diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index 39954ef561..56fc5fed69 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -31,7 +31,15 @@ from pyiceberg import __version__ from pyiceberg.catalog import BOTOCORE_SESSION, TOKEN, URI, WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary -from pyiceberg.catalog.rest.auth import AUTH_MANAGER, AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager +from pyiceberg.catalog.rest.auth import ( + AUTH_MANAGER, + AuthManager, + AuthManagerAdapter, + AuthManagerFactory, + LegacyOAuth2AuthManager, + NoopAuthManager, + SigV4AuthManager, +) from pyiceberg.catalog.rest.response import _handle_non_200_response from pyiceberg.catalog.rest.scan_planning import ( FetchScanTasksRequest, @@ -251,11 +259,11 @@ class ScanPlanningMode(Enum): CA_BUNDLE = "cabundle" SSL = "ssl" SIGV4 = "rest.sigv4-enabled" +SIGV4_AUTH_TYPE = "sigv4" SIGV4_REGION = "rest.signing-region" SIGV4_SERVICE = "rest.signing-name" SIGV4_MAX_RETRIES = "rest.sigv4.max-retries" SIGV4_MAX_RETRIES_DEFAULT = 10 -EMPTY_BODY_SHA256: str = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" OAUTH2_SERVER_URI = "oauth2-server-uri" SNAPSHOT_LOADING_MODE = "snapshot-loading-mode" AUTH = "auth" @@ -431,10 +439,49 @@ def _create_session(self) -> Session: elif ssl_client_cert := ssl_client.get(CERT): session.cert = ssl_client_cert + self._auth_manager = self._build_auth_manager(session) + session.auth = AuthManagerAdapter(self._auth_manager) + + # SigV4 retry is decoupled from signing: mount a plain retry adapter. + if self._is_sigv4_enabled(): + from requests.adapters import HTTPAdapter + + max_retries = property_as_int(self.properties, SIGV4_MAX_RETRIES, SIGV4_MAX_RETRIES_DEFAULT) + session.mount(self.uri, HTTPAdapter(max_retries=max_retries)) + + return session + + def _is_sigv4_enabled(self) -> bool: + """Return True if SigV4 signing is requested via either config path.""" + if property_as_bool(self.properties, SIGV4, False): + return True + auth_config = self.properties.get(AUTH) + return auth_config is not None and auth_config.get("type") == SIGV4_AUTH_TYPE + + def _build_auth_manager(self, session: Session) -> AuthManager: + """Build the AuthManager, wrapping the delegate in SigV4 when enabled.""" + delegate = self._build_delegate_auth_manager(session) + if self._is_sigv4_enabled(): + return self._build_sigv4_auth_manager(delegate) + return delegate + + def _build_delegate_auth_manager(self, session: Session) -> AuthManager: + """Build the header-based AuthManager (the SigV4 delegate, or the manager used directly).""" if auth_config := self.properties.get(AUTH): auth_type = auth_config.get("type") if auth_type is None: raise ValueError("auth.type must be defined") + + if auth_type == SIGV4_AUTH_TYPE: + # The delegate is configured under auth.sigv4.delegate.* + sigv4_config = auth_config.get(SIGV4_AUTH_TYPE, {}) + delegate_config = sigv4_config.get("delegate") + if not delegate_config or "type" not in delegate_config: + # No delegate configured: SigV4-only auth, with no header-based delegate. + return NoopAuthManager() + delegate_type = delegate_config["type"] + return AuthManagerFactory.create(delegate_type, delegate_config.get(delegate_type, {})) + auth_type_config = auth_config.get(auth_type, {}) auth_impl = auth_config.get("impl") @@ -444,17 +491,28 @@ def _create_session(self) -> Session: if auth_type != CUSTOM and auth_impl: raise ValueError("auth.impl can only be specified when using custom auth.type") - self._auth_manager = AuthManagerFactory.create(auth_impl or auth_type, auth_type_config) - session.auth = AuthManagerAdapter(self._auth_manager) - else: - self._auth_manager = self._create_legacy_oauth2_auth_manager(session) - session.auth = AuthManagerAdapter(self._auth_manager) + return AuthManagerFactory.create(auth_impl or auth_type, auth_type_config) - # Configure SigV4 Request Signing - if property_as_bool(self.properties, SIGV4, False): - self._init_sigv4(session) + return self._create_legacy_oauth2_auth_manager(session) - return session + def _build_sigv4_auth_manager(self, delegate: AuthManager) -> AuthManager: + """Wrap the delegate AuthManager in a SigV4AuthManager.""" + import boto3 + + boto_session = boto3.Session( + profile_name=get_first_property_value(self.properties, AWS_PROFILE_NAME), + region_name=get_first_property_value(self.properties, AWS_REGION), + botocore_session=self.properties.get(BOTOCORE_SESSION), + aws_access_key_id=get_first_property_value(self.properties, AWS_ACCESS_KEY_ID), + aws_secret_access_key=get_first_property_value(self.properties, AWS_SECRET_ACCESS_KEY), + aws_session_token=get_first_property_value(self.properties, AWS_SESSION_TOKEN), + ) + return SigV4AuthManager( + delegate=delegate, + boto_session=boto_session, + region=self.properties.get(SIGV4_REGION), + service=self.properties.get(SIGV4_SERVICE, "execute-api"), + ) @staticmethod def _resolve_storage_credentials(storage_credentials: list[StorageCredential], location: str | None) -> Properties: @@ -757,64 +815,6 @@ def _split_identifier_for_json(self, identifier: str | Identifier) -> dict[str, identifier_tuple = self._identifier_to_validated_tuple(identifier) return {"namespace": identifier_tuple[:-1], "name": identifier_tuple[-1]} - def _init_sigv4(self, session: Session) -> None: - from urllib import parse - - import boto3 - from botocore.auth import SigV4Auth - from botocore.awsrequest import AWSRequest - from requests import PreparedRequest - from requests.adapters import HTTPAdapter - - class SigV4Adapter(HTTPAdapter): - def __init__(self, **properties: str): - self._properties = properties - max_retries = property_as_int(self._properties, SIGV4_MAX_RETRIES, SIGV4_MAX_RETRIES_DEFAULT) - super().__init__(max_retries=max_retries) - self._boto_session = boto3.Session( - profile_name=get_first_property_value(self._properties, AWS_PROFILE_NAME), - region_name=get_first_property_value(self._properties, AWS_REGION), - botocore_session=self._properties.get(BOTOCORE_SESSION), - aws_access_key_id=get_first_property_value(self._properties, AWS_ACCESS_KEY_ID), - aws_secret_access_key=get_first_property_value(self._properties, AWS_SECRET_ACCESS_KEY), - aws_session_token=get_first_property_value(self._properties, AWS_SESSION_TOKEN), - ) - - def add_headers(self, request: PreparedRequest, **kwargs: Any) -> None: # pylint: disable=W0613 - credentials = self._boto_session.get_credentials().get_frozen_credentials() - region = self._properties.get(SIGV4_REGION, self._boto_session.region_name) - service = self._properties.get(SIGV4_SERVICE, "execute-api") - - url = str(request.url).split("?")[0] - query = str(parse.urlsplit(request.url).query) - params = dict(parse.parse_qsl(query)) - - # remove the connection header as it will be updated after signing - if "connection" in request.headers: - del request.headers["connection"] - # For empty bodies, explicitly set the content hash header to the SHA256 of an empty string - if not request.body: - request.headers["x-amz-content-sha256"] = EMPTY_BODY_SHA256 - - aws_request = AWSRequest( - method=request.method, url=url, params=params, data=request.body, headers=dict(request.headers) - ) - - SigV4Auth(credentials, service, region).add_auth(aws_request) - original_header = request.headers - signed_headers = aws_request.headers - relocated_headers = {} - - # relocate headers if there is a conflict with signed headers - for header, value in original_header.items(): - if header in signed_headers and signed_headers[header] != value: - relocated_headers[f"Original-{header}"] = value - - request.headers.update(relocated_headers) - request.headers.update(signed_headers) - - session.mount(self.uri, SigV4Adapter(**self.properties)) - def _response_to_table(self, identifier_tuple: tuple[str, ...], table_response: TableResponse) -> Table: # Per Iceberg spec: storage-credentials take precedence over config credential_config = self._resolve_storage_credentials( diff --git a/pyiceberg/catalog/rest/auth.py b/pyiceberg/catalog/rest/auth.py index 602074282c..3f42708dd7 100644 --- a/pyiceberg/catalog/rest/auth.py +++ b/pyiceberg/catalog/rest/auth.py @@ -21,7 +21,7 @@ import threading import time from abc import ABC, abstractmethod -from functools import cached_property +from functools import cache, cached_property from typing import Any import requests @@ -36,6 +36,37 @@ COLON = ":" logger = logging.getLogger(__name__) +# SHA-256 of an empty payload. Used as the x-amz-content-sha256 header value for +# empty-body requests, matching Iceberg Java's RESTSigV4AuthSession workaround. +EMPTY_BODY_SHA256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + + +@cache +def _iceberg_sigv4_auth_class() -> type: + """Lazily build the botocore SigV4Auth subclass (botocore is an optional dependency).""" + from urllib import parse + + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + + class _IcebergSigV4Auth(SigV4Auth): + def canonical_request(self, request: AWSRequest) -> str: + # Override forces the hex payload hash in the canonical request even when + # the x-amz-content-sha256 header is base64 (see SigV4AuthManager.sign_request). + # Mirrors botocore <=1.42.x SigV4Auth.canonical_request layout: + # https://github.com/boto/botocore/blob/1.42.85/botocore/auth.py#L622-L637 + cr = [request.method.upper()] + path = self._normalize_url_path(parse.urlsplit(request.url).path) + cr.append(path) + cr.append(self.canonical_query_string(request)) + headers_to_sign = self.headers_to_sign(request) + cr.append(self.canonical_headers(headers_to_sign) + "\n") + cr.append(self.signed_headers(headers_to_sign)) + cr.append(self.payload(request)) + return "\n".join(cr) + + return _IcebergSigV4Auth + class AuthManager(ABC): """ @@ -48,6 +79,14 @@ class AuthManager(ABC): def auth_header(self) -> str | None: """Return the Authorization header value, or None if not applicable.""" + def sign_request(self, request: PreparedRequest) -> PreparedRequest: + """Optionally sign or otherwise modify the prepared request. + + The default implementation is a no-op. Override for request-signing + schemes such as SigV4 that must inspect the full request. + """ + return request + class NoopAuthManager(AuthManager): """Auth Manager implementation with no auth.""" @@ -311,6 +350,91 @@ def auth_header(self) -> str: return f"Bearer {self._get_token()}" +class SigV4AuthManager(AuthManager): + """AuthManager that signs requests with AWS SigV4, wrapping a delegate AuthManager. + + Mirrors Iceberg Java's RESTSigV4AuthManager: the delegate AuthManager handles + header-based auth (e.g. OAuth2), then SigV4 signs the resulting request. + """ + + def __init__( + self, + delegate: AuthManager, + boto_session: Any, + region: str | None, + service: str = "execute-api", + ): + """Initialize SigV4AuthManager. + + Args: + delegate: AuthManager that supplies header-based auth before signing. + boto_session: A boto3.Session used to resolve AWS credentials. + region: SigV4 signing region; falls back to the boto session's region. + service: SigV4 signing service name. + """ + self._delegate = delegate + self._boto_session = boto_session + self._region = region + self._service = service + + def auth_header(self) -> str | None: + return self._delegate.auth_header() + + def sign_request(self, request: PreparedRequest) -> PreparedRequest: + import hashlib + from urllib import parse + + from botocore.awsrequest import AWSRequest + + credentials = self._boto_session.get_credentials().get_frozen_credentials() + region = self._region or self._boto_session.region_name + + url = str(request.url).split("?")[0] + query = str(parse.urlsplit(request.url).query) + params = dict(parse.parse_qsl(query)) + + # remove the connection header as it will be updated after signing + if "connection" in request.headers: + del request.headers["connection"] + + # Match Iceberg Java's AWS SDK v2 flexible-checksum signing: + # x-amz-content-sha256 header is base64 for non-empty bodies, hex for empty. + # The SigV4 canonical request still uses hex (enforced in _iceberg_sigv4_auth_class). + # Ref: https://github.com/apache/iceberg/blob/main/aws/src/main/java/org/apache/iceberg/aws/RESTSigV4AuthSession.java + if request.body: + if isinstance(request.body, str): + body_bytes = request.body.encode("utf-8") + elif isinstance(request.body, (bytes, bytearray)): + body_bytes = bytes(request.body) + else: + raise TypeError( + f"Unsupported request body type for SigV4 signing: {type(request.body).__name__}; expected str or bytes." + ) + content_sha256_header = base64.b64encode(hashlib.sha256(body_bytes).digest()).decode() + else: + content_sha256_header = EMPTY_BODY_SHA256 + + signing_headers = dict(request.headers) + signing_headers["x-amz-content-sha256"] = content_sha256_header + + aws_request = AWSRequest(method=request.method, url=url, params=params, data=request.body, headers=signing_headers) + + _iceberg_sigv4_auth_class()(credentials, self._service, region).add_auth(aws_request) + + original_header = dict(request.headers) + signed_headers = dict(aws_request.headers) + relocated_headers = {} + + # relocate headers if there is a conflict with signed headers + for header, value in original_header.items(): + if header in signed_headers and signed_headers[header] != value: + relocated_headers[f"Original-{header}"] = value + + request.headers.update(relocated_headers) + request.headers.update(signed_headers) + return request + + class AuthManagerAdapter(AuthBase): """A `requests.auth.AuthBase` adapter for integrating an `AuthManager` into a `requests.Session`. @@ -332,17 +456,19 @@ def __init__(self, auth_manager: AuthManager): def __call__(self, request: PreparedRequest) -> PreparedRequest: """ - Modify the outgoing request to include the Authorization header. + Modify the outgoing request to include the Authorization header and any signature. Args: request (requests.PreparedRequest): The HTTP request being prepared. Returns: - requests.PreparedRequest: The modified request with Authorization header. + requests.PreparedRequest: The modified request. """ if auth_header := self.auth_manager.auth_header(): request.headers["Authorization"] = auth_header - return request + # Header first, then sign: a request-signing AuthManager (e.g. SigV4) must + # see the Authorization header so it can relocate it before signing. + return self.auth_manager.sign_request(request) class AuthManagerFactory: diff --git a/pyproject.toml b/pyproject.toml index 96118f8451..4293e057a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,7 @@ sql-postgres = [ ] sql-sqlite = ["sqlalchemy>=2.0.18,<3"] gcsfs = ["gcsfs>=2023.1.0"] -rest-sigv4 = ["boto3>=1.24.59"] +rest-sigv4 = ["boto3>=1.24.59", "botocore<2"] hf = ["huggingface-hub>=0.24.0"] pyiceberg-core = ["pyiceberg-core>=0.5.1,<0.10.0"] datafusion = ["datafusion>=52,<53"] diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index df2f96a392..d0d49cd461 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -24,7 +24,6 @@ from unittest import mock import pytest -from requests import Request from requests.adapters import HTTPAdapter from requests.exceptions import HTTPError from requests_mock import Mocker @@ -33,7 +32,6 @@ from pyiceberg.catalog import PropertiesUpdateSummary, load_catalog from pyiceberg.catalog.rest import ( DEFAULT_ENDPOINTS, - EMPTY_BODY_SHA256, OAUTH2_SERVER_URI, SIGV4_MAX_RETRIES, SIGV4_MAX_RETRIES_DEFAULT, @@ -482,73 +480,26 @@ def test_list_tables_200(rest_mock: Mocker) -> None: def test_list_tables_200_sigv4(rest_mock: Mocker) -> None: namespace = "examples" + # SigV4 signing replaces the bearer Authorization header with an AWS4-HMAC-SHA256 + # signature, so the request headers are not matched against TEST_HEADERS here. rest_mock.get( f"{TEST_URI}v1/namespaces/{namespace}/tables", json={"identifiers": [{"namespace": ["examples"], "name": "fooshare"}]}, status_code=200, - request_headers=TEST_HEADERS, - ) - - assert RestCatalog("rest", **{"uri": TEST_URI, "token": TEST_TOKEN, "rest.sigv4-enabled": "true"}).list_tables(namespace) == [ - ("examples", "fooshare") - ] - assert rest_mock.called - - -def test_sigv4_sign_request_without_body(rest_mock: Mocker) -> None: - existing_token = "existing_token" - - catalog = RestCatalog( - "rest", - **{ - "uri": TEST_URI, - "token": existing_token, - "rest.sigv4-enabled": "true", - "rest.signing-region": "us-west-2", - "client.access-key-id": "id", - "client.secret-access-key": "secret", - }, ) - prepared = catalog._session.prepare_request(Request("GET", f"{TEST_URI}v1/config")) - adapter = catalog._session.adapters[catalog.uri] - assert isinstance(adapter, HTTPAdapter) - adapter.add_headers(prepared) - - assert prepared.headers["Authorization"].startswith("AWS4-HMAC-SHA256") - assert prepared.headers["Original-Authorization"] == f"Bearer {existing_token}" - assert prepared.headers["x-amz-content-sha256"] == EMPTY_BODY_SHA256 - - -def test_sigv4_sign_request_with_body(rest_mock: Mocker) -> None: - existing_token = "existing_token" - - catalog = RestCatalog( + assert RestCatalog( "rest", **{ "uri": TEST_URI, - "token": existing_token, + "token": TEST_TOKEN, "rest.sigv4-enabled": "true", "rest.signing-region": "us-west-2", "client.access-key-id": "id", "client.secret-access-key": "secret", }, - ) - - prepared = catalog._session.prepare_request( - Request( - "POST", - f"{TEST_URI}v1/namespaces", - data={"namespace": "asdfasd"}, - ) - ) - adapter = catalog._session.adapters[catalog.uri] - assert isinstance(adapter, HTTPAdapter) - adapter.add_headers(prepared) - - assert prepared.headers["Authorization"].startswith("AWS4-HMAC-SHA256") - assert prepared.headers["Original-Authorization"] == f"Bearer {existing_token}" - assert prepared.headers.get("x-amz-content-sha256") != EMPTY_BODY_SHA256 + ).list_tables(namespace) == [("examples", "fooshare")] + assert rest_mock.called def test_sigv4_adapter_default_retry_config(rest_mock: Mocker) -> None: @@ -588,29 +539,6 @@ def test_sigv4_adapter_override_retry_config(rest_mock: Mocker) -> None: assert adapter.max_retries.total == 3 -def test_sigv4_uses_client_profile_name(rest_mock: Mocker) -> None: - with mock.patch("boto3.Session") as mock_session: - RestCatalog( - "rest", - **{ - "uri": TEST_URI, - "token": TEST_TOKEN, - "rest.sigv4-enabled": "true", - "rest.signing-region": "us-west-2", - "client.profile-name": "rest-profile", - }, - ) - - mock_session.assert_called_with( - profile_name="rest-profile", - region_name=None, - botocore_session=None, - aws_access_key_id=None, - aws_secret_access_key=None, - aws_session_token=None, - ) - - def test_list_tables_404(rest_mock: Mocker) -> None: namespace = "examples" rest_mock.get( @@ -728,16 +656,25 @@ def test_list_views_paginated_200_none_next_page_token(rest_mock: Mocker) -> Non def test_list_views_200_sigv4(rest_mock: Mocker) -> None: namespace = "examples" + # SigV4 signing replaces the bearer Authorization header with an AWS4-HMAC-SHA256 + # signature, so the request headers are not matched against TEST_HEADERS here. rest_mock.get( f"{TEST_URI}v1/namespaces/{namespace}/views", json={"identifiers": [{"namespace": ["examples"], "name": "fooshare"}]}, status_code=200, - request_headers=TEST_HEADERS, ) - assert RestCatalog("rest", **{"uri": TEST_URI, "token": TEST_TOKEN, "rest.sigv4-enabled": "true"}).list_views(namespace) == [ - ("examples", "fooshare") - ] + assert RestCatalog( + "rest", + **{ + "uri": TEST_URI, + "token": TEST_TOKEN, + "rest.sigv4-enabled": "true", + "rest.signing-region": "us-west-2", + "client.access-key-id": "id", + "client.secret-access-key": "secret", + }, + ).list_views(namespace) == [("examples", "fooshare")] assert rest_mock.called @@ -2494,7 +2431,17 @@ def test_rest_catalog_close_sigv4(self, rest_mock: Mocker) -> None: status_code=200, ) - catalog = RestCatalog("rest", **{"uri": TEST_URI, "token": TEST_TOKEN, "rest.sigv4-enabled": "true"}) + catalog = RestCatalog( + "rest", + **{ + "uri": TEST_URI, + "token": TEST_TOKEN, + "rest.sigv4-enabled": "true", + "rest.signing-region": "us-west-2", + "client.access-key-id": "id", + "client.secret-access-key": "secret", + }, + ) catalog.close() assert hasattr(catalog, "_session") assert len(catalog._session.adapters) == self.EXPECTED_ADAPTERS_SIGV4 @@ -2528,7 +2475,17 @@ def test_rest_catalog_context_manager_with_exception_sigv4(self, rest_mock: Mock ) try: - with RestCatalog("rest", **{"uri": TEST_URI, "token": TEST_TOKEN, "rest.sigv4-enabled": "true"}) as cat: + with RestCatalog( + "rest", + **{ + "uri": TEST_URI, + "token": TEST_TOKEN, + "rest.sigv4-enabled": "true", + "rest.signing-region": "us-west-2", + "client.access-key-id": "id", + "client.secret-access-key": "secret", + }, + ) as cat: catalog = cat raise ValueError("Test exception") except ValueError: diff --git a/tests/catalog/test_rest_auth.py b/tests/catalog/test_rest_auth.py index ae5d40f5aa..29f6612ff3 100644 --- a/tests/catalog/test_rest_auth.py +++ b/tests/catalog/test_rest_auth.py @@ -16,12 +16,14 @@ # under the License. import base64 +import hashlib from unittest.mock import MagicMock, patch import pytest import requests from requests_mock import Mocker +from pyiceberg.catalog.rest import RestCatalog from pyiceberg.catalog.rest.auth import AuthManagerAdapter, BasicAuthManager, EntraAuthManager, GoogleAuthManager, NoopAuthManager TEST_URI = "https://iceberg-test-catalog/" @@ -35,6 +37,11 @@ def rest_mock(requests_mock: Mocker) -> Mocker: json={}, status_code=200, ) + requests_mock.get( + f"{TEST_URI}v1/config", + json={"defaults": {}, "overrides": {}}, + status_code=200, + ) return requests_mock @@ -249,3 +256,374 @@ def test_entra_auth_manager_token_failure(mock_default_cred: MagicMock, rest_moc # Verify no requests were made with a blank/missing auth header history = rest_mock.request_history assert len(history) == 0 + + +def test_sign_request_default_is_noop() -> None: + """AuthManager.sign_request default implementation must not modify the request.""" + manager = NoopAuthManager() + prepared = requests.Request("GET", TEST_URI).prepare() + original_headers = dict(prepared.headers) + + result = manager.sign_request(prepared) + + assert result is prepared + assert dict(result.headers) == original_headers + + +def test_sigv4_auth_manager_signs_with_java_reference_values() -> None: + """SigV4AuthManager.sign_request must match Iceberg Java reference header values.""" + import boto3 + + from pyiceberg.catalog.rest.auth import SigV4AuthManager + + boto_session = boto3.Session( + aws_access_key_id="id", + aws_secret_access_key="secret", + region_name="us-east-1", + ) + manager = SigV4AuthManager( + delegate=NoopAuthManager(), + boto_session=boto_session, + region="us-east-1", + service="execute-api", + ) + + # Non-empty body: base64 SHA-256 (Iceberg Java TestRESTSigV4AuthSession.java L177) + body = b'{"namespace":["ns"],"properties":{}}' + prepared = requests.Request("POST", "https://example.com/v1/namespaces", data=body).prepare() + manager.sign_request(prepared) + assert prepared.headers["x-amz-content-sha256"] == base64.b64encode(hashlib.sha256(body).digest()).decode() + assert prepared.headers["x-amz-content-sha256"] == "yc5oAKPWjHY4sW8XQq0l/3aNrrXJKBycVFNnDEGMfww=" + assert prepared.headers["Authorization"].startswith("AWS4-HMAC-SHA256 Credential=") + + # Empty body: hex EMPTY_BODY_SHA256 (Iceberg Java TestRESTSigV4AuthSession.java L121) + prepared_empty = requests.Request("GET", "https://example.com/v1/config").prepare() + manager.sign_request(prepared_empty) + assert prepared_empty.headers["x-amz-content-sha256"] == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + + +def test_sigv4_auth_manager_relocates_delegate_authorization() -> None: + """When the delegate sets Authorization, SigV4 relocates it to Original-Authorization.""" + import boto3 + + from pyiceberg.catalog.rest.auth import SigV4AuthManager + + boto_session = boto3.Session(aws_access_key_id="id", aws_secret_access_key="secret", region_name="us-east-1") + manager = SigV4AuthManager( + delegate=BasicAuthManager(username="user", password="pass"), + boto_session=boto_session, + region="us-east-1", + service="execute-api", + ) + adapter = AuthManagerAdapter(manager) + + prepared = requests.Request("GET", "https://example.com/v1/config").prepare() + adapter(prepared) + + # SigV4 owns Authorization; the delegate's Basic header is relocated. + assert prepared.headers["Authorization"].startswith("AWS4-HMAC-SHA256 Credential=") + assert prepared.headers["Original-Authorization"].startswith("Basic ") + + +def test_sigv4_legacy_config_builds_sigv4_auth_manager(rest_mock: Mocker) -> None: + """Legacy rest.sigv4-enabled config produces a SigV4AuthManager.""" + from pyiceberg.catalog.rest.auth import SigV4AuthManager + + catalog = RestCatalog( + "rest", + **{ + "uri": TEST_URI, + "rest.sigv4-enabled": "true", + "rest.signing-region": "us-east-1", + "client.access-key-id": "id", + "client.secret-access-key": "secret", + }, + ) + assert isinstance(catalog._auth_manager, SigV4AuthManager) + + +def test_sigv4_auth_type_config_builds_sigv4_auth_manager(rest_mock: Mocker) -> None: + """New auth.type=sigv4 config produces a SigV4AuthManager wrapping the delegate.""" + from pyiceberg.catalog.rest.auth import SigV4AuthManager + + catalog = RestCatalog( + "rest", + **{ # type: ignore + "uri": TEST_URI, + "auth": {"type": "sigv4", "sigv4": {"delegate": {"type": "noop"}}}, + "rest.signing-region": "us-east-1", + "client.access-key-id": "id", + "client.secret-access-key": "secret", + }, + ) + assert isinstance(catalog._auth_manager, SigV4AuthManager) + + +def test_sigv4_sign_request_without_body(rest_mock: Mocker) -> None: + from pyiceberg.catalog.rest.auth import EMPTY_BODY_SHA256 + + existing_token = "existing_token" + + catalog = RestCatalog( + "rest", + **{ + "uri": TEST_URI, + "token": existing_token, + "rest.sigv4-enabled": "true", + "rest.signing-region": "us-west-2", + "client.access-key-id": "id", + "client.secret-access-key": "secret", + }, + ) + + # prepare_request applies session.auth, which signs via SigV4AuthManager. + prepared = catalog._session.prepare_request(requests.Request("GET", f"{TEST_URI}v1/config")) + + auth_header = prepared.headers["Authorization"] + assert auth_header.startswith("AWS4-HMAC-SHA256 Credential=") + assert prepared.headers["Original-Authorization"] == f"Bearer {existing_token}" + assert prepared.headers["x-amz-content-sha256"] == EMPTY_BODY_SHA256 + # Verify the signature format: Credential, SignedHeaders, Signature + assert "Credential=" in auth_header + assert "SignedHeaders=" in auth_header + assert "Signature=" in auth_header + # x-amz-content-sha256 should be in signed headers + assert "x-amz-content-sha256" in auth_header + + +def test_sigv4_sign_request_with_body(rest_mock: Mocker) -> None: + existing_token = "existing_token" + + catalog = RestCatalog( + "rest", + **{ + "uri": TEST_URI, + "token": existing_token, + "rest.sigv4-enabled": "true", + "rest.signing-region": "us-west-2", + "client.access-key-id": "id", + "client.secret-access-key": "secret", + }, + ) + + prepared = catalog._session.prepare_request( + requests.Request( + "POST", + f"{TEST_URI}v1/namespaces", + data={"namespace": "asdfasd"}, + ) + ) + + auth_header = prepared.headers["Authorization"] + assert auth_header.startswith("AWS4-HMAC-SHA256 Credential=") + assert "SignedHeaders=" in auth_header + # Conflicting Authorization header is relocated + assert prepared.headers["Original-Authorization"] == f"Bearer {existing_token}" + # Non-empty body should have base64-encoded SHA256 + content_sha256 = prepared.headers["x-amz-content-sha256"] + assert prepared.body is not None + body_bytes = prepared.body.encode("utf-8") if isinstance(prepared.body, str) else prepared.body + expected_sha256 = base64.b64encode(hashlib.sha256(body_bytes).digest()).decode() + assert content_sha256 == expected_sha256 + # x-amz-content-sha256 should be in signed headers + assert "x-amz-content-sha256" in auth_header + + +def test_sigv4_content_sha256_with_bytes_body(rest_mock: Mocker) -> None: + existing_token = "existing_token" + + catalog = RestCatalog( + "rest", + **{ + "uri": TEST_URI, + "token": existing_token, + "rest.sigv4-enabled": "true", + "rest.signing-region": "us-west-2", + "client.access-key-id": "id", + "client.secret-access-key": "secret", + }, + ) + + body_content = b'{"namespace": "test_namespace"}' + prepared = catalog._session.prepare_request( + requests.Request( + "POST", + f"{TEST_URI}v1/namespaces", + data=body_content, + ) + ) + + assert prepared.headers["Authorization"].startswith("AWS4-HMAC-SHA256 Credential=") + assert "SignedHeaders=" in prepared.headers["Authorization"] + content_sha256 = prepared.headers["x-amz-content-sha256"] + expected_sha256 = base64.b64encode(hashlib.sha256(body_content).digest()).decode() + assert content_sha256 == expected_sha256 + + +def test_sigv4_conflicting_sigv4_headers(rest_mock: Mocker) -> None: + from pyiceberg.catalog.rest.auth import EMPTY_BODY_SHA256 + + catalog = RestCatalog( + "rest", + **{ + "uri": TEST_URI, + "rest.sigv4-enabled": "true", + "rest.signing-region": "us-west-2", + "client.access-key-id": "id", + "client.secret-access-key": "secret", + }, + ) + + # Build an unsigned prepared request, then inject conflicting SigV4 headers. + prepared = requests.Request("GET", f"{TEST_URI}v1/config").prepare() + prepared.headers["x-amz-content-sha256"] = "fake" + prepared.headers["X-Amz-Date"] = "fake" + + # session.auth is the AuthManagerAdapter; calling it signs the request. + auth = catalog._session.auth + assert isinstance(auth, AuthManagerAdapter) + auth(prepared) + + # Matching Java SDK: conflicting headers are relocated with "Original-" prefix + assert prepared.headers.get("Original-x-amz-content-sha256") == "fake" + assert prepared.headers.get("Original-X-Amz-Date") == "fake" + # SigV4 headers are set correctly after signing + assert prepared.headers["Authorization"].startswith("AWS4-HMAC-SHA256 Credential=") + assert prepared.headers["x-amz-content-sha256"] == EMPTY_BODY_SHA256 + assert "X-Amz-Date" in prepared.headers + + +def test_sigv4_canonical_request_uses_hex_payload(rest_mock: Mocker) -> None: + """Verify that the canonical request uses hex-encoded payload hash, not the base64 header value.""" + from typing import Any + + from botocore.auth import SigV4Auth + + catalog = RestCatalog( + "rest", + **{ + "uri": TEST_URI, + "token": "token", + "rest.sigv4-enabled": "true", + "rest.signing-region": "us-west-2", + "client.access-key-id": "id", + "client.secret-access-key": "secret", + }, + ) + + body_content = b'{"namespace": "test"}' + + # Capture the canonical request string during signing + captured_canonical = [] + original_add_auth = SigV4Auth.add_auth + + def capturing_add_auth(self: Any, request: Any) -> None: + captured_canonical.append(self.canonical_request(request)) + original_add_auth(self, request) + + # Signing now happens inside prepare_request (via session.auth). + with patch.object(SigV4Auth, "add_auth", capturing_add_auth): + prepared = catalog._session.prepare_request( + requests.Request( + "POST", + f"{TEST_URI}v1/namespaces", + data=body_content, + ) + ) + + assert len(captured_canonical) == 1 + canonical_lines = captured_canonical[0].split("\n") + # Last line of canonical request is the payload hash + payload_hash = canonical_lines[-1] + # Must be hex-encoded (64 hex chars), not base64 + assert len(payload_hash) == 64 + assert payload_hash == hashlib.sha256(body_content).hexdigest() + # Meanwhile the header is base64-encoded + assert prepared.headers["x-amz-content-sha256"] == base64.b64encode(hashlib.sha256(body_content).digest()).decode() + + +def test_sigv4_content_sha256_matches_iceberg_java_reference(rest_mock: Mocker) -> None: + """Pin byte-for-byte equivalence with Iceberg Java TestRESTSigV4AuthSession (L121, L177).""" + java_reference_body = b'{"namespace":["ns"],"properties":{}}' + java_reference_base64 = "yc5oAKPWjHY4sW8XQq0l/3aNrrXJKBycVFNnDEGMfww=" + java_reference_empty_hex = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + + catalog = RestCatalog( + "rest", + **{ + "uri": TEST_URI, + "rest.sigv4-enabled": "true", + "rest.signing-region": "us-east-1", + "client.access-key-id": "id", + "client.secret-access-key": "secret", + }, + ) + + # Non-empty body: must match Java's base64 reference value exactly + prepared_with_body = catalog._session.prepare_request( + requests.Request("POST", f"{TEST_URI}v1/namespaces", data=java_reference_body) + ) + assert prepared_with_body.headers["x-amz-content-sha256"] == java_reference_base64 + + # Empty body: must match Java's hex reference value exactly + prepared_empty = catalog._session.prepare_request(requests.Request("GET", f"{TEST_URI}v1/config")) + assert prepared_empty.headers["x-amz-content-sha256"] == java_reference_empty_hex + + +def test_sigv4_unsupported_body_type_raises() -> None: + """Unsupported body types (e.g. file-like) raise a clear error rather than crashing in hashlib.""" + import boto3 + + from pyiceberg.catalog.rest.auth import NoopAuthManager, SigV4AuthManager + + boto_session = boto3.Session( + aws_access_key_id="id", + aws_secret_access_key="secret", + region_name="us-east-1", + ) + manager = SigV4AuthManager( + delegate=NoopAuthManager(), + boto_session=boto_session, + region="us-east-1", + service="execute-api", + ) + + prepared = requests.Request("POST", f"{TEST_URI}v1/namespaces").prepare() + # Inject an unsupported body type (a list — not str/bytes) + prepared.body = ["not", "a", "valid", "body"] # type: ignore[assignment] + + with pytest.raises(TypeError, match="Unsupported request body type for SigV4 signing"): + manager.sign_request(prepared) + + +def test_sigv4_uses_client_profile_name(rest_mock: Mocker) -> None: + import boto3 + + # Use a real boto3.Session for credential resolution (signing runs during + # config fetch), but spy on the constructor to assert the profile is honored. + real_session = boto3.Session( + aws_access_key_id="id", + aws_secret_access_key="secret", + region_name="us-west-2", + ) + + with patch("boto3.Session", return_value=real_session) as mock_session: + RestCatalog( + "rest", + **{ + "uri": TEST_URI, + "token": "token", + "rest.sigv4-enabled": "true", + "rest.signing-region": "us-west-2", + "client.profile-name": "rest-profile", + }, + ) + + mock_session.assert_called_with( + profile_name="rest-profile", + region_name=None, + botocore_session=None, + aws_access_key_id=None, + aws_secret_access_key=None, + aws_session_token=None, + ) diff --git a/uv.lock b/uv.lock index e9515cbebc..485021af8b 100644 --- a/uv.lock +++ b/uv.lock @@ -4673,6 +4673,7 @@ ray = [ ] rest-sigv4 = [ { name = "boto3" }, + { name = "botocore" }, ] s3fs = [ { name = "s3fs" }, @@ -4739,6 +4740,7 @@ requires-dist = [ { name = "boto3", marker = "extra == 'dynamodb'", specifier = ">=1.24.59" }, { name = "boto3", marker = "extra == 'glue'", specifier = ">=1.24.59" }, { name = "boto3", marker = "extra == 'rest-sigv4'", specifier = ">=1.24.59" }, + { name = "botocore", marker = "extra == 'rest-sigv4'", specifier = "<2" }, { name = "cachetools", specifier = ">=5.5,<8.0" }, { name = "click", specifier = ">=7.1.1,<9.0.0" }, { name = "daft", marker = "extra == 'daft'", specifier = ">=0.7.10" },