Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
except BulkIndexError as e:
# There are several reasons why this might be raised, but the
# most likely one is if the deletion has not had enough time to
# propogate throughout the index, in which case this would be
# propagate throughout the index, in which case this would be
# raised with some form of "version_conflict_engine_exception
# version conflict, document already exists" messaging.
# Refresh the index and try one more time. We do not refresh
Expand Down
108 changes: 64 additions & 44 deletions backend/onyx/natural_language_processing/search_nlp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from collections.abc import Callable
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from functools import wraps
from types import TracebackType
from typing import Any
Expand Down Expand Up @@ -97,7 +96,7 @@ def _get_or_create_event_loop() -> asyncio.AbstractEventLoop:
which was causing memory leaks. Instead, each thread reuses the same loop.

Returns:
asyncio.AbstractEventLoop: The thread-local event loop
asyncio.AbstractEventLoop: The thread-local event loop.
"""
if (
not hasattr(_thread_local, "loop")
Expand Down Expand Up @@ -706,8 +705,6 @@ def __init__(
async def _make_direct_api_call(
self,
embed_request: EmbedRequest,
tenant_id: str | None = None, # noqa: ARG002
request_id: str | None = None, # noqa: ARG002
) -> EmbedResponse:
"""Make direct API call to cloud provider, bypassing model server."""
if self.provider_type is None:
Expand Down Expand Up @@ -846,15 +843,16 @@ def _batch_encode_texts(
request_id: str | None = None,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)
num_of_batches = len(text_batches)

logger.debug(f"Encoding {len(texts)} texts in {len(text_batches)} batches")
logger.debug(f"Encoding {len(texts)} texts in {num_of_batches} batches.")

embeddings: list[Embedding] = []

@_cleanup_thread_local
def process_batch(
batch_idx: int,
batch_len: int,
num_of_batches: int,
text_batch: list[str],
tenant_id: str | None = None,
request_id: str | None = None,
Expand Down Expand Up @@ -882,74 +880,96 @@ def process_batch(
)

start_time = time.monotonic()

# Route between direct API calls and model server calls
response: EmbedResponse
# Route between direct API calls and model server calls.
if self.provider_type is not None:
# For API providers, make direct API call
# Use thread-local event loop to prevent memory leaks from creating
# thousands of event loops during batch processing
loop = _get_or_create_event_loop()
response = loop.run_until_complete(
self._make_direct_api_call(
embed_request, tenant_id=tenant_id, request_id=request_id
# For API providers, make direct API call.
try:
# Detect if this code is being called from an event loop
# or not.
asyncio.get_running_loop()
except RuntimeError:
# This code is being called synchronously, safe to use
# run_until_complete.
# Use thread-local event loop to prevent memory leaks
# from creating thousands of event loops during batch
# processing.
loop = _get_or_create_event_loop()
response = loop.run_until_complete(
self._make_direct_api_call(embed_request)
)
)
else:
# This code is being called from an event loop, can't
# block on it from the same thread without deadlocking.
# Run in a separate thread with its own loop.
with ThreadPoolExecutor(max_workers=1) as pool:
response = cast(
EmbedResponse,
pool.submit(
asyncio.run, self._make_direct_api_call(embed_request)
).result(),
)
else:
# For local models, use model server
# For local models, use model server.
response = self._make_model_server_request(
embed_request, tenant_id=tenant_id, request_id=request_id
)

end_time = time.monotonic()

processing_time = end_time - start_time
processing_time = time.monotonic() - start_time
logger.debug(
f"EmbeddingModel.process_batch: Batch {batch_idx}/{batch_len} processing time: {processing_time:.2f} seconds"
f"process_batch: Batch idx {batch_idx}, total num {num_of_batches}, processing time: {processing_time:.2f}s."
)

return batch_idx, response.embeddings

# only multi thread if:
# 1. num_threads is greater than 1
# 2. we are using an API-based embedding model (provider_type is not None)
# 3. there are more than 1 batch (no point in threading if only 1)
# Only multi-thread if:
# 1. num_threads is greater than 1.
# 2. we are using an API-based embedding model (provider_type is not
# None).
# 3. there is more than 1 batch (no point in threading if only 1).
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future_to_batch = {
# NOTE: Be careful with closures, we explicitly pass in idx and
# batch here because if we were to pass them in via enclosing
# scope, they would be passed in as references not values and
# would be evaluated at lambda execution time, in which case
# every lambda would point to the same values for idx and batch.
futures = [
executor.submit(
partial(
process_batch,
idx,
len(text_batches),
batch,
lambda idx, batch: process_batch(
batch_idx=idx,
num_of_batches=num_of_batches,
text_batch=batch,
tenant_id=tenant_id,
request_id=request_id,
)
): idx
for idx, batch in enumerate(text_batches, start=1)
}
),
idx,
batch,
)
for idx, batch in enumerate(text_batches)
]

# Collect results in order
# Collect results in order.
batch_results: list[tuple[int, list[Embedding]]] = []
for future in as_completed(future_to_batch):
for future in as_completed(futures):
try:
result = future.result()
batch_results.append(result)
except Exception as e:
logger.exception("Embedding model failed to process batch")
logger.exception("Embedding model failed to process batch.")
raise e

# Sort by batch index and extend embeddings
# Sort by batch index and extend embeddings.
batch_results.sort(key=lambda x: x[0])
for _, batch_embeddings in batch_results:
embeddings.extend(batch_embeddings)
else:
# Original sequential processing
for idx, text_batch in enumerate(text_batches, start=1):
# Original sequential processing.
for idx, text_batch in enumerate(text_batches):
_, batch_embeddings = process_batch(
idx,
len(text_batches),
text_batch,
batch_idx=idx,
num_of_batches=num_of_batches,
text_batch=text_batch,
tenant_id=tenant_id,
request_id=request_id,
)
Expand Down
Loading