diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 57bea7a9..b20d3e40 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -17,13 +17,15 @@ jobs: steps: - name: Cleanup previous run - run: docker run --rm -v ${{ github.workspace }}:/workspace alpine sh -c "rm -rf /workspace/test/aperturedb" + run: | + docker run --rm -v ${{ github.workspace }}:/workspace alpine sh -c "rm -rf /workspace/test/aperturedb /workspace/test/seattle-* /workspace/test/output /workspace/docker/tests" || true + sudo rm -rf test/aperturedb test/seattle-* test/output docker/tests || true continue-on-error: true - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Login to DockerHub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USER }} password: ${{ secrets.DOCKER_PASS }} @@ -80,13 +82,15 @@ jobs: steps: - name: Cleanup previous run - run: docker run --rm -v ${{ github.workspace }}:/workspace alpine sh -c "rm -rf /workspace/test/aperturedb" + run: | + docker run --rm -v ${{ github.workspace }}:/workspace alpine sh -c "rm -rf /workspace/test/aperturedb /workspace/test/seattle-* /workspace/test/output /workspace/docker/tests" || true + sudo rm -rf test/aperturedb test/seattle-* test/output docker/tests || true continue-on-error: true - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Login to DockerHub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USER }} password: ${{ secrets.DOCKER_PASS }} @@ -96,7 +100,6 @@ jobs: with: credentials_json: ${{ secrets.GCP_SERVICE_ACCOUNT_KEY }} project_id: ${{ secrets.GCP_SERVICE_ACCOUNT_PROJECT_ID }} - - name: Set up Cloud SDK uses: google-github-actions/setup-gcloud@v2 - name: Build tests on pytorch GPU image diff --git a/aperturedb/DaskManager.py b/aperturedb/DaskManager.py index 6a79d0c5..8e7e167b 100644 --- a/aperturedb/DaskManager.py +++ b/aperturedb/DaskManager.py @@ -48,8 +48,8 @@ def __del__(self): self._client.close() self._cluster.close() - def run(self, QueryClass: type[ParallelQuery], client: Connector, generator, batchsize, stats, dry_run=False): - def process(df, host, port, use_ssl, ca_cert, verify_hostname, session, connector_type, dry_run): + def run(self, QueryClass: type['ParallelQuery'], client: Connector, generator, batchsize, stats, dry_run=False, **kwargs): + def process(df, host, port, use_ssl, ca_cert, verify_hostname, session, connector_type, dry_run, kwargs_dict): metrics = Stats() # Dask reads data in partitions, and the first partition is of 2 rows, with all # values as 'foo'. This is for sampling the column names and types. Should not process @@ -88,7 +88,7 @@ def process(df, host, port, use_ssl, ca_cert, verify_hostname, session, connecto blobs_relative_to_csv=generator.blobs_relative_to_csv) loader.query(generator=data, batchsize=len( - slice), numthreads=1, stats=False) + slice), numthreads=1, stats=False, **kwargs_dict) count += 1 metrics.times_arr.extend(loader.times_arr) metrics.error_counter += loader.error_counter @@ -109,7 +109,8 @@ def process(df, host, port, use_ssl, ca_cert, verify_hostname, session, connecto client.config.verify_hostname, client.shared_data.session, type(client), - dry_run) + dry_run, + kwargs) computation = computation.persist() if stats: progress(computation) diff --git a/aperturedb/ParallelQuery.py b/aperturedb/ParallelQuery.py index cd3a2180..211d5ced 100644 --- a/aperturedb/ParallelQuery.py +++ b/aperturedb/ParallelQuery.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import List, Tuple +from typing import List, Tuple, Optional from aperturedb import Parallelizer import numpy as np import logging @@ -174,7 +174,7 @@ def process_responses(requests, input_blobs, responses, output_blobs): parameter_count = len(inspect.signature( response_handler).parameters) if parameter_count < 4 or parameter_count > 5: - raise Exception("Bad Signature for response_handler :" + raise Exception("Bad Signature for response_handler : " f"expected 6 > args > 3, got {parameter_count}") if parameter_count == 4: indexless_handler = response_handler @@ -241,35 +241,136 @@ def worker(self, thid: int, generator, start: int, end: int, run_event) -> None: # A new connection will be created for each thread client = self.client.clone() - try: - total_batches = (end - start) // self.batchsize - - if (end - start) % self.batchsize > 0: - total_batches += 1 - - logger.info( - f"Worker {thid} executing {total_batches} batches, {self.stats=}") - executed_batches = 0 - for i in range(total_batches): - if not run_event.is_set(): - break - batch_start = start + i * self.batchsize - batch_end = min(batch_start + self.batchsize, end) - - try: - self.do_batch(client, batch_start, - generator[batch_start:batch_end]) - except Exception as e: - logger.exception(e) - logger.warning( - f"Worker {thid} failed to execute batch {i}: [{batch_start},{batch_end}]") - with self.error_counter_lock: - self.error_counter += 1 + max_bytes = getattr(self, "max_bytes_per_batch", None) - executed_batches += 1 - if self.stats: - self.pb.update(batch_end - batch_start) - logger.info(f"Worker {thid} executed {executed_batches} batches") + try: + if max_bytes is not None and max_bytes > 0: + logger.info( + f"Worker {thid} executing dynamically sized batches (max {max_bytes} bytes), {self.stats=}") + current_batch = [] + current_bytes = 0 + batch_start = start + executed_batches = 0 + + for i in range(start, end): + if not run_event.is_set(): + break + + try: + item = generator[i] + + # Estimate item size (mostly blobs + some json overhead) + item_bytes = len(str(item[0])) + for blob in item[1]: + if isinstance(blob, (bytes, bytearray, memoryview)): + item_bytes += len(blob) + else: + item_bytes += 100 + except Exception as e: + logger.exception(e) + logger.warning( + f"Worker {thid} failed to retrieve/estimate item {i}") + with self.error_counter_lock: + self.error_counter += 1 + if self.stats: + self.pb.update(1) + + if len(current_batch) > 0: + try: + self.do_batch( + client, batch_start, current_batch) + executed_batches += 1 + except Exception as e2: + logger.exception(e2) + logger.warning( + f"Worker {thid} failed to execute dynamic batch starting at {batch_start}") + with self.error_counter_lock: + self.error_counter += 1 + + if self.stats: + self.pb.update(len(current_batch)) + + current_batch = [] + current_bytes = 0 + + continue + + if len(current_batch) == 0: + batch_start = i + + if len(current_batch) > 0 and (current_bytes + item_bytes > max_bytes or len(current_batch) >= self.batchsize): + try: + self.do_batch(client, batch_start, current_batch) + executed_batches += 1 + except Exception as e: + logger.exception(e) + logger.warning( + f"Worker {thid} failed to execute dynamic batch starting at {batch_start}") + with self.error_counter_lock: + self.error_counter += 1 + + if self.stats: + self.pb.update(len(current_batch)) + + current_batch = [] + current_bytes = 0 + batch_start = i + + if len(current_batch) == 0 and item_bytes > max_bytes: + logger.warning( + f"Worker {thid} executing batch starting at {batch_start} that exceeds max_bytes_per_batch: {item_bytes} > {max_bytes}") + + current_batch.append(item) + current_bytes += item_bytes + + # Remainder + if len(current_batch) > 0 and run_event.is_set(): + try: + self.do_batch(client, batch_start, current_batch) + executed_batches += 1 + except Exception as e: + logger.exception(e) + logger.warning( + f"Worker {thid} failed to execute dynamic batch remainder starting at {batch_start}") + with self.error_counter_lock: + self.error_counter += 1 + + if self.stats: + self.pb.update(len(current_batch)) + + logger.info( + f"Worker {thid} finished executing {executed_batches} dynamically sized batches") + + else: + total_batches = (end - start) // self.batchsize + + if (end - start) % self.batchsize > 0: + total_batches += 1 + + logger.info( + f"Worker {thid} executing {total_batches} batches, {self.stats=}") + executed_batches = 0 + for i in range(total_batches): + if not run_event.is_set(): + break + batch_start = start + i * self.batchsize + batch_end = min(batch_start + self.batchsize, end) + + try: + self.do_batch(client, batch_start, + generator[batch_start:batch_end]) + except Exception as e: + logger.exception(e) + logger.warning( + f"Worker {thid} failed to execute batch {i}: [{batch_start},{batch_end}]") + with self.error_counter_lock: + self.error_counter += 1 + + executed_batches += 1 + if self.stats: + self.pb.update(batch_end - batch_start) + msg = f"Worker {thid} executed {executed_batches} batches" + logger.info(msg) finally: # Explicitly close the connection to avoid exhausting server connection limits if client is not self.client and hasattr(client, 'close') and callable(client.close): @@ -287,7 +388,7 @@ def get_succeeded_commands(self) -> int: return sum(stat["succeeded_commands"] for stat in self.actual_stats) - def query(self, generator, batchsize: int = 1, numthreads: int = 4, stats: bool = False) -> None: + def query(self, generator, batchsize: int = 1, numthreads: int = 4, stats: bool = False, max_bytes_per_batch: Optional[int] = None) -> None: """ This function takes as input the data to be executed in specified number of threads. The generator yields a tuple : (array of commands, array of blobs) @@ -296,7 +397,9 @@ def query(self, generator, batchsize: int = 1, numthreads: int = 4, stats: bool batchsize (int, optional): Number of queries per transaction. Defaults to 1. numthreads (int, optional): Number of parallel workers. Defaults to 4. stats (bool, optional): Show statistics at end of ingestion. Defaults to False. + max_bytes_per_batch (int, optional): The maximum number of bytes allowed per batch. Acts as an additional cap on batch size; batches will be split if they reach either this byte limit or the `batchsize` item limit. If None or <= 0, dynamic batching is disabled and only `batchsize` is used. Default is None. """ + self.max_bytes_per_batch = max_bytes_per_batch use_dask = hasattr(generator, "use_dask") and generator.use_dask if use_dask: @@ -308,7 +411,7 @@ def query(self, generator, batchsize: int = 1, numthreads: int = 4, stats: bool if use_dask: results, self.total_actions_time = self.daskManager.run( - self.__class__, self.client, generator, batchsize, stats=stats, dry_run=self.dry_run) + self.__class__, self.client, generator, batchsize, stats=stats, dry_run=self.dry_run, max_bytes_per_batch=self.max_bytes_per_batch) self.actual_stats = [] for result in results: if result is not None: @@ -339,10 +442,8 @@ def query(self, generator, batchsize: int = 1, numthreads: int = 4, stats: bool logger.error( f"Could not determine query structure from:\n{generator[0]}") logger.error(type(generator[0])) - logger.info( - f"Commands per query = {self.commands_per_query}, " - f"Blobs per query = {self.blobs_per_query}" - ) + logger.info(f"Commands per query = {self.commands_per_query}, " + f"Blobs per query = {self.blobs_per_query}") self.batched_run(generator, batchsize, numthreads, stats) def print_stats(self) -> None: diff --git a/test/run_test_container.sh b/test/run_test_container.sh index 6ea9b0e0..460b4bc5 100755 --- a/test/run_test_container.sh +++ b/test/run_test_container.sh @@ -34,6 +34,7 @@ function run_aperturedb_instance(){ docker network create ${TAG}_host_default GATEWAY=$(docker network inspect ${TAG}_host_default | jq -r .[0].IPAM.Config[0].Gateway) GATEWAY=$GATEWAY RUNNER_NAME=$TAG docker compose -f docker-compose.yml up -d + if [[ "$TAG" == *_non_http ]]; then PORT=$(RUNNER_NAME=$TAG docker compose -f docker-compose.yml port lenz 55551 | awk -F: '{print $NF}') elif [[ "$TAG" == *_http ]]; then @@ -101,17 +102,21 @@ wait_for_stack() { local lenz_ready=0 local nginx_ready=0 - if docker run --rm --network=${network} curlimages/curl:latest \ - nc -z -w 2 lenz 58085 >/dev/null 2>&1; then + if docker run --rm --network=${network} --entrypoint nc curlimages/curl:latest \ + -z -w 2 lenz 58085 >/dev/null 2>&1; then lenz_ready=1 fi - if [[ "$tag" == *"_http" ]]; then + if [[ "$tag" == *"_non_http" ]]; then + if [ "$lenz_ready" -eq 1 ]; then + echo "Stack ${tag} is ready after ${elapsed}s" + return 0 + fi + elif [[ "$tag" == *"_http" ]]; then if docker run --rm --network=${network} curlimages/curl:latest \ - -sS -o /dev/null -m 2 http://nginx:80/ >/dev/null 2>&1; then + -fsS -o /dev/null -m 2 http://nginx:80/ >/dev/null 2>&1; then nginx_ready=1 fi - if [ "$lenz_ready" -eq 1 ] && [ "$nginx_ready" -eq 1 ]; then echo "Stack ${tag} is ready after ${elapsed}s" return 0 diff --git a/test/test_Parallel.py b/test/test_Parallel.py index 488cc3a0..b2faba27 100644 --- a/test/test_Parallel.py +++ b/test/test_Parallel.py @@ -164,6 +164,155 @@ def mock_do_batch(client, batch_start, data): assert closed_count[0] == 1 +class GeneratorWithLargeBlobs(Subscriptable): + def __init__(self, elements=10, blob_size=100) -> None: + super().__init__() + self.elements = elements + self.blob_size = blob_size + + def __len__(self): + return self.elements + + def getitem(self, subscript): + query = [{"FindBlob": {}}] + blobs = [b"0" * self.blob_size] + return query, blobs + + +class GeneratorWithSmallImages(Subscriptable): + def __init__(self, elements=10, blob_size=10) -> None: + super().__init__() + self.elements = elements + self.blob_size = blob_size + + def __len__(self): + return self.elements + + def getitem(self, subscript): + query = [{"AddImage": {}}] + blobs = [b"0" * self.blob_size] + return query, blobs + + +class MockClient: + def __init__(self): + from types import SimpleNamespace + self.config = SimpleNamespace(host="localhost", port=55555, use_ssl=False, + verify_hostname=False, username="admin", password="password") + + def clone(self): + return self + + def query(self, q, b): + self.queries.append(q) + return ([{"FindBlob": {"status": 0}} for _ in range(len(q))], []) + + def last_query_ok(self): + return True + + def get_last_query_time(self): + return 0 + + +def test_dynamic_batching(): + db = MockClient() + db.queries = [] + + # 10 elements, 100 bytes each + generator = GeneratorWithLargeBlobs(10, 100) + querier = ParallelQuery(db) + db.queries = [] + + # limit to 150 bytes -> should process 1 element per batch despite batchsize=5 + querier.query(generator, batchsize=5, numthreads=1, + max_bytes_per_batch=150) + + # It should have succeeded in processing all queries + assert querier.get_succeeded_queries() == 10 + assert len(db.queries) == 10 + for q in db.queries: + assert len(q) == 1 + + +def test_dynamic_batching_oversized_item(): + db = MockClient() + db.queries = [] + + # 10 elements, 100 bytes each + generator = GeneratorWithLargeBlobs(10, 100) + querier = ParallelQuery(db) + db.queries = [] + + # limit to 50 bytes -> item size (100) > max_bytes (50). Should log warning and process 1 per batch. + querier.query(generator, batchsize=5, numthreads=1, max_bytes_per_batch=50) + + # It should have succeeded in processing all queries + assert querier.get_succeeded_queries() == 10 + assert len(db.queries) == 10 + for q in db.queries: + assert len(q) == 1 + + +def test_dynamic_batching_add_image(): + db = MockClient() + db.queries = [] + + # 10 elements, 10 bytes each + generator = GeneratorWithSmallImages(10, 10) + querier = ParallelQuery(db) + db.queries = [] + + # Expected item size: len(str([{"AddImage": {}}])) -> 18 + 10 bytes blob = 28 bytes. + # With a limit of 100 bytes, we should be able to fit 3 items per batch (3 * 28 = 84 bytes). + # Since batchsize=5 is larger than 3, the max_bytes_per_batch limit will be the bottleneck, + # producing 3, 3, 3, 1 item batches. + querier.query(generator, batchsize=5, numthreads=1, + max_bytes_per_batch=100) + + # It should have succeeded in processing all queries + assert querier.get_succeeded_queries() == 10 + assert len(db.queries) == 4 + for q in db.queries[:-1]: + assert len(q) == 3 + assert len(db.queries[-1]) == 1 + + +def test_dynamic_batching_add_image_variable_sizes(): + db = MockClient() + db.queries = [] + + # 1. Smaller images (2 bytes each) -> larger batches + generator_small = GeneratorWithSmallImages(10, 2) + querier_small = ParallelQuery(db) + db.queries = [] + + # Expected item size: 18 + 2 = 20 bytes. + # Max bytes = 100 -> 100 // 20 = 5 items per batch. + querier_small.query(generator_small, batchsize=10, + numthreads=1, max_bytes_per_batch=100) + + assert querier_small.get_succeeded_queries() == 10 + assert len(db.queries) == 2 + for q in db.queries: + assert len(q) == 5 + + # 2. Larger images (32 bytes each) -> smaller batches + generator_large = GeneratorWithSmallImages(10, 32) + db.queries = [] + querier_large = ParallelQuery(db) + db.queries = [] + + # Expected item size: 18 + 32 = 50 bytes. + # Max bytes = 100 -> 100 // 50 = 2 items per batch. + querier_large.query(generator_large, batchsize=10, + numthreads=1, max_bytes_per_batch=100) + + assert querier_large.get_succeeded_queries() == 10 + assert len(db.queries) == 5 + for q in db.queries: + assert len(q) == 2 + + def test_dask_dry_run(db: Connector): from aperturedb.ParallelLoader import ParallelLoader from aperturedb.EntityDataCSV import EntityDataCSV