diff --git a/aperturedb/ParallelLoader.py b/aperturedb/ParallelLoader.py index 0d7efb30..de30e0f1 100644 --- a/aperturedb/ParallelLoader.py +++ b/aperturedb/ParallelLoader.py @@ -180,7 +180,7 @@ def query_setup(self, generator: Subscriptable) -> None: logger.warning( f"Failed to create index for {connection_class}.{property_name}") - def ingest(self, generator: Subscriptable, batchsize: int = 1, numthreads: int = 4, stats: bool = False) -> None: + def ingest(self, generator: Subscriptable, batchsize: int = 1, numthreads: int = 4, stats: bool = False, transformers: list = None) -> None: """ **Method to ingest data into the database** @@ -189,10 +189,13 @@ def ingest(self, generator: Subscriptable, batchsize: int = 1, numthreads: int = batchsize (int, optional): The size of batch to be used. Defaults to 1. numthreads (int, optional): Number of workers to create. Defaults to 4. stats (bool, optional): If stats need to be presented, realtime. Defaults to False. + transformers (list, optional): A list of Transformer classes to apply to the data. Defaults to None. """ logger.info( f"Starting ingestion with batchsize={batchsize}, numthreads={numthreads}") - self.query(generator, batchsize, numthreads, stats) + + self.query(generator, batchsize, numthreads, + stats, transformers=transformers) def print_stats(self) -> None: diff --git a/aperturedb/ParallelQuery.py b/aperturedb/ParallelQuery.py index cd3a2180..4022ced2 100644 --- a/aperturedb/ParallelQuery.py +++ b/aperturedb/ParallelQuery.py @@ -287,7 +287,7 @@ def get_succeeded_commands(self) -> int: return sum(stat["succeeded_commands"] for stat in self.actual_stats) - def query(self, generator, batchsize: int = 1, numthreads: int = 4, stats: bool = False) -> None: + def query(self, generator, batchsize: int = 1, numthreads: int = 4, stats: bool = False, transformers: list = None) -> None: """ This function takes as input the data to be executed in specified number of threads. The generator yields a tuple : (array of commands, array of blobs) @@ -295,17 +295,34 @@ def query(self, generator, batchsize: int = 1, numthreads: int = 4, stats: bool generator (_type_): The class that generates the queries to be executed. batchsize (int, optional): Number of queries per transaction. Defaults to 1. numthreads (int, optional): Number of parallel workers. Defaults to 4. - stats (bool, optional): Show statistics at end of ingestion. Defaults to False. + stats (bool, optional): Show statistics at end of query execution. Defaults to False. + transformers (list, optional): A list of Transformer classes to apply to the data. Defaults to None. """ + from aperturedb.transformers.transformer import Transformer + + if transformers is not None: + if not isinstance(transformers, (list, tuple)): + transformers = [transformers] + for t in transformers: + if not (isinstance(t, type) and issubclass(t, Transformer)) and not callable(t): + raise TypeError( + "Each transformer must be a subclass of Transformer or a callable.") + use_dask = hasattr(generator, "use_dask") and generator.use_dask if use_dask: + if transformers or isinstance(generator, Transformer): + raise ValueError("Transformers cannot be used with Dask mode.") self._reset(batchsize=batchsize, numthreads=numthreads) self.daskManager = DaskManager(num_workers=numthreads) if hasattr(self, "query_setup"): self.query_setup(generator) + if transformers and len(generator) > 0: + for transformer in transformers: + generator = transformer(generator, client=self.client) + if use_dask: results, self.total_actions_time = self.daskManager.run( self.__class__, self.client, generator, batchsize, stats=stats, dry_run=self.dry_run) diff --git a/aperturedb/transformers/transformer.py b/aperturedb/transformers/transformer.py index 5370364f..d0813788 100644 --- a/aperturedb/transformers/transformer.py +++ b/aperturedb/transformers/transformer.py @@ -1,6 +1,7 @@ from aperturedb.Subscriptable import Subscriptable from aperturedb.CommonLibrary import create_connector from aperturedb.Utils import Utils +import threading import logging logger = logging.getLogger(__name__) @@ -59,6 +60,7 @@ def __init__(self, data: Subscriptable, client=None, **kwargs) -> None: self._blob_index = [] self._add_image_index = [] self._client = client + self._thread_local = threading.local() bc = 0 for i, c in enumerate(x[0]): @@ -82,9 +84,35 @@ def __len__(self): return len(self.data) def get_client(self): - if self._client is None: - self._client = create_connector() - return self._client + if not hasattr(self, "_thread_local"): + self._thread_local = threading.local() + + if not hasattr(self._thread_local, "client"): + if self._client is not None: + if hasattr(self._client, "clone"): + self._thread_local.client = self._client.clone() + else: + self._thread_local.client = self._client + else: + self._thread_local.client = create_connector() + return self._thread_local.client def get_utils(self): return Utils(self.get_client()) + + def __getattr__(self, name): + # Delegate specific attribute access to the underlying data (generator) + # to preserve behaviors from original generators (like CSVParser). + allowed_attributes = { + "use_dask", + "strict_response_validation", + "response_handler", + "error_handler", + "blobs_relative_to_csv", + "commands_per_query", + "blobs_per_query" + } + if name in allowed_attributes and "data" in self.__dict__: + return getattr(self.data, name) + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'") diff --git a/test/test_Parallel.py b/test/test_Parallel.py index 488cc3a0..c632ab3c 100644 --- a/test/test_Parallel.py +++ b/test/test_Parallel.py @@ -1,12 +1,26 @@ import logging import random +import pytest from aperturedb.Connector import Connector from aperturedb.ParallelQuery import ParallelQuery +from aperturedb.ParallelLoader import ParallelLoader from aperturedb.Subscriptable import Subscriptable +from aperturedb.transformers.transformer import Transformer logger = logging.getLogger(__name__) + +class DummyTransformer(Transformer): + def __init__(self, generator, client=None): + super().__init__(generator, client=client) + if client is None: + pytest.fail("Client was not passed to transformer!") + + def getitem(self, idx): + query, blobs = self.data[idx] + return query, blobs + # Tests for parallel which don't involve data. @@ -24,7 +38,7 @@ def getitem(self, subscript): query = [] blobs = [] for i in range(self.commands_per_query): - if random.randint(0, 100) <= (self.error_pct * 100): + if random.random() < self.error_pct: query.append({ "BadCommand": { } @@ -114,6 +128,86 @@ def test_dictResponseHandling(self): print(e) raise + def test_transformers(self, db: Connector): + """ + Verifies that transformers are correctly applied. + """ + elements = 10 + generator = GeneratorWithErrors(elements=elements, error_pct=0) + + loader = ParallelLoader(db) + loader.ingest(generator, batchsize=2, numthreads=2, + stats=False, transformers=[DummyTransformer]) + + assert loader.get_succeeded_queries() > 0 + + def test_transformers_rejects_dask(self, db: Connector): + elements = 10 + generator = GeneratorWithErrors(elements=elements, error_pct=0) + generator.use_dask = True + + loader = ParallelLoader(db) + with pytest.raises(ValueError, match="Transformers cannot be used with Dask"): + loader.ingest(generator, batchsize=2, numthreads=2, + stats=False, transformers=[DummyTransformer]) + + # Test manual wrapping also gets rejected + transformer = DummyTransformer(generator, client=db) + with pytest.raises(ValueError, match="Transformers cannot be used with Dask"): + loader.ingest(transformer, batchsize=2, numthreads=2, stats=False) + + def test_transformers_equivalence(self, db: Connector): + """ + Verifies that using transformers parameter is equivalent to manual wrapping. + """ + elements = 10 + + # Manual wrapping + generator1 = GeneratorWithErrors(elements=elements, error_pct=0) + transformer1 = DummyTransformer(generator1, client=db) + loader1 = ParallelLoader(db) + loader1.ingest(transformer1, batchsize=2, numthreads=2, stats=False) + + # transformers parameter + generator2 = GeneratorWithErrors(elements=elements, error_pct=0) + loader2 = ParallelLoader(db) + loader2.ingest(generator2, batchsize=2, numthreads=2, + stats=False, transformers=[DummyTransformer]) + + assert loader1.get_succeeded_queries() == loader2.get_succeeded_queries() + assert loader1.get_succeeded_queries() > 0 + + def test_query_transformers(self, db: Connector): + """ + Verifies that transformers are correctly applied when calling ParallelQuery.query(). + """ + elements = 10 + generator = GeneratorWithErrors(elements=elements, error_pct=0) + + querier = ParallelQuery(db) + querier.query(generator, batchsize=2, numthreads=2, + stats=False, transformers=[DummyTransformer]) + + assert querier.get_succeeded_queries() > 0 + + def test_transformers_single_and_invalid(self, db: Connector): + """ + Verifies that passing a single transformer works and invalid inputs raise TypeError. + """ + elements = 10 + generator = GeneratorWithErrors(elements=elements, error_pct=0) + loader = ParallelLoader(db) + + # Single transformer + loader.ingest(generator, batchsize=2, numthreads=2, + stats=False, transformers=DummyTransformer) + assert loader.get_succeeded_queries() > 0 + + # Invalid input + with pytest.raises(TypeError, match="must be a subclass of Transformer or a callable"): + loader.ingest(generator, batchsize=2, numthreads=2, + stats=False, transformers="invalid_transformer") + def test_parallel_query_worker_closes_connection(self, db, monkeypatch): from aperturedb.QueryGenerator import QueryGenerator import threading