diff --git a/aperturedb/ParallelLoader.py b/aperturedb/ParallelLoader.py index 6f2769fd..44d588a7 100644 --- a/aperturedb/ParallelLoader.py +++ b/aperturedb/ParallelLoader.py @@ -164,7 +164,7 @@ def query_setup(self, generator: Subscriptable) -> None: logger.warning( f"Failed to create index for {connection_class}.{property_name}") - def ingest(self, generator: Subscriptable, batchsize: int = 1, numthreads: int = 4, stats: bool = False) -> None: + def ingest(self, generator: Subscriptable, batchsize: int = 1, numthreads: int = 4, stats: bool = False, max_bytes_per_batch: int = None) -> None: """ **Method to ingest data into the database** @@ -173,10 +173,11 @@ def ingest(self, generator: Subscriptable, batchsize: int = 1, numthreads: int = batchsize (int, optional): The size of batch to be used. Defaults to 1. numthreads (int, optional): Number of workers to create. Defaults to 4. stats (bool, optional): If stats need to be presented, realtime. Defaults to False. + max_bytes_per_batch (int, optional): Automatic batch sizing to keep encoded queries under this limit. """ logger.info( f"Starting ingestion with batchsize={batchsize}, numthreads={numthreads}") - self.query(generator, batchsize, numthreads, stats) + self.query(generator, batchsize, numthreads, stats, max_bytes_per_batch=max_bytes_per_batch) def print_stats(self) -> None: diff --git a/aperturedb/ParallelQuery.py b/aperturedb/ParallelQuery.py index 31fb840e..64bee3b0 100644 --- a/aperturedb/ParallelQuery.py +++ b/aperturedb/ParallelQuery.py @@ -218,31 +218,90 @@ def worker(self, thid: int, generator, start: int, end: int, run_event) -> None: # A new connection will be created for each thread client = self.client.clone() - total_batches = (end - start) // self.batchsize - - if (end - start) % self.batchsize > 0: - total_batches += 1 - - logger.info( - f"Worker {thid} executing {total_batches} batches, {self.stats=}") - for i in range(total_batches): - if not run_event.is_set(): - break - batch_start = start + i * self.batchsize - batch_end = min(batch_start + self.batchsize, end) - - try: - self.do_batch(client, batch_start, - generator[batch_start:batch_end]) - except Exception as e: - logger.exception(e) - logger.warning( - f"Worker {thid} failed to execute batch {i}: [{batch_start},{batch_end}]") - self.error_counter += 1 + max_bytes = getattr(self, "max_bytes_per_batch", None) + + if max_bytes is not None and max_bytes > 0: + logger.info( + f"Worker {thid} executing dynamically sized batches (max {max_bytes} bytes), {self.stats=}") + current_batch = [] + current_bytes = 0 + batch_start = start - if self.stats: - self.pb.update(batch_end - batch_start) - logger.info(f"Worker {thid} executed {total_batches} batches") + for i in range(start, end): + if not run_event.is_set(): + break + + item = generator[i] + + # Estimate item size (mostly blobs + some json overhead) + item_bytes = 0 + for blob in item[1]: + if isinstance(blob, bytes): + item_bytes += len(blob) + else: + item_bytes += len(str(blob)) + + if len(current_batch) > 0 and (current_bytes + item_bytes > max_bytes): + try: + self.do_batch(client, batch_start, current_batch) + except Exception as e: + logger.exception(e) + logger.warning( + f"Worker {thid} failed to execute dynamic batch starting at {batch_start}") + self.error_counter += 1 + + if self.stats: + self.pb.update(len(current_batch)) + + current_batch = [] + current_bytes = 0 + batch_start = i + + current_batch.append(item) + current_bytes += item_bytes + + # Remainder + if len(current_batch) > 0 and run_event.is_set(): + try: + self.do_batch(client, batch_start, current_batch) + except Exception as e: + logger.exception(e) + logger.warning( + f"Worker {thid} failed to execute dynamic batch remainder starting at {batch_start}") + self.error_counter += 1 + + if self.stats: + self.pb.update(len(current_batch)) + + logger.info( + f"Worker {thid} finished executing dynamically sized batches") + + else: + total_batches = (end - start) // self.batchsize + + if (end - start) % self.batchsize > 0: + total_batches += 1 + + logger.info( + f"Worker {thid} executing {total_batches} batches, {self.stats=}") + for i in range(total_batches): + if not run_event.is_set(): + break + batch_start = start + i * self.batchsize + batch_end = min(batch_start + self.batchsize, end) + + try: + self.do_batch(client, batch_start, + generator[batch_start:batch_end]) + except Exception as e: + logger.exception(e) + logger.warning( + f"Worker {thid} failed to execute batch {i}: [{batch_start},{batch_end}]") + self.error_counter += 1 + + if self.stats: + self.pb.update(batch_end - batch_start) + logger.info(f"Worker {thid} executed {total_batches} batches") def get_objects_existed(self) -> int: return sum([stat["objects_existed"] @@ -256,7 +315,8 @@ def get_succeeded_commands(self) -> int: return sum([stat["succeeded_commands"] for stat in self.actual_stats]) - def query(self, generator, batchsize: int = 1, numthreads: int = 4, stats: bool = False) -> None: + def query(self, generator, batchsize: int = 1, numthreads: int = 4, stats: bool = False, max_bytes_per_batch: int = None) -> None: + self.max_bytes_per_batch = max_bytes_per_batch """ This function takes as input the data to be executed in specified number of threads. The generator yields a tuple : (array of commands, array of blobs)