Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions src/apify/storage_clients/_apify/_alias_resolving.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
from logging import getLogger
from typing import TYPE_CHECKING, ClassVar, Literal, overload

from apify_client import ApifyClientAsync

from ._utils import hash_api_base_url_and_token

if TYPE_CHECKING:
from collections.abc import Callable
from types import TracebackType

from apify_client import ApifyClientAsync
from apify_client._resource_clients import (
DatasetClientAsync,
DatasetCollectionClientAsync,
Expand All @@ -35,6 +34,7 @@ async def open_by_alias(
storage_type: Literal['Dataset'],
collection_client: DatasetCollectionClientAsync,
get_resource_client_by_id: Callable[[str], DatasetClientAsync],
api_client: ApifyClientAsync,
configuration: Configuration,
) -> DatasetClientAsync: ...

Expand All @@ -46,6 +46,7 @@ async def open_by_alias(
storage_type: Literal['KeyValueStore'],
collection_client: KeyValueStoreCollectionClientAsync,
get_resource_client_by_id: Callable[[str], KeyValueStoreClientAsync],
api_client: ApifyClientAsync,
configuration: Configuration,
) -> KeyValueStoreClientAsync: ...

Expand All @@ -57,6 +58,7 @@ async def open_by_alias(
storage_type: Literal['RequestQueue'],
collection_client: RequestQueueCollectionClientAsync,
get_resource_client_by_id: Callable[[str], RequestQueueClientAsync],
api_client: ApifyClientAsync,
configuration: Configuration,
) -> RequestQueueClientAsync: ...

Expand All @@ -69,6 +71,7 @@ async def open_by_alias(
KeyValueStoreCollectionClientAsync | RequestQueueCollectionClientAsync | DatasetCollectionClientAsync
),
get_resource_client_by_id: Callable[[str], KeyValueStoreClientAsync | RequestQueueClientAsync | DatasetClientAsync],
api_client: ApifyClientAsync,
configuration: Configuration,
) -> KeyValueStoreClientAsync | RequestQueueClientAsync | DatasetClientAsync:
"""Open storage by alias, creating it if necessary.
Expand All @@ -81,6 +84,8 @@ async def open_by_alias(
storage_type: The type of storage to open.
collection_client: The Apify API collection client for the storage type.
get_resource_client_by_id: A callable that takes a storage ID and returns the resource client.
api_client: The Apify API client used for the storage operation. Reused to access the default KVS that
holds the alias mapping, so alias resolution does not spin up its own client.
configuration: Configuration object containing API credentials and settings.

Returns:
Expand All @@ -94,6 +99,7 @@ async def open_by_alias(
storage_type=storage_type,
alias=alias,
configuration=configuration,
api_client=api_client,
) as alias_resolver:
storage_id = await alias_resolver.resolve_id()

Expand Down Expand Up @@ -142,10 +148,12 @@ def __init__(
storage_type: Literal['Dataset', 'KeyValueStore', 'RequestQueue'],
alias: str,
configuration: Configuration,
api_client: ApifyClientAsync,
) -> None:
self._storage_type = storage_type
self._alias = alias
self._configuration = configuration
self._api_client = api_client

async def __aenter__(self) -> AliasResolver:
"""Context manager to prevent race condition in alias creation."""
Expand Down Expand Up @@ -173,26 +181,22 @@ async def _get_alias_init_lock(cls) -> Lock:
cls._alias_init_lock = Lock()
return cls._alias_init_lock

@classmethod
async def _get_alias_map(cls, configuration: Configuration) -> dict[str, str]:
async def _get_alias_map(self) -> dict[str, str]:
"""Get the aliases and storage ids mapping from the default kvs.

Mapping is loaded from kvs only once and is shared for all instances of the _AliasResolver class.

Args:
configuration: Configuration object to use for accessing the default KVS.
Mapping is loaded from kvs only once and is shared for all instances of the `AliasResolver` class.

Returns:
Map of aliases and storage ids.
"""
if not cls._alias_map_loaded and configuration.is_at_home:
default_kvs_client = await cls._get_default_kvs_client(configuration)
if not AliasResolver._alias_map_loaded and self._configuration.is_at_home:
default_kvs_client = self._get_default_kvs_client()

record = await default_kvs_client.get_record(cls._ALIAS_MAPPING_KEY)
cls._alias_map = record.get('value', {}) if record else {}
cls._alias_map_loaded = True
record = await default_kvs_client.get_record(self._ALIAS_MAPPING_KEY)
AliasResolver._alias_map = record.get('value', {}) if record else {}
AliasResolver._alias_map_loaded = True

return cls._alias_map
return AliasResolver._alias_map

async def resolve_id(self) -> str | None:
"""Get id of the aliased storage.
Expand All @@ -212,12 +216,12 @@ async def resolve_id(self) -> str | None:
return storage_id

# Fallback to the mapping saved in the default KVS
return (await self._get_alias_map(self._configuration)).get(self._storage_key, None)
return (await self._get_alias_map()).get(self._storage_key, None)

async def store_mapping(self, storage_id: str) -> None:
"""Add alias and related storage id to the mapping in default kvs and local in-memory mapping."""
# Update in-memory mapping
alias_map = await self._get_alias_map(self._configuration)
alias_map = await self._get_alias_map()
alias_map[self._storage_key] = storage_id

if not self._configuration.is_at_home:
Expand All @@ -226,7 +230,7 @@ async def store_mapping(self, storage_id: str) -> None:
)
return

default_kvs_client = await self._get_default_kvs_client(self._configuration)
default_kvs_client = self._get_default_kvs_client()
await default_kvs_client.get()

try:
Expand All @@ -250,16 +254,14 @@ def _storage_key(self) -> str:
]
)

@staticmethod
async def _get_default_kvs_client(configuration: Configuration) -> KeyValueStoreClientAsync:
"""Get a client for the default key-value store."""
apify_client_async = ApifyClientAsync(
token=configuration.token,
api_url=configuration.api_base_url,
max_retries=8,
)
def _get_default_kvs_client(self) -> KeyValueStoreClientAsync:
"""Get a client for the default key-value store.

if not configuration.default_key_value_store_id:
Derived from the injected `ApifyClientAsync`, so alias resolution shares the same HTTP client (and its
connection pool and event loop affinity) as the storage operation that triggered it, instead of creating
and leaking its own.
"""
if not self._configuration.default_key_value_store_id:
raise ValueError("'Configuration.default_key_value_store_id' must be set.")

return apify_client_async.key_value_store(key_value_store_id=configuration.default_key_value_store_id)
return self._api_client.key_value_store(key_value_store_id=self._configuration.default_key_value_store_id)
2 changes: 2 additions & 0 deletions src/apify/storage_clients/_apify/_api_client_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def get_resource_client(storage_id: str) -> DatasetClientAsync:
storage_type=storage_type,
collection_client=collection_client,
get_resource_client_by_id=get_resource_client,
api_client=apify_client,
configuration=configuration,
) # ty:ignore[no-matching-overload]

Expand All @@ -127,6 +128,7 @@ def get_resource_client(storage_id: str) -> DatasetClientAsync:
storage_type=storage_type,
collection_client=collection_client,
get_resource_client_by_id=get_resource_client,
api_client=apify_client,
configuration=configuration,
) # ty:ignore[no-matching-overload]

Expand Down
102 changes: 89 additions & 13 deletions tests/unit/storage_clients/test_alias_resolver.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from __future__ import annotations

from unittest.mock import AsyncMock, patch
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch

from apify_client import ApifyClientAsync

from apify._configuration import Configuration
from apify.storage_clients._apify._alias_resolving import AliasResolver


def _api_client() -> ApifyClientAsync:
"""Build a throwaway API client for resolver construction in tests that never issue real requests."""
return ApifyClientAsync(token='test-token')


def test_storage_key_format() -> None:
"""Test that _storage_key has the expected format: type,alias,hash."""
config = Configuration(token='test-token', api_base_url='https://api.apify.com')
resolver = AliasResolver(storage_type='Dataset', alias='my-alias', configuration=config)
resolver = AliasResolver(storage_type='Dataset', alias='my-alias', configuration=config, api_client=_api_client())
key = resolver._storage_key
parts = key.split(',')
assert len(parts) == 3
Expand All @@ -22,15 +30,19 @@ async def test_resolve_id_returns_none_for_unknown() -> None:
"""Test that resolve_id returns None for an alias not in the map."""
AliasResolver._alias_map = {}
config = Configuration(token='test-token')
resolver = AliasResolver(storage_type='Dataset', alias='unknown-alias', configuration=config)
resolver = AliasResolver(
storage_type='Dataset', alias='unknown-alias', configuration=config, api_client=_api_client()
)
result = await resolver.resolve_id()
assert result is None


async def test_resolve_id_returns_stored_id() -> None:
"""Test that resolve_id returns the ID if it was previously stored."""
config = Configuration(token='test-token', api_base_url='https://api.apify.com')
resolver = AliasResolver(storage_type='KeyValueStore', alias='test-alias', configuration=config)
resolver = AliasResolver(
storage_type='KeyValueStore', alias='test-alias', configuration=config, api_client=_api_client()
)
storage_key = resolver._storage_key
AliasResolver._alias_map = {storage_key: 'stored-id-123'}

Expand All @@ -42,7 +54,9 @@ async def test_store_mapping_local_only() -> None:
"""Test that store_mapping only updates in-memory map when not at home."""
AliasResolver._alias_map = {}
config = Configuration(is_at_home=False, token='test-token')
resolver = AliasResolver(storage_type='RequestQueue', alias='test-alias', configuration=config)
resolver = AliasResolver(
storage_type='RequestQueue', alias='test-alias', configuration=config, api_client=_api_client()
)

await resolver.store_mapping(storage_id='new-id-456')

Expand All @@ -55,7 +69,7 @@ async def test_concurrent_alias_creation_uses_lock() -> None:
AliasResolver._alias_init_lock = None
AliasResolver._alias_map = {}
config = Configuration(token='test-token')
resolver = AliasResolver(storage_type='Dataset', alias='test', configuration=config)
resolver = AliasResolver(storage_type='Dataset', alias='test', configuration=config, api_client=_api_client())

async with resolver:
# Lock should be acquired
Expand All @@ -71,26 +85,28 @@ async def test_get_alias_map_returns_in_memory_map() -> None:
"""Test that _get_alias_map returns the in-memory map when not at home."""
AliasResolver._alias_map = {'existing_key': 'existing_id'}
config = Configuration(is_at_home=False, token='test-token')
resolver = AliasResolver(storage_type='Dataset', alias='test', configuration=config, api_client=_api_client())

result = await AliasResolver._get_alias_map(config)
result = await resolver._get_alias_map()
assert result == {'existing_key': 'existing_id'}
# Also verify that an empty map is returned without fetching from KVS when not at home
AliasResolver._alias_map = {}
result = await AliasResolver._get_alias_map(config)
result = await resolver._get_alias_map()
assert result == {}


async def test_get_alias_map_loads_from_kvs_only_once_when_empty() -> None:
"""An empty KVS response must not trigger repeat fetches on subsequent calls."""
config = Configuration(is_at_home=True, token='test-token', default_key_value_store_id='default-kvs-id')
resolver = AliasResolver(storage_type='Dataset', alias='test', configuration=config, api_client=_api_client())

fake_kvs_client = AsyncMock()
fake_kvs_client.get_record = AsyncMock(return_value=None)

with patch.object(AliasResolver, '_get_default_kvs_client', return_value=fake_kvs_client):
await AliasResolver._get_alias_map(config)
await AliasResolver._get_alias_map(config)
await AliasResolver._get_alias_map(config)
await resolver._get_alias_map()
await resolver._get_alias_map()
await resolver._get_alias_map()

assert fake_kvs_client.get_record.await_count == 1
assert AliasResolver._alias_map == {}
Expand All @@ -100,7 +116,7 @@ async def test_store_mapping_uses_injected_configuration_is_at_home() -> None:
"""`store_mapping` gates on the injected configuration's `is_at_home`, not the global one."""
# Global `is_at_home` defaults to False; injected config says True — the KVS write must still happen.
config = Configuration(is_at_home=True, token='test-token', default_key_value_store_id='default-kvs-id')
resolver = AliasResolver(storage_type='Dataset', alias='test-alias', configuration=config)
resolver = AliasResolver(storage_type='Dataset', alias='test-alias', configuration=config, api_client=_api_client())

fake_kvs_client = AsyncMock()
fake_kvs_client.get_record = AsyncMock(return_value=None)
Expand Down Expand Up @@ -130,6 +146,66 @@ async def test_configuration_storages_alias_resolving() -> None:
# Check that id of each non-default storage saved in the mapping is resolved
for storage_type in ('Dataset', 'KeyValueStore', 'RequestQueue'):
assert (
await AliasResolver(storage_type=storage_type, alias='custom', configuration=configuration).resolve_id()
await AliasResolver(
storage_type=storage_type, alias='custom', configuration=configuration, api_client=_api_client()
).resolve_id()
== f'custom_{storage_type}_id'
)


def test_default_kvs_client_derives_from_injected_client() -> None:
"""The default-KVS client used for alias mapping is derived from the injected client, not a freshly created one."""
api_client = _api_client()
config = Configuration(token='test-token', default_key_value_store_id='default-kvs-id')
resolver = AliasResolver(storage_type='Dataset', alias='a', configuration=config, api_client=api_client)

kvs_client = resolver._get_default_kvs_client()

assert kvs_client.resource_id == 'default-kvs-id'
# Shares the injected client's HTTP client (and its connection pool), proving no separate client is spun up.
assert kvs_client._http_client is api_client.http_client


def test_resolvers_use_their_own_injected_client() -> None:
"""Each resolver derives its KVS client from its own injected client; there is no shared process-global cache."""
config = Configuration(token='test-token', default_key_value_store_id='default-kvs-id')
client_a = _api_client()
client_b = _api_client()
resolver_a = AliasResolver(storage_type='Dataset', alias='a', configuration=config, api_client=client_a)
resolver_b = AliasResolver(storage_type='Dataset', alias='b', configuration=config, api_client=client_b)

assert resolver_a._get_default_kvs_client()._http_client is client_a.http_client
assert resolver_b._get_default_kvs_client()._http_client is client_b.http_client
assert client_a.http_client is not client_b.http_client


def test_alias_resolution_runs_across_event_loops_with_shared_client() -> None:
"""A single injected client can drive alias resolution from more than one event loop without loop-bound state."""
config = Configuration(is_at_home=True, token='test-token', default_key_value_store_id='default-kvs-id')

kvs_client = AsyncMock()
kvs_client.get_record = AsyncMock(return_value={'value': {}})
kvs_client.set_record = AsyncMock(return_value=None)
api_client = MagicMock()
api_client.key_value_store = MagicMock(return_value=kvs_client)

async def store_on_current_loop(alias: str, storage_id: str) -> None:
# Each loop starts from clean class state and builds its own lock on the running loop.
AliasResolver._alias_map = {}
AliasResolver._alias_map_loaded = False
AliasResolver._alias_init_lock = None
resolver = AliasResolver(storage_type='Dataset', alias=alias, configuration=config, api_client=api_client)
async with resolver:
await resolver.store_mapping(storage_id=storage_id)

loop_a = asyncio.new_event_loop()
loop_b = asyncio.new_event_loop()
try:
loop_a.run_until_complete(store_on_current_loop('alias-a', 'id-a'))
loop_b.run_until_complete(store_on_current_loop('alias-b', 'id-b'))
finally:
loop_a.close()
loop_b.close()

# The same injected client served both event loops.
assert kvs_client.set_record.await_count == 2
Loading