diff --git a/pinecone/_client.py b/pinecone/_client.py index 4e72568a..48c2d784 100644 --- a/pinecone/_client.py +++ b/pinecone/_client.py @@ -8,7 +8,12 @@ from pinecone._internal.config import PineconeConfig, RetryConfig from pinecone._internal.constants import CONTROL_PLANE_API_VERSION, DEFAULT_BASE_URL -from pinecone._internal.indexes_helpers import _LegacyIndexKwargs, poll_index_until_ready +from pinecone._internal.indexes_helpers import ( + IndexKwargs, + _LegacyIndexKwargs, + apply_index_kwargs_overrides, + poll_index_until_ready, +) from pinecone._internal.validation import require_non_empty from pinecone.errors.exceptions import ValidationError @@ -874,7 +879,7 @@ def IndexAsyncio(self, host: str, **kwargs: Any) -> Any: # noqa: N802 """ from pinecone.async_client.async_index import AsyncIndex as _AsyncIndex - return _AsyncIndex( + index_kwargs = IndexKwargs( host=host, api_key=self._config.api_key, additional_headers=dict(self._config.additional_headers), @@ -886,6 +891,10 @@ def IndexAsyncio(self, host: str, **kwargs: Any) -> Any: # noqa: N802 source_tag=self._config.source_tag, connection_pool_maxsize=self._config.connection_pool_maxsize, ) + index_kwargs = apply_index_kwargs_overrides( + index_kwargs, kwargs, caller="Pinecone.IndexAsyncio()" + ) + return _AsyncIndex(**index_kwargs) def close(self) -> None: """Close all open HTTP connections. diff --git a/pinecone/_internal/indexes_helpers.py b/pinecone/_internal/indexes_helpers.py index cf53f68a..aa6780af 100644 --- a/pinecone/_internal/indexes_helpers.py +++ b/pinecone/_internal/indexes_helpers.py @@ -60,6 +60,34 @@ class _LegacyIndexKwargs(IndexKwargs): pool_threads: NotRequired[int] +_INDEX_KWARG_KEYS = frozenset(IndexKwargs.__annotations__) + + +def apply_index_kwargs_overrides( + base: IndexKwargs, overrides: Mapping[str, Any], *, caller: str +) -> IndexKwargs: + """Apply explicit data-plane client kwargs to factory defaults.""" + allowed = _INDEX_KWARG_KEYS - {"host"} + unexpected = sorted(set(overrides) - allowed) + if unexpected: + raise TypeError(f"{caller} got unexpected keyword arguments: {unexpected!r}") + + return IndexKwargs( + host=base["host"], + api_key=overrides.get("api_key", base["api_key"]), + additional_headers=dict(overrides.get("additional_headers", base["additional_headers"])), + timeout=overrides.get("timeout", base["timeout"]), + proxy_url=overrides.get("proxy_url", base["proxy_url"]), + proxy_headers=dict(overrides.get("proxy_headers", base["proxy_headers"])), + ssl_ca_certs=overrides.get("ssl_ca_certs", base["ssl_ca_certs"]), + ssl_verify=overrides.get("ssl_verify", base["ssl_verify"]), + source_tag=overrides.get("source_tag", base["source_tag"]), + connection_pool_maxsize=overrides.get( + "connection_pool_maxsize", base["connection_pool_maxsize"] + ), + ) + + def resolve_enum_value(value: Any) -> Any: """Extract ``.value`` from enum-like objects, pass through otherwise.""" return value.value if hasattr(value, "value") else value diff --git a/pinecone/async_client/pinecone.py b/pinecone/async_client/pinecone.py index 31b3daae..ba04f9df 100644 --- a/pinecone/async_client/pinecone.py +++ b/pinecone/async_client/pinecone.py @@ -8,7 +8,11 @@ from pinecone._internal.config import PineconeConfig, RetryConfig from pinecone._internal.constants import CONTROL_PLANE_API_VERSION, DEFAULT_BASE_URL -from pinecone._internal.indexes_helpers import IndexKwargs, async_poll_index_until_ready +from pinecone._internal.indexes_helpers import ( + IndexKwargs, + apply_index_kwargs_overrides, + async_poll_index_until_ready, +) from pinecone._internal.validation import require_non_empty from pinecone.errors.exceptions import ValidationError @@ -750,7 +754,7 @@ def IndexAsyncio(self, host: str, **kwargs: Any) -> AsyncIndex: # noqa: N802 """ from pinecone.async_client.async_index import AsyncIndex as _AsyncIndex - return _AsyncIndex( + index_kwargs = IndexKwargs( host=host, api_key=self._config.api_key, additional_headers=dict(self._config.additional_headers), @@ -762,6 +766,10 @@ def IndexAsyncio(self, host: str, **kwargs: Any) -> AsyncIndex: # noqa: N802 source_tag=self._config.source_tag, connection_pool_maxsize=self._config.connection_pool_maxsize, ) + index_kwargs = apply_index_kwargs_overrides( + index_kwargs, kwargs, caller="AsyncPinecone.IndexAsyncio()" + ) + return _AsyncIndex(**index_kwargs) def _build_index_kwargs(self, host: str) -> IndexKwargs: """Return the kwargs dict for constructing an AsyncIndex.""" diff --git a/tests/unit/test_async_pinecone_backcompat.py b/tests/unit/test_async_pinecone_backcompat.py index 2db62278..feb4013a 100644 --- a/tests/unit/test_async_pinecone_backcompat.py +++ b/tests/unit/test_async_pinecone_backcompat.py @@ -4,6 +4,8 @@ from unittest.mock import AsyncMock, MagicMock +import pytest + from pinecone.async_client.pinecone import AsyncPinecone from pinecone.inference.models.index_embed import IndexEmbed from pinecone.models.enums import CloudProvider @@ -333,6 +335,22 @@ def test_async_index_asyncio_delegate_returns_async_index() -> None: assert isinstance(idx, AsyncIndex) +def test_async_index_asyncio_delegate_forwards_explicit_ssl_verify() -> None: + from unittest.mock import patch + + pc = AsyncPinecone(api_key="test-key", ssl_verify=True) + with patch("pinecone.async_client.async_index.AsyncIndex") as mock_async_index: + pc.IndexAsyncio(host="my-index.svc.pinecone.io", ssl_verify=False) + _, kwargs = mock_async_index.call_args + assert kwargs["ssl_verify"] is False + + +def test_async_index_asyncio_delegate_rejects_unknown_kwargs() -> None: + pc = AsyncPinecone(api_key="test-key") + with pytest.raises(TypeError, match="unexpected keyword arguments"): + pc.IndexAsyncio(host="my-index.svc.pinecone.io", bogus=True) + + # --------------------------------------------------------------------------- # __repr__ masking # --------------------------------------------------------------------------- diff --git a/tests/unit/test_pinecone_class.py b/tests/unit/test_pinecone_class.py index 9b343c91..4fcf6d0d 100644 --- a/tests/unit/test_pinecone_class.py +++ b/tests/unit/test_pinecone_class.py @@ -470,6 +470,18 @@ def test_constructs_async_index(self) -> None: _, kwargs = mock_async_index.call_args assert kwargs["host"] == "my-index.svc.pinecone.io" + def test_forwards_explicit_ssl_verify(self) -> None: + pc = Pinecone(api_key="test-key", ssl_verify=True) + with patch("pinecone.async_client.async_index.AsyncIndex") as mock_async_index: + pc.IndexAsyncio(host="my-index.svc.pinecone.io", ssl_verify=False) + _, kwargs = mock_async_index.call_args + assert kwargs["ssl_verify"] is False + + def test_rejects_unknown_kwargs(self) -> None: + pc = Pinecone(api_key="test-key") + with pytest.raises(TypeError, match="unexpected keyword arguments"): + pc.IndexAsyncio(host="my-index.svc.pinecone.io", bogus=True) + # --------------------------------------------------------------------------- # Lazy namespace property first-access and caching