diff --git a/src/apify/storage_clients/_apify/_request_queue_shared_client.py b/src/apify/storage_clients/_apify/_request_queue_shared_client.py index 496ec4f1..057b2918 100644 --- a/src/apify/storage_clients/_apify/_request_queue_shared_client.py +++ b/src/apify/storage_clients/_apify/_request_queue_shared_client.py @@ -8,7 +8,12 @@ from cachetools import LRUCache -from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, RequestQueueMetadata +from crawlee.storage_clients.models import ( + AddRequestsResponse, + ProcessedRequest, + RequestQueueMetadata, + UnprocessedRequest, +) from ._models import ApifyRequestQueueMetadata, CachedRequest, RequestQueueHead from ._utils import to_crawlee_request, unique_key_to_request_id @@ -71,6 +76,14 @@ def __init__( self._requests_cache: LRUCache[str, CachedRequest] = LRUCache(maxsize=cache_size) """LRU cache storing request objects, keyed by request ID.""" + self._requests_being_added: dict[str, asyncio.Future[bool]] = {} + """In-flight `add_batch_of_requests` markers, keyed by request ID. + + Each future resolves once the platform call that is adding the request settles: `True` if the request was + committed, `False` otherwise. Concurrent producers of the same request await it instead of re-sending, + which preserves deduplication while still avoiding false success when the original add fails. + """ + self._queue_has_locked_requests: bool | None = None """Whether the queue contains requests currently locked by other clients.""" @@ -87,9 +100,13 @@ async def add_batch_of_requests( forefront: bool = False, ) -> AddRequestsResponse: """Specific implementation of this method for the RQ shared access mode.""" + loop = asyncio.get_running_loop() # Do not try to add previously added requests to avoid pointless expensive calls to API new_requests: list[Request] = [] already_present_requests: list[ProcessedRequest] = [] + # Requests a concurrent `add_batch_of_requests` call is already sending. We await its outcome instead of + # re-sending them, as (request, that call's in-flight future) pairs. + awaited_in_flight: list[tuple[Request, asyncio.Future[bool]]] = [] for request in requests: request_id = unique_key_to_request_id(request.unique_key) @@ -106,46 +123,68 @@ async def add_batch_of_requests( ) ) + elif request_id in self._requests_being_added: + # A concurrent call is already adding this request; await its outcome rather than re-sending it. + awaited_in_flight.append((request, self._requests_being_added[request_id])) + else: - # Add new request to the cache. - processed_request = ProcessedRequest( - id=request_id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=request.was_already_handled, - ) - self._cache_request( - request_id, - processed_request, - ) + # Register an in-flight marker so concurrent producers dedupe against it; caching is deferred + # until the platform confirms the request was accepted (see below). new_requests.append(request) + self._requests_being_added[request_id] = loop.create_future() if new_requests: # Prepare requests for API by converting to dictionaries. requests_dict = [request.model_dump(by_alias=True) for request in new_requests] - # Send requests to API. - batch_response = await self._api_client.batch_add_requests( - requests=requests_dict, - forefront=forefront, - ) - - batch_response_dict = batch_response.model_dump(by_alias=True) - api_response = AddRequestsResponse.model_validate(batch_response_dict) - - # Add the locally known already present processed requests based on the local cache. - api_response.processed_requests.extend(already_present_requests) + committed_request_ids: set[str] = set() + try: + # Send requests to API. + batch_response = await self._api_client.batch_add_requests( + requests=requests_dict, + forefront=forefront, + ) - # Remove unprocessed requests from the cache - for unprocessed_request in api_response.unprocessed_requests: - unprocessed_request_id = unique_key_to_request_id(unprocessed_request.unique_key) - self._requests_cache.pop(unprocessed_request_id, None) + batch_response_dict = batch_response.model_dump(by_alias=True) + api_response = AddRequestsResponse.model_validate(batch_response_dict) + + # Commit only the requests the platform actually accepted to the local dedup cache. Caching after + # the call succeeds (not before) keeps a failed call from poisoning the cache and silently + # deduplicating a later retry of the same request. + unprocessed_unique_keys = {request.unique_key for request in api_response.unprocessed_requests} + for request in new_requests: + if request.unique_key in unprocessed_unique_keys: + continue + request_id = unique_key_to_request_id(request.unique_key) + self._cache_request( + request_id, + ProcessedRequest( + id=request_id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=request.was_already_handled, + ), + ) + committed_request_ids.add(request_id) + + # Add the locally known already present processed requests based on the local cache. + api_response.processed_requests.extend(already_present_requests) + finally: + # Release the in-flight markers we registered. Committed requests tell concurrent producers the + # request reached the platform; everything else (unprocessed, API error, cancellation) tells them + # it did not, so they retry instead of reporting false success. + for request in new_requests: + request_id = unique_key_to_request_id(request.unique_key) + self._settle_pending_addition(request_id, committed=request_id in committed_request_ids) else: api_response = AddRequestsResponse.model_validate( {'unprocessedRequests': [], 'processedRequests': already_present_requests} ) + # Fold in requests a concurrent call was already adding. + await self._resolve_awaited_in_flight(awaited_in_flight, api_response) + logger.debug( f'Tried to add new requests: {len(new_requests)}, ' f'succeeded to add new requests: {len(api_response.processed_requests) - len(already_present_requests)}, ' @@ -163,6 +202,42 @@ async def add_batch_of_requests( return api_response + def _settle_pending_addition(self, request_id: str, *, committed: bool) -> None: + """Resolve the in-flight add marker for a request, unblocking any concurrent producers awaiting it. + + Args: + request_id: ID of the request whose in-flight add has settled. + committed: Whether the request was committed to the platform. + """ + future = self._requests_being_added.pop(request_id, None) + if future is not None and not future.done(): + future.set_result(committed) + + @staticmethod + async def _resolve_awaited_in_flight( + awaited_in_flight: list[tuple[Request, asyncio.Future[bool]]], + api_response: AddRequestsResponse, + ) -> None: + """Await concurrent in-flight adds of these requests and fold the outcome into `api_response`. + + Requests the concurrent add committed are reported as already present; the rest are reported unprocessed + so the caller retries them rather than receiving false success. + """ + for request, future in awaited_in_flight: + if await future: + api_response.processed_requests.append( + ProcessedRequest( + id=unique_key_to_request_id(request.unique_key), + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=request.was_already_handled, + ) + ) + else: + api_response.unprocessed_requests.append( + UnprocessedRequest(unique_key=request.unique_key, url=request.url, method=request.method) + ) + async def get_request(self, unique_key: str) -> Request | None: """Specific implementation of this method for the RQ shared access mode.""" return await self._get_request_by_id(unique_key_to_request_id(unique_key)) diff --git a/src/apify/storage_clients/_apify/_request_queue_single_client.py b/src/apify/storage_clients/_apify/_request_queue_single_client.py index a42f7252..28b72141 100644 --- a/src/apify/storage_clients/_apify/_request_queue_single_client.py +++ b/src/apify/storage_clients/_apify/_request_queue_single_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from collections import deque from datetime import UTC, datetime from logging import getLogger @@ -7,7 +8,12 @@ from cachetools import LRUCache -from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, RequestQueueMetadata +from crawlee.storage_clients.models import ( + AddRequestsResponse, + ProcessedRequest, + RequestQueueMetadata, + UnprocessedRequest, +) from ._utils import to_crawlee_request, unique_key_to_request_id @@ -90,6 +96,14 @@ def __init__( Tracked locally to accurately determine when the queue is empty for this single consumer. """ + self._requests_being_added: dict[str, asyncio.Future[bool]] = {} + """In-flight `add_batch_of_requests` markers, keyed by request ID. + + Each future resolves once the platform call that is adding the request settles: `True` if the request was + committed, `False` otherwise. Concurrent producers of the same request await it instead of re-sending, + which preserves deduplication while still avoiding false success when the original add fails. + """ + self._initialized_caches = False """Flag indicating whether local caches have been populated from existing queue contents. @@ -108,8 +122,12 @@ async def add_batch_of_requests( await self._init_caches() self._initialized_caches = True + loop = asyncio.get_running_loop() new_requests: list[Request] = [] already_present_requests: list[ProcessedRequest] = [] + # Requests a concurrent `add_batch_of_requests` call is already sending. We await its outcome instead of + # re-sending them, as (request, that call's in-flight future) pairs. + awaited_in_flight: list[tuple[Request, asyncio.Future[bool]]] = [] for request in requests: # Calculate id for request @@ -135,33 +153,52 @@ async def add_batch_of_requests( was_already_handled=request.was_already_handled, ) ) + # Check if a concurrent call is already adding this request, and await its outcome rather than + # re-sending it. + elif request_id in self._requests_being_added: + awaited_in_flight.append((request, self._requests_being_added[request_id])) else: - # Push the request to the platform. Probably not there, or we are not aware of it + # Push the request to the platform. Probably not there, or we are not aware of it. Register an + # in-flight marker so concurrent producers dedupe against it; caching is deferred until the + # platform confirms the request was accepted (see below). new_requests.append(request) - - # Update local caches - self._requests_cache[request_id] = request - if forefront: - self._head_requests.append(request_id) - else: - self._head_requests.appendleft(request_id) + self._requests_being_added[request_id] = loop.create_future() if new_requests: # Prepare requests for API by converting to dictionaries. requests_dict = [request.model_dump(by_alias=True) for request in new_requests] - # Send requests to API. - batch_response = await self._api_client.batch_add_requests(requests=requests_dict, forefront=forefront) - batch_response_dict = batch_response.model_dump(by_alias=True) - api_response = AddRequestsResponse.model_validate(batch_response_dict) - - # Add the locally known already present processed requests based on the local cache. - api_response.processed_requests.extend(already_present_requests) - - # Remove unprocessed requests from the cache - for unprocessed_request in api_response.unprocessed_requests: - request_id = unique_key_to_request_id(unprocessed_request.unique_key) - self._requests_cache.pop(request_id, None) + committed_request_ids: set[str] = set() + try: + # Send requests to API. + batch_response = await self._api_client.batch_add_requests(requests=requests_dict, forefront=forefront) + batch_response_dict = batch_response.model_dump(by_alias=True) + api_response = AddRequestsResponse.model_validate(batch_response_dict) + + # Commit only the requests the platform actually accepted to the local caches. Caching after the + # call succeeds (not before) keeps a failed call from poisoning the cache and silently + # deduplicating a later retry of the same request. + unprocessed_unique_keys = {request.unique_key for request in api_response.unprocessed_requests} + for request in new_requests: + if request.unique_key in unprocessed_unique_keys: + continue + request_id = unique_key_to_request_id(request.unique_key) + self._requests_cache[request_id] = request + if forefront: + self._head_requests.append(request_id) + else: + self._head_requests.appendleft(request_id) + committed_request_ids.add(request_id) + + # Add the locally known already present processed requests based on the local cache. + api_response.processed_requests.extend(already_present_requests) + finally: + # Release the in-flight markers we registered. Committed requests tell concurrent producers the + # request reached the platform; everything else (unprocessed, API error, cancellation) tells them + # it did not, so they retry instead of reporting false success. + for request in new_requests: + request_id = unique_key_to_request_id(request.unique_key) + self._settle_pending_addition(request_id, committed=request_id in committed_request_ids) else: api_response = AddRequestsResponse( @@ -169,6 +206,9 @@ async def add_batch_of_requests( processed_requests=already_present_requests, ) + # Fold in requests a concurrent call was already adding. + await self._resolve_awaited_in_flight(awaited_in_flight, api_response) + # Update assumed total count for newly added requests. new_request_count = 0 for processed_request in api_response.processed_requests: @@ -179,6 +219,42 @@ async def add_batch_of_requests( return api_response + def _settle_pending_addition(self, request_id: str, *, committed: bool) -> None: + """Resolve the in-flight add marker for a request, unblocking any concurrent producers awaiting it. + + Args: + request_id: ID of the request whose in-flight add has settled. + committed: Whether the request was committed to the platform. + """ + future = self._requests_being_added.pop(request_id, None) + if future is not None and not future.done(): + future.set_result(committed) + + @staticmethod + async def _resolve_awaited_in_flight( + awaited_in_flight: list[tuple[Request, asyncio.Future[bool]]], + api_response: AddRequestsResponse, + ) -> None: + """Await concurrent in-flight adds of these requests and fold the outcome into `api_response`. + + Requests the concurrent add committed are reported as already present; the rest are reported unprocessed + so the caller retries them rather than receiving false success. + """ + for request, future in awaited_in_flight: + if await future: + api_response.processed_requests.append( + ProcessedRequest( + id=unique_key_to_request_id(request.unique_key), + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=request.was_already_handled, + ) + ) + else: + api_response.unprocessed_requests.append( + UnprocessedRequest(unique_key=request.unique_key, url=request.url, method=request.method) + ) + async def get_request(self, unique_key: str) -> Request | None: """Specific implementation of this method for the RQ single access mode.""" return await self._get_request_by_id(id=unique_key_to_request_id(unique_key)) diff --git a/tests/unit/storage_clients/test_apify_request_queue_client.py b/tests/unit/storage_clients/test_apify_request_queue_client.py index cfdc0ed1..a11609e8 100644 --- a/tests/unit/storage_clients/test_apify_request_queue_client.py +++ b/tests/unit/storage_clients/test_apify_request_queue_client.py @@ -1,25 +1,29 @@ from __future__ import annotations +import asyncio from datetime import UTC, datetime +from types import SimpleNamespace +from typing import TYPE_CHECKING from unittest.mock import AsyncMock import pytest -from apify_client._models import RequestQueueHead, RequestQueueStats -from crawlee.storage_clients.models import RequestQueueMetadata +from apify_client._models import AddedRequest, BatchAddResult, RequestQueueHead, RequestQueueStats +from crawlee.storage_clients.models import AddRequestsResponse, RequestQueueMetadata +from apify import Request from apify.storage_clients._apify._models import ApifyRequestQueueMetadata +from apify.storage_clients._apify._request_queue_shared_client import ApifyRequestQueueSharedClient from apify.storage_clients._apify._request_queue_single_client import ApifyRequestQueueSingleClient from apify.storage_clients._apify._utils import unique_key_to_request_id +if TYPE_CHECKING: + from collections.abc import Sequence -def _make_single_client( - api_client: AsyncMock | None = None, -) -> tuple[ApifyRequestQueueSingleClient, AsyncMock]: - if api_client is None: - api_client = AsyncMock() + +def _make_metadata() -> RequestQueueMetadata: now = datetime.now(tz=UTC) - metadata = RequestQueueMetadata( + return RequestQueueMetadata( id='test-rq-id', name='test-rq', accessed_at=now, @@ -30,7 +34,45 @@ def _make_single_client( pending_request_count=0, total_request_count=0, ) - client = ApifyRequestQueueSingleClient(api_client=api_client, metadata=metadata, cache_size=100) + + +def _batch_result_all_processed(requests: Sequence[Request]) -> BatchAddResult: + """Build a `batch_add_requests` response marking every request as newly processed.""" + return BatchAddResult.model_construct( + processed_requests=[ + AddedRequest.model_construct( + request_id=unique_key_to_request_id(request.unique_key), + unique_key=request.unique_key, + was_already_present=False, + was_already_handled=False, + ) + for request in requests + ], + unprocessed_requests=[], + ) + + +def _make_single_client( + api_client: AsyncMock | None = None, +) -> tuple[ApifyRequestQueueSingleClient, AsyncMock]: + if api_client is None: + api_client = AsyncMock() + client = ApifyRequestQueueSingleClient(api_client=api_client, metadata=_make_metadata(), cache_size=100) + return client, api_client + + +def _make_shared_client( + api_client: AsyncMock | None = None, +) -> tuple[ApifyRequestQueueSharedClient, AsyncMock]: + if api_client is None: + api_client = AsyncMock() + metadata = _make_metadata() + client = ApifyRequestQueueSharedClient( + api_client=api_client, + metadata=metadata, + cache_size=100, + metadata_getter=AsyncMock(return_value=metadata), + ) return client, api_client @@ -119,3 +161,131 @@ async def test_list_head_limit(in_progress_count: int, expected_limit: int) -> N await client._list_head() api_client.list_head.assert_awaited_once_with(limit=expected_limit) + + +# Adding a request through `batch_add_requests` must never poison the local dedup cache or report false +# success. A failed add leaves nothing cached, so a later add (sequential or concurrent) still reaches the +# platform; a concurrent producer of the same request deduplicates against the in-flight add instead of +# re-sending it, yet is only told the request is present once the platform actually accepts it. + + +@pytest.mark.parametrize('access', ['single', 'shared']) +async def test_failed_batch_add_does_not_poison_dedup_cache(access: str) -> None: + """A failed `batch_add_requests` leaves no cached entry, so a retry still reaches the platform.""" + client, api_client = _make_single_client() if access == 'single' else _make_shared_client() + # The single client lazily initializes its caches via `list_requests`; harmless for the shared client. + api_client.list_requests = AsyncMock(return_value=SimpleNamespace(items=[])) + request = Request.from_url('https://example.com/1') + request_id = unique_key_to_request_id(request.unique_key) + + # First attempt: the platform call fails. + api_client.batch_add_requests = AsyncMock(side_effect=RuntimeError('network down')) + with pytest.raises(RuntimeError): + await client.add_batch_of_requests([request]) + assert request_id not in client._requests_cache + + # Retry: the platform call succeeds. The request must be sent again, not deduped away. + api_client.batch_add_requests = AsyncMock(return_value=_batch_result_all_processed([request])) + await client.add_batch_of_requests([request]) + + api_client.batch_add_requests.assert_awaited_once() + assert api_client.batch_add_requests.await_args is not None + assert len(api_client.batch_add_requests.await_args.kwargs['requests']) == 1 + + +async def _start_first_then_concurrent_producer( + client: ApifyRequestQueueSingleClient | ApifyRequestQueueSharedClient, + request: Request, + *, + in_flight: asyncio.Event, +) -> tuple[asyncio.Task[AddRequestsResponse], asyncio.Task[AddRequestsResponse]]: + """Start one producer, wait until its `batch_add_requests` is in flight, then start a concurrent producer + of the same request and let it park on the in-flight add. Returns the (first, second) tasks.""" + if isinstance(client, ApifyRequestQueueSingleClient): + # Skip the lazy `list_requests` init so the concurrent producer's only suspension point is the + # in-flight future, which makes the scheduling below deterministic. + client._initialized_caches = True + + first = asyncio.create_task(client.add_batch_of_requests([request])) + await in_flight.wait() + second = asyncio.create_task(client.add_batch_of_requests([request])) + await asyncio.sleep(0) # let the concurrent producer classify and park on the in-flight future + return first, second + + +@pytest.mark.parametrize('access', ['single', 'shared']) +async def test_concurrent_add_failure_does_not_falsely_dedupe(access: str) -> None: + """While one producer's `batch_add_requests` is in flight and then fails, a concurrent producer of the same + request must not report false success: the request is returned unprocessed (so the caller retries it).""" + client, api_client = _make_single_client() if access == 'single' else _make_shared_client() + request = Request.from_url('https://example.com/1') + request_id = unique_key_to_request_id(request.unique_key) + + in_flight = asyncio.Event() + release = asyncio.Event() + call_count = 0 + + async def batch_add(*, requests: list, forefront: bool = False) -> BatchAddResult: # noqa: ARG001 + nonlocal call_count + call_count += 1 + in_flight.set() + await release.wait() + raise RuntimeError('network down') + + api_client.batch_add_requests = AsyncMock(side_effect=batch_add) + + first, second = await _start_first_then_concurrent_producer(client, request, in_flight=in_flight) + + # Nothing is committed while the first call is still in flight, so the concurrent producer cannot observe a + # false "already present" entry. + assert request_id not in client._requests_cache + + # Let the first producer fail. + release.set() + with pytest.raises(RuntimeError): + await first + + # The concurrent producer deduplicated against the in-flight add (no second API call), but because that add + # failed it must be told the request is unprocessed rather than receiving false success. + response = await second + assert call_count == 1 + assert [unprocessed.unique_key for unprocessed in response.unprocessed_requests] == [request.unique_key] + assert all(processed.unique_key != request.unique_key for processed in response.processed_requests) + assert request_id not in client._requests_cache + + +@pytest.mark.parametrize('access', ['single', 'shared']) +async def test_concurrent_add_deduplicates_against_in_flight(access: str) -> None: + """A concurrent producer of an in-flight request deduplicates against it: only one `batch_add_requests` call + is made, and once it succeeds the concurrent producer is told the request is already present.""" + client, api_client = _make_single_client() if access == 'single' else _make_shared_client() + request = Request.from_url('https://example.com/1') + request_id = unique_key_to_request_id(request.unique_key) + + in_flight = asyncio.Event() + release = asyncio.Event() + call_count = 0 + + async def batch_add(*, requests: list, forefront: bool = False) -> BatchAddResult: # noqa: ARG001 + nonlocal call_count + call_count += 1 + in_flight.set() + await release.wait() + return _batch_result_all_processed([request]) + + api_client.batch_add_requests = AsyncMock(side_effect=batch_add) + + first, second = await _start_first_then_concurrent_producer(client, request, in_flight=in_flight) + + # Let the first producer succeed. + release.set() + first_response = await first + second_response = await second + + assert call_count == 1 + assert request_id in client._requests_cache + # The first producer added the request, the concurrent one deduplicated against the in-flight add. + assert [processed.unique_key for processed in first_response.processed_requests] == [request.unique_key] + assert [processed.unique_key for processed in second_response.processed_requests] == [request.unique_key] + assert second_response.processed_requests[0].was_already_present is True + assert second_response.unprocessed_requests == []