diff --git a/backend/onyx/document_index/opensearch/opensearch_document_index.py b/backend/onyx/document_index/opensearch/opensearch_document_index.py index 140a64c7993..d62e0fcbeb4 100644 --- a/backend/onyx/document_index/opensearch/opensearch_document_index.py +++ b/backend/onyx/document_index/opensearch/opensearch_document_index.py @@ -758,7 +758,7 @@ def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None: except BulkIndexError as e: # There are several reasons why this might be raised, but the # most likely one is if the deletion has not had enough time to - # propogate throughout the index, in which case this would be + # propagate throughout the index, in which case this would be # raised with some form of "version_conflict_engine_exception # version conflict, document already exists" messaging. # Refresh the index and try one more time. We do not refresh diff --git a/backend/onyx/natural_language_processing/search_nlp_models.py b/backend/onyx/natural_language_processing/search_nlp_models.py index 06e8538ce46..1f74ea4fef4 100644 --- a/backend/onyx/natural_language_processing/search_nlp_models.py +++ b/backend/onyx/natural_language_processing/search_nlp_models.py @@ -6,7 +6,6 @@ from collections.abc import Callable from concurrent.futures import as_completed from concurrent.futures import ThreadPoolExecutor -from functools import partial from functools import wraps from types import TracebackType from typing import Any @@ -97,7 +96,7 @@ def _get_or_create_event_loop() -> asyncio.AbstractEventLoop: which was causing memory leaks. Instead, each thread reuses the same loop. Returns: - asyncio.AbstractEventLoop: The thread-local event loop + asyncio.AbstractEventLoop: The thread-local event loop. """ if ( not hasattr(_thread_local, "loop") @@ -706,8 +705,6 @@ def __init__( async def _make_direct_api_call( self, embed_request: EmbedRequest, - tenant_id: str | None = None, # noqa: ARG002 - request_id: str | None = None, # noqa: ARG002 ) -> EmbedResponse: """Make direct API call to cloud provider, bypassing model server.""" if self.provider_type is None: @@ -846,15 +843,16 @@ def _batch_encode_texts( request_id: str | None = None, ) -> list[Embedding]: text_batches = batch_list(texts, batch_size) + num_of_batches = len(text_batches) - logger.debug(f"Encoding {len(texts)} texts in {len(text_batches)} batches") + logger.debug(f"Encoding {len(texts)} texts in {num_of_batches} batches.") embeddings: list[Embedding] = [] @_cleanup_thread_local def process_batch( batch_idx: int, - batch_len: int, + num_of_batches: int, text_batch: list[str], tenant_id: str | None = None, request_id: str | None = None, @@ -882,74 +880,96 @@ def process_batch( ) start_time = time.monotonic() - - # Route between direct API calls and model server calls + response: EmbedResponse + # Route between direct API calls and model server calls. if self.provider_type is not None: - # For API providers, make direct API call - # Use thread-local event loop to prevent memory leaks from creating - # thousands of event loops during batch processing - loop = _get_or_create_event_loop() - response = loop.run_until_complete( - self._make_direct_api_call( - embed_request, tenant_id=tenant_id, request_id=request_id + # For API providers, make direct API call. + try: + # Detect if this code is being called from an event loop + # or not. + asyncio.get_running_loop() + except RuntimeError: + # This code is being called synchronously, safe to use + # run_until_complete. + # Use thread-local event loop to prevent memory leaks + # from creating thousands of event loops during batch + # processing. + loop = _get_or_create_event_loop() + response = loop.run_until_complete( + self._make_direct_api_call(embed_request) ) - ) + else: + # This code is being called from an event loop, can't + # block on it from the same thread without deadlocking. + # Run in a separate thread with its own loop. + with ThreadPoolExecutor(max_workers=1) as pool: + response = cast( + EmbedResponse, + pool.submit( + asyncio.run, self._make_direct_api_call(embed_request) + ).result(), + ) else: - # For local models, use model server + # For local models, use model server. response = self._make_model_server_request( embed_request, tenant_id=tenant_id, request_id=request_id ) - end_time = time.monotonic() - - processing_time = end_time - start_time + processing_time = time.monotonic() - start_time logger.debug( - f"EmbeddingModel.process_batch: Batch {batch_idx}/{batch_len} processing time: {processing_time:.2f} seconds" + f"process_batch: Batch idx {batch_idx}, total num {num_of_batches}, processing time: {processing_time:.2f}s." ) return batch_idx, response.embeddings - # only multi thread if: - # 1. num_threads is greater than 1 - # 2. we are using an API-based embedding model (provider_type is not None) - # 3. there are more than 1 batch (no point in threading if only 1) + # Only multi-thread if: + # 1. num_threads is greater than 1. + # 2. we are using an API-based embedding model (provider_type is not + # None). + # 3. there is more than 1 batch (no point in threading if only 1). if num_threads >= 1 and self.provider_type and len(text_batches) > 1: with ThreadPoolExecutor(max_workers=num_threads) as executor: - future_to_batch = { + # NOTE: Be careful with closures, we explicitly pass in idx and + # batch here because if we were to pass them in via enclosing + # scope, they would be passed in as references not values and + # would be evaluated at lambda execution time, in which case + # every lambda would point to the same values for idx and batch. + futures = [ executor.submit( - partial( - process_batch, - idx, - len(text_batches), - batch, + lambda idx, batch: process_batch( + batch_idx=idx, + num_of_batches=num_of_batches, + text_batch=batch, tenant_id=tenant_id, request_id=request_id, - ) - ): idx - for idx, batch in enumerate(text_batches, start=1) - } + ), + idx, + batch, + ) + for idx, batch in enumerate(text_batches) + ] - # Collect results in order + # Collect results in order. batch_results: list[tuple[int, list[Embedding]]] = [] - for future in as_completed(future_to_batch): + for future in as_completed(futures): try: result = future.result() batch_results.append(result) except Exception as e: - logger.exception("Embedding model failed to process batch") + logger.exception("Embedding model failed to process batch.") raise e - # Sort by batch index and extend embeddings + # Sort by batch index and extend embeddings. batch_results.sort(key=lambda x: x[0]) for _, batch_embeddings in batch_results: embeddings.extend(batch_embeddings) else: - # Original sequential processing - for idx, text_batch in enumerate(text_batches, start=1): + # Original sequential processing. + for idx, text_batch in enumerate(text_batches): _, batch_embeddings = process_batch( - idx, - len(text_batches), - text_batch, + batch_idx=idx, + num_of_batches=num_of_batches, + text_batch=text_batch, tenant_id=tenant_id, request_id=request_id, ) diff --git a/backend/tests/unit/onyx/natural_language_processing/test_search_nlp_models.py b/backend/tests/unit/onyx/natural_language_processing/test_search_nlp_models.py index d9f480a8071..c4c9f242e3d 100644 --- a/backend/tests/unit/onyx/natural_language_processing/test_search_nlp_models.py +++ b/backend/tests/unit/onyx/natural_language_processing/test_search_nlp_models.py @@ -1,5 +1,5 @@ from collections.abc import AsyncGenerator -from typing import List +from threading import Lock from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch @@ -10,8 +10,11 @@ from onyx.llm.constants import LlmProviderNames from onyx.natural_language_processing.search_nlp_models import CloudEmbedding +from onyx.natural_language_processing.search_nlp_models import EmbeddingModel from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType +from shared_configs.model_server_models import EmbedRequest +from shared_configs.model_server_models import EmbedResponse @pytest.fixture @@ -25,7 +28,7 @@ async def mock_http_client() -> AsyncGenerator[AsyncMock, None]: @pytest.fixture -def sample_embeddings() -> List[List[float]]: +def sample_embeddings() -> list[list[float]]: return [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] @@ -47,7 +50,7 @@ async def test_cloud_embedding_explicit_close() -> None: @pytest.mark.asyncio async def test_openai_embedding( mock_http_client: AsyncMock, # noqa: ARG001 - sample_embeddings: List[List[float]], + sample_embeddings: list[list[float]], ) -> None: with patch("openai.AsyncOpenAI") as mock_openai: mock_client = AsyncMock() @@ -85,3 +88,309 @@ async def test_rate_limit_handling() -> None: model_name="fake-model", text_type=EmbedTextType.QUERY, ) + + +# ------------------------------------------------------------------------------ +# _batch_encode_texts tests +# +# Tests correct ordering of the embedding results, and that sync and async +# caller contexts both work. +# ------------------------------------------------------------------------------ + +_SEARCH_NLP_MODULE = "onyx.natural_language_processing.search_nlp_models" + + +def _text_for_idx(i: int) -> str: + return f"text_{i}" + + +def _embedding_for_idx(i: int) -> list[float]: + return [float(i)] + + +def _embedding_for_text(text: str) -> list[float]: + return _embedding_for_idx(int(text.split("_")[1])) + + +def _fake_direct_api_call(embed_request: EmbedRequest) -> EmbedResponse: + return EmbedResponse( + embeddings=[_embedding_for_text(t) for t in embed_request.texts] + ) + + +def _fake_model_server_call( + embed_request: EmbedRequest, + tenant_id: str | None = None, # noqa: ARG001 + request_id: str | None = None, # noqa: ARG001 +) -> EmbedResponse: + return EmbedResponse( + embeddings=[_embedding_for_text(t) for t in embed_request.texts] + ) + + +def _make_cloud_embedding_model() -> EmbeddingModel: + with patch(f"{_SEARCH_NLP_MODULE}.get_tokenizer", return_value=MagicMock()): + return EmbeddingModel( + server_host="localhost", + server_port=9000, + model_name="text-embedding-3-small", + normalize=True, + query_prefix=None, + passage_prefix=None, + api_key="fake-key", + api_url=None, + provider_type=EmbeddingProvider.OPENAI, + ) + + +def _make_local_embedding_model() -> EmbeddingModel: + with patch(f"{_SEARCH_NLP_MODULE}.get_tokenizer", return_value=MagicMock()): + return EmbeddingModel( + server_host="localhost", + server_port=9000, + model_name="nomic-ai/nomic-embed-text-v1", + normalize=True, + query_prefix=None, + passage_prefix=None, + api_key=None, + api_url=None, + provider_type=None, + ) + + +def test_batch_encode_multi_batch_partial_last() -> None: + """ + Tests that the multi-threaded path with non-uniform batches preserves + expected ordering and cardinality of embeddings given an input. + """ + # Precondition. + model = _make_cloud_embedding_model() + n_texts = 13 # 3 batches of 4 + 1 partial batch of 1. + texts = [_text_for_idx(i) for i in range(n_texts)] + + # Under test. + with patch.object( + EmbeddingModel, + "_make_direct_api_call", + new=AsyncMock(side_effect=_fake_direct_api_call), + ): + result = model.encode( + texts=texts, + text_type=EmbedTextType.PASSAGE, # Arbitrary. + api_embedding_batch_size=4, + ) + + # Postcondition. + assert result == [_embedding_for_idx(i) for i in range(n_texts)] + + +def test_batch_encode_multi_batch_uniform() -> None: + """ + Tests that the multi-threaded path with uniform batches preserves expected + ordering and cardinality of embeddings given an input. + """ + # Precondition. + model = _make_cloud_embedding_model() + n_texts = 16 # 4 batches of 4. + texts = [_text_for_idx(i) for i in range(n_texts)] + + # Under test. + with patch.object( + EmbeddingModel, + "_make_direct_api_call", + new=AsyncMock(side_effect=_fake_direct_api_call), + ): + result = model.encode( + texts=texts, + text_type=EmbedTextType.PASSAGE, # Arbitrary. + api_embedding_batch_size=4, + ) + + # Postcondition. + assert result == [_embedding_for_idx(i) for i in range(n_texts)] + + +def test_batch_encode_single_batch_sequential() -> None: + """ + Tests that the sequential path with a single batch preserves expected + ordering and cardinality of embeddings given an input. + """ + # Precondition. + model = _make_cloud_embedding_model() + n_texts = 3 # Less than the batch size. + texts = [_text_for_idx(i) for i in range(n_texts)] + + # Under test. + with patch.object( + EmbeddingModel, + "_make_direct_api_call", + new=AsyncMock(side_effect=_fake_direct_api_call), + ): + result = model.encode( + texts=texts, + text_type=EmbedTextType.PASSAGE, # Arbitrary. + api_embedding_batch_size=4, + ) + + # Postcondition. + assert result == [_embedding_for_idx(i) for i in range(n_texts)] + + +def test_batch_encode_local_model_sequential() -> None: + """ + Tests that the sequential path with a local model preserves expected + ordering and cardinality of embeddings given an input. + """ + # Precondition. + model = _make_local_embedding_model() + n_texts = 10 # 2 batches of 4 + 1 partial batch of 2. + texts = [_text_for_idx(i) for i in range(n_texts)] + + # Under test. + with patch.object( + EmbeddingModel, + "_make_model_server_request", + side_effect=_fake_model_server_call, + ): + result = model.encode( + texts=texts, + text_type=EmbedTextType.PASSAGE, # Arbitrary. + local_embedding_batch_size=4, + ) + + # Postcondition. + assert result == [_embedding_for_idx(i) for i in range(n_texts)] + + +def test_batch_encode_error_propagates() -> None: + """ + Tests that a failing batch propagates its exception out of encode(). + """ + # Precondition. + model = _make_cloud_embedding_model() + texts = [_text_for_idx(i) for i in range(8)] + + call_count = {"n": 0} + call_count_lock = Lock() + + def _fail_on_second_call(embed_request: EmbedRequest) -> EmbedResponse: + with call_count_lock: + call_count["n"] += 1 + if call_count["n"] == 2: + raise RuntimeError("simulated provider failure") + return _fake_direct_api_call(embed_request) + + # Under test and postcondition. + with patch.object( + EmbeddingModel, + "_make_direct_api_call", + new=AsyncMock(side_effect=_fail_on_second_call), + ): + with pytest.raises(RuntimeError, match="simulated provider failure"): + model.encode( + texts=texts, + text_type=EmbedTextType.PASSAGE, # Arbitrary. + api_embedding_batch_size=2, + ) + + +def test_batch_encode_sync_caller_uses_thread_local_loop() -> None: + """ + Tests that a sync call uses the thread-local event loop and does not call + asyncio.run. + """ + # Precondition. + model = _make_cloud_embedding_model() + texts = [_text_for_idx(i) for i in range(4)] + + # Under test. + with ( + patch.object( + EmbeddingModel, + "_make_direct_api_call", + new=AsyncMock(side_effect=_fake_direct_api_call), + ), + patch(f"{_SEARCH_NLP_MODULE}.asyncio.run") as mock_asyncio_run, + ): + result = model.encode( + texts=texts, + text_type=EmbedTextType.PASSAGE, # Arbitrary. + api_embedding_batch_size=4, + ) + + # Postcondition. + assert result == [_embedding_for_idx(i) for i in range(4)] + assert mock_asyncio_run.call_count == 0 + + +@pytest.mark.asyncio +async def test_batch_encode_async_caller_single_batch_no_deadlock() -> None: + """ + Tests that an async call using the sequential path calls asyncio.run exactly + once, and that this call succeeds. In this path the caller is in an event + loop, so calling asyncio.run would raise as a thread running an event loop + cannot wait on itself. Calling asyncio.run in a thread with no event loop is + safe. + """ + # Precondition. + model = _make_cloud_embedding_model() + n_texts = 4 # 1 batch of 4. + texts = [_text_for_idx(i) for i in range(n_texts)] + + # Under test. + with ( + patch.object( + EmbeddingModel, + "_make_direct_api_call", + new=AsyncMock(side_effect=_fake_direct_api_call), + ), + patch( + f"{_SEARCH_NLP_MODULE}.asyncio.run", + wraps=__import__("asyncio").run, + ) as spy_asyncio_run, + ): + result = model.encode( + texts=texts, + text_type=EmbedTextType.PASSAGE, # Arbitrary. + api_embedding_batch_size=4, + ) + + # Postcondition. + assert result == [_embedding_for_idx(i) for i in range(n_texts)] + assert spy_asyncio_run.call_count == 1 + + +@pytest.mark.asyncio +async def test_batch_encode_async_caller_multi_batch() -> None: + """ + Tests that an async call using the multi-threaded path does not call + asyncio.run, and that the encode call succeeds. In this path the caller is + in an event loop, but the batches are processed in separate threads which do + not have running event loops, so we do not expect to call asyncio.run. + """ + # Precondition. + model = _make_cloud_embedding_model() + n_texts = 13 # 3 batches of 4 + 1 partial batch of 1. + texts = [_text_for_idx(i) for i in range(n_texts)] + + # Under test. + with ( + patch.object( + EmbeddingModel, + "_make_direct_api_call", + new=AsyncMock(side_effect=_fake_direct_api_call), + ), + patch( + f"{_SEARCH_NLP_MODULE}.asyncio.run", + wraps=__import__("asyncio").run, + ) as spy_asyncio_run, + ): + result = model.encode( + texts=texts, + text_type=EmbedTextType.PASSAGE, # Arbitrary. + api_embedding_batch_size=4, + ) + + # Postcondition. + assert result == [_embedding_for_idx(i) for i in range(n_texts)] + assert spy_asyncio_run.call_count == 0