diff --git a/src/google/adk/integrations/agent_registry/agent_registry.py b/src/google/adk/integrations/agent_registry/agent_registry.py index a486215151..fc696fcc2c 100644 --- a/src/google/adk/integrations/agent_registry/agent_registry.py +++ b/src/google/adk/integrations/agent_registry/agent_registry.py @@ -19,6 +19,7 @@ from collections.abc import Generator from enum import Enum import logging +import os import re from typing import Any from typing import Callable @@ -39,9 +40,11 @@ from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams from google.adk.tools.mcp_tool.mcp_toolset import McpToolset import google.auth -import google.auth.transport.requests +from google.auth.transport import mtls +from google.auth.transport import requests as requests_auth import httpx from mcp import StdioServerParameters +import requests from typing_extensions import override # pylint: disable=g-import-not-at-top @@ -61,6 +64,9 @@ logger = logging.getLogger("google_adk." + __name__) AGENT_REGISTRY_BASE_URL = "https://agentregistry.googleapis.com/v1alpha" +AGENT_REGISTRY_MTLS_BASE_URL = ( + "https://agentregistry.mtls.googleapis.com/v1alpha" +) _TRANSPORT_MAPPING = { "HTTP_JSON": A2ATransport.http_json, @@ -120,6 +126,14 @@ async def get_tools( return tools +class _MtlsEndpoint(Enum): + """The mTLS endpoint setting.""" + + AUTO = "auto" + ALWAYS = "always" + NEVER = "never" + + class _ProtocolType(str, Enum): """Supported agent protocol types.""" @@ -198,6 +212,22 @@ def __init__( raise RuntimeError( f"Failed to get default Google Cloud credentials: {e}" ) from e + # Instantiate and configure AuthorizedSession once during initialization + self._session = requests_auth.AuthorizedSession( + credentials=self._credentials + ) + + use_client_cert = _use_client_cert_effective() + client_cert_source = None + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + self._session.configure_mtls_channel(client_cert_source) + + self._base_url = _get_agent_registry_base_url(client_cert_source) def _get_auth_headers(self) -> Dict[str, str]: """Refreshes credentials and returns authorization headers.""" @@ -224,23 +254,29 @@ def _make_request( self, path: str, params: Dict[str, Any] | None = None ) -> Dict[str, Any]: """Helper function to make GET requests to the Agent Registry API.""" + if path.startswith("projects/"): - url = f"{AGENT_REGISTRY_BASE_URL}/{path}" + url = f"{self._base_url}/{path}" else: - url = f"{AGENT_REGISTRY_BASE_URL}/{self._base_path}/{path}" + url = f"{self._base_url}/{self._base_path}/{path}" + headers = {} + quota_project_id = ( + getattr(self._credentials, "quota_project_id", None) or self.project_id + ) + if quota_project_id: + headers["x-goog-user-project"] = quota_project_id try: - headers = self._get_auth_headers() - with httpx.Client() as client: - response = client.get(url, headers=headers, params=params) - response.raise_for_status() - return response.json() - except httpx.HTTPStatusError as e: + # Using AuthorizedSession for internal API calls to handle mTLS/Auth. + response = self._session.get(url, headers=headers, params=params) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: raise RuntimeError( f"API request failed with status {e.response.status_code}:" f" {e.response.text}" ) from e - except httpx.RequestError as e: + except requests.exceptions.RequestException as e: raise RuntimeError(f"API request failed (network error): {e}") from e except Exception as e: raise RuntimeError(f"API request failed: {e}") from e @@ -520,3 +556,33 @@ def get_remote_a2a_agent( description=description, httpx_client=httpx_client, ) + + +def _use_client_cert_effective() -> bool: + """Returns whether client certificate should be used for mTLS.""" + try: + return bool(mtls.should_use_client_cert()) + except (ImportError, AttributeError): + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + return use_client_cert_str == "true" + + +def _get_agent_registry_base_url(client_cert_source: Any | None = None) -> str: + """Returns the base URL based on mTLS configuration and cert availability.""" + use_mtls_endpoint_str = os.getenv( + "GOOGLE_API_USE_MTLS_ENDPOINT", _MtlsEndpoint.AUTO.value + ).lower() + + try: + use_mtls_endpoint = _MtlsEndpoint(use_mtls_endpoint_str) + except ValueError: + use_mtls_endpoint = _MtlsEndpoint.AUTO + + if (use_mtls_endpoint is _MtlsEndpoint.ALWAYS) or ( + use_mtls_endpoint is _MtlsEndpoint.AUTO and client_cert_source + ): + return AGENT_REGISTRY_MTLS_BASE_URL + + return AGENT_REGISTRY_BASE_URL diff --git a/tests/unittests/integrations/agent_registry/test_agent_registry.py b/tests/unittests/integrations/agent_registry/test_agent_registry.py index f4ba47cf25..f19e52ec04 100644 --- a/tests/unittests/integrations/agent_registry/test_agent_registry.py +++ b/tests/unittests/integrations/agent_registry/test_agent_registry.py @@ -13,6 +13,7 @@ # limitations under the License. +import os from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch @@ -26,28 +27,36 @@ from google.adk.integrations.agent_registry.agent_registry import _ProtocolType from google.adk.telemetry.tracing import GCP_MCP_SERVER_DESTINATION_ID from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +from google.auth.transport import requests as requests_auth import httpx from mcp import ClientSession from mcp.types import ListToolsResult from mcp.types import Tool import pytest +import requests class TestAgentRegistry: @pytest.fixture def registry(self): - with patch("google.auth.default", return_value=(MagicMock(), "project-id")): - return AgentRegistry(project_id="test-project", location="global") + mock_creds = MagicMock() + mock_creds.quota_project_id = None + with patch( + "google.auth.default", return_value=(mock_creds, "project-id") + ), patch( + "google.auth.transport.requests.AuthorizedSession" + ) as mock_session_class: + registry = AgentRegistry(project_id="test-project", location="global") + return registry @pytest.mark.asyncio - @patch("httpx.Client") @patch( "google.adk.tools.mcp_tool.mcp_session_manager.MCPSessionManager.create_session", new_callable=AsyncMock, ) async def test_get_mcp_toolset_adds_destination_id( - self, mock_create_session, mock_httpx, registry + self, mock_create_session, registry ): """Test that tools from get_mcp_toolset have the destination ID.""" # Arrange @@ -63,9 +72,7 @@ async def test_get_mcp_toolset_adds_destination_id( "protocolBinding": "JSONRPC", }], } - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_api_response - ) + registry._session.get.return_value = mock_api_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -109,13 +116,12 @@ async def test_get_mcp_toolset_adds_destination_id( ) @pytest.mark.asyncio - @patch("httpx.Client") @patch( "google.adk.tools.mcp_tool.mcp_session_manager.MCPSessionManager.create_session", new_callable=AsyncMock, ) async def test_get_mcp_toolset_handles_missing_destination_id( - self, mock_create_session, mock_httpx, registry + self, mock_create_session, registry ): """Test get_mcp_toolset when the destination ID is missing.""" # Arrange @@ -129,9 +135,7 @@ async def test_get_mcp_toolset_handles_missing_destination_id( "protocolBinding": "JSONRPC", }], } - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_api_response - ) + registry._session.get.return_value = mock_api_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -258,14 +262,11 @@ def test_get_connection_uri_returns_none_if_no_url_in_interfaces( assert version is None assert binding is None - @patch("httpx.Client") - def test_list_agents(self, mock_httpx, registry): + def test_list_agents(self, registry): mock_response = MagicMock() mock_response.json.return_value = {"agents": []} mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + registry._session.get.return_value = mock_response # Mock auth refresh registry._credentials.token = "token" @@ -274,14 +275,11 @@ def test_list_agents(self, mock_httpx, registry): agents = registry.list_agents() assert agents == {"agents": []} - @patch("httpx.Client") - def test_get_mcp_server(self, mock_httpx, registry): + def test_get_mcp_server(self, registry): mock_response = MagicMock() mock_response.json.return_value = {"name": "test-mcp"} mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + registry._session.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -289,14 +287,11 @@ def test_get_mcp_server(self, mock_httpx, registry): server = registry.get_mcp_server("test-mcp") assert server == {"name": "test-mcp"} - @patch("httpx.Client") - def test_list_endpoints(self, mock_httpx, registry): + def test_list_endpoints(self, registry): mock_response = MagicMock() mock_response.json.return_value = {"endpoints": []} mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + registry._session.get.return_value = mock_response # Mock auth refresh registry._credentials.token = "token" @@ -305,14 +300,11 @@ def test_list_endpoints(self, mock_httpx, registry): endpoints = registry.list_endpoints() assert endpoints == {"endpoints": []} - @patch("httpx.Client") - def test_get_endpoint(self, mock_httpx, registry): + def test_get_endpoint(self, registry): mock_response = MagicMock() mock_response.json.return_value = {"name": "test-endpoint"} mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + registry._session.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -329,36 +321,40 @@ def test_get_endpoint(self, mock_httpx, registry): ("https://mcp.googleapis.com/v1", True, True), ], ) - @patch("httpx.Client") def test_get_mcp_toolset_auth_headers( - self, mock_httpx, registry, url, expected_auth, use_custom_provider + self, + registry, + url, + expected_auth, + use_custom_provider, ): - mock_response = MagicMock() - mock_response.json.return_value = { + mock_api_response = MagicMock() + mock_api_response.json.return_value = { "displayName": "TestPrefix", "interfaces": [{ "url": url, "protocolBinding": "JSONRPC", }], } - mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_api_response.raise_for_status = MagicMock() + + mock_creds = MagicMock() + mock_creds.quota_project_id = None if use_custom_provider: custom_header_provider = lambda context: { "Authorization": "Bearer custom_token" } with patch( - "google.auth.default", return_value=(MagicMock(), "project-id") - ): + "google.auth.default", return_value=(mock_creds, "project-id") + ), patch("google.auth.transport.requests.AuthorizedSession"): registry = AgentRegistry( project_id="test-project", location="global", header_provider=custom_header_provider, ) + registry._session.get.return_value = mock_api_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -375,8 +371,7 @@ def test_get_mcp_toolset_auth_headers( else: assert "Authorization" not in headers - @patch("httpx.Client") - def test_get_mcp_toolset_with_auth(self, mock_httpx, registry): + def test_get_mcp_toolset_with_auth(self, registry): mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestPrefix", @@ -386,9 +381,7 @@ def test_get_mcp_toolset_with_auth(self, mock_httpx, registry): }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + registry._session.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -408,9 +401,8 @@ def test_get_mcp_toolset_with_auth(self, mock_httpx, registry): assert auth_config.auth_scheme == auth_scheme assert auth_config.raw_auth_credential == auth_credential - @patch("httpx.Client") def test_get_mcp_toolset_with_auth_blocks_gcp_headers( - self, mock_httpx, registry + self, registry ): mock_response = MagicMock() mock_response.json.return_value = { @@ -421,9 +413,7 @@ def test_get_mcp_toolset_with_auth_blocks_gcp_headers( }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + registry._session.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -442,8 +432,7 @@ def test_get_mcp_toolset_with_auth_blocks_gcp_headers( headers = toolset._header_provider(MagicMock()) assert "Authorization" not in headers - @patch("httpx.Client") - def test_get_remote_a2a_agent(self, mock_httpx, registry): + def test_get_remote_a2a_agent(self, registry): mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestAgent", @@ -460,9 +449,7 @@ def test_get_remote_a2a_agent(self, mock_httpx, registry): "skills": [{"id": "s1", "name": "Skill 1", "description": "Desc 1"}], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + registry._session.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -478,8 +465,7 @@ def test_get_remote_a2a_agent(self, mock_httpx, registry): assert agent._agent_card.preferred_transport == A2ATransport.http_json assert agent._agent_card.protocol_version == "0.4.0" - @patch("httpx.Client") - def test_get_remote_a2a_agent_defaults(self, mock_httpx, registry): + def test_get_remote_a2a_agent_defaults(self, registry): mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestAgent", @@ -493,9 +479,7 @@ def test_get_remote_a2a_agent_defaults(self, mock_httpx, registry): }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + registry._session.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -505,8 +489,7 @@ def test_get_remote_a2a_agent_defaults(self, mock_httpx, registry): assert agent._agent_card.preferred_transport == A2ATransport.http_json assert agent._agent_card.protocol_version == "0.3.0" - @patch("httpx.Client") - def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry): + def test_get_remote_a2a_agent_with_card(self, registry): mock_response = MagicMock() mock_response.json.return_value = { "name": "projects/p/locations/l/agents/a", @@ -530,9 +513,7 @@ def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry): }, } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + registry._session.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -547,8 +528,9 @@ def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry): assert len(agent._agent_card.skills) == 1 assert agent._agent_card.skills[0].name == "S1" - @patch("httpx.Client") - def test_get_remote_a2a_agent_with_httpx_client(self, mock_httpx, registry): + def test_get_remote_a2a_agent_with_httpx_client( + self, registry + ): mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestAgent", @@ -562,9 +544,7 @@ def test_get_remote_a2a_agent_with_httpx_client(self, mock_httpx, registry): }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + registry._session.get.return_value = mock_response custom_client = httpx.AsyncClient() agent = registry.get_remote_a2a_agent( @@ -572,9 +552,8 @@ def test_get_remote_a2a_agent_with_httpx_client(self, mock_httpx, registry): ) assert agent._httpx_client is custom_client - @patch("httpx.Client") def test_get_remote_a2a_agent_configures_transports( - self, mock_httpx, registry + self, registry ): mock_response = MagicMock() mock_response.json.return_value = { @@ -588,9 +567,7 @@ def test_get_remote_a2a_agent_configures_transports( }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + registry._session.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -616,15 +593,16 @@ def test_get_auth_headers_fallback_to_project_id(self, registry): assert headers["Authorization"] == "Bearer fake-token" assert headers["x-goog-user-project"] == "test-project" - @patch("httpx.Client") - def test_make_request_raises_http_status_error(self, mock_httpx, registry): + def test_make_request_raises_http_status_error( + self, registry + ): mock_response = MagicMock() mock_response.status_code = 404 mock_response.text = "Not Found" - error = httpx.HTTPStatusError( + error = requests.exceptions.HTTPError( "Error", request=MagicMock(), response=mock_response ) - mock_httpx.return_value.__enter__.return_value.get.side_effect = error + registry._session.get.side_effect = error registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -634,10 +612,13 @@ def test_make_request_raises_http_status_error(self, mock_httpx, registry): ): registry._make_request("test-path") - @patch("httpx.Client") - def test_make_request_raises_request_error(self, mock_httpx, registry): - error = httpx.RequestError("Connection failed", request=MagicMock()) - mock_httpx.return_value.__enter__.return_value.get.side_effect = error + def test_make_request_raises_request_error( + self, registry + ): + error = requests.exceptions.RequestException( + "Connection failed", request=MagicMock() + ) + registry._session.get.side_effect = error registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -647,11 +628,10 @@ def test_make_request_raises_request_error(self, mock_httpx, registry): ): registry._make_request("test-path") - @patch("httpx.Client") - def test_make_request_raises_generic_exception(self, mock_httpx, registry): - mock_httpx.return_value.__enter__.return_value.get.side_effect = Exception( - "Generic error" - ) + def test_make_request_raises_generic_exception( + self, registry + ): + registry._session.get.side_effect = Exception("Generic error") registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -741,3 +721,114 @@ def side_effect(*args, **kwargs): == "projects/123/locations/l/authProviders/ap-789" ) assert toolset._auth_scheme.continue_uri == "https://override.com/continue" + + +class TestAgentRegistryMtls: + + @pytest.fixture + def registry(self): + with patch( + "google.auth.default", return_value=(MagicMock(), "test-project") + ), patch("google.auth.transport.requests.AuthorizedSession"), patch( + "google.adk.integrations.agent_registry.agent_registry._use_client_cert_effective", + return_value=False, + ): + return AgentRegistry(project_id="test-project", location="global") + + @patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ) + def test_make_request_uses_authorized_session_no_mtls( + self, mock_has_cert, registry + ): + """Verifies that AuthorizedSession is used for standard requests.""" + mock_session = registry._session + mock_response = MagicMock() + mock_response.json.return_value = {"key": "value"} + mock_session.get.return_value = mock_response + + result = registry._make_request("test-path") + + # Assert session usage + mock_session.get.assert_called_once() + assert mock_session.configure_mtls_channel.call_count == 0 + assert result == {"key": "value"} + + @patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ) + @patch("google.auth.transport.mtls.default_client_cert_source") + @patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) + def test_make_request_configures_mtls(self, mock_cert_source, registry): + """Verifies that mTLS is configured when supported and enabled.""" + mock_cert_source.return_value = lambda: (b"cert", b"key") + + with patch( + "google.auth.default", return_value=(MagicMock(), "test-project") + ), patch( + "google.adk.integrations.agent_registry.agent_registry._use_client_cert_effective", + return_value=True, + ), patch( + "google.auth.transport.requests.AuthorizedSession" + ) as mock_session_class: + # Instantiate inside the test after enabling mTLS patches + registry = AgentRegistry(project_id="test-project", location="global") + mock_session = registry._session + + # Mock successful response + mock_response = MagicMock() + mock_response.json.return_value = {"key": "value"} + mock_session.get.return_value = mock_response + + registry._make_request("test-path") + + # Verify mTLS configuration and endpoint + mock_session.configure_mtls_channel.assert_called_once() + args, kwargs = mock_session.get.call_args + assert "agentregistry.mtls.googleapis.com" in args[0] + + @pytest.mark.parametrize( + "env_val, has_cert, expected", + [ + ("true", True, True), + ("true", False, True), + ("false", True, False), + ("false", False, False), + ], + ) + def test_use_client_cert_effective( + self, env_val, has_cert, expected, registry + ): + """Tests the logic for enabling mTLS based on env vars and cert availability.""" + with patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": env_val}): + with patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=has_cert, + ): + from google.adk.integrations.agent_registry.agent_registry import _use_client_cert_effective + + assert _use_client_cert_effective() == expected + + def test_get_agent_registry_base_url(self, registry): + """Verifies correct base URL selection for mTLS vs non-mTLS.""" + from google.adk.integrations.agent_registry.agent_registry import _get_agent_registry_base_url + + # Non-mTLS + assert "agentregistry.googleapis.com" in _get_agent_registry_base_url(None) + + # mTLS + assert "agentregistry.mtls.googleapis.com" in _get_agent_registry_base_url( + lambda: True + ) + + def test_make_request_error_handling(self, registry): + """Ensures exceptions from AuthorizedSession are handled gracefully.""" + mock_session = registry._session + mock_session.get.side_effect = Exception("Connection error") + + with pytest.raises( + RuntimeError, match="API request failed: Connection error" + ): + registry._make_request("test-path")