From 6168c49dfe551956272728d3f6e08ffe3357a56d Mon Sep 17 00:00:00 2001 From: ecantn Date: Mon, 1 Jun 2026 11:51:44 +0200 Subject: [PATCH 1/3] fix(agent_tool): wrap input_schema JSON in ReAct prompt; propagate tool_choice to LiteLLM When AgentTool uses input_schema, the inner agent receives a raw JSON blob that causes Claude models to skip the tool-calling loop (ReAct). Fix by wrapping the payload in a natural-language instruction. Also propagate tool_config.function_calling_config.mode to LiteLLM's tool_choice parameter so callers can enforce tool_choice='required'. Addresses #773. Fixes: #5926 --- src/google/adk/models/lite_llm.py | 26 ++++++++++++++++++++++---- src/google/adk/tools/agent_tool.py | 8 +++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 4d4f93c88f..a8cc267b98 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -1883,6 +1883,7 @@ async def _get_completion_inputs( Optional[List[Dict]], Optional[Dict[str, Any]], Optional[Dict], + Optional[str], ]: """Converts an LlmRequest to litellm inputs and extracts generation params. @@ -1891,8 +1892,8 @@ async def _get_completion_inputs( model: The model string to use for determining provider-specific behavior. Returns: - The litellm inputs (message list, tool dictionary, response format and - generation params). + The litellm inputs (message list, tool dictionary, response format, + generation params, and tool_choice). """ _ensure_litellm_imported() @@ -1967,7 +1968,21 @@ async def _get_completion_inputs( if not generation_params: generation_params = None - return messages, tools, response_format, generation_params + # 5. Extract tool_choice from tool_config + tool_choice: Optional[str] = None + if ( + llm_request.config + and llm_request.config.tool_config + and llm_request.config.tool_config.function_calling_config + ): + mode = llm_request.config.tool_config.function_calling_config.mode + if mode == types.FunctionCallingConfigMode.ANY: + tool_choice = "required" + elif mode == types.FunctionCallingConfigMode.NONE: + tool_choice = "none" + # AUTO → None (provider default) + + return messages, tools, response_format, generation_params, tool_choice def _build_function_declaration_log( @@ -2228,7 +2243,7 @@ async def generate_content_async( logger.debug(_build_request_log(llm_request)) effective_model = llm_request.model or self.model - messages, tools, response_format, generation_params = ( + messages, tools, response_format, generation_params, tool_choice = ( await _get_completion_inputs(llm_request, effective_model) ) normalized_messages = _normalize_ollama_chat_messages( @@ -2260,6 +2275,9 @@ async def generate_content_async( if generation_params: completion_args.update(generation_params) + if tool_choice is not None: + completion_args["tool_choice"] = tool_choice + if stream: text = "" reasoning_parts: List[types.Part] = [] diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 1768861dba..8efd6baf01 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -217,11 +217,17 @@ async def run_async( input_schema = _get_input_schema(self.agent) if input_schema: input_value = input_schema.model_validate(args) + json_payload = input_value.model_dump_json(exclude_none=True) content = types.Content( role='user', parts=[ types.Part.from_text( - text=input_value.model_dump_json(exclude_none=True) + text=( + 'Process the following structured request. Use your' + ' available tools as needed to gather information or' + ' perform actions before producing the final' + ' response.\n\nRequest:\n' + json_payload + ) ) ], ) From 44ba0dea620adbb50c96d0a55952c75ce2b51611 Mon Sep 17 00:00:00 2001 From: ecantn Date: Mon, 1 Jun 2026 12:04:33 +0200 Subject: [PATCH 2/3] test: add unit tests for AgentTool input_schema wrapping and LiteLLM tool_choice propagation - Update existing _get_completion_inputs call sites to handle new 5-tuple return value (adds tool_choice as 5th element) - Add 3 tests for AgentTool.run_async: verifies message is passed verbatim without input_schema, and is wrapped in a natural-language instruction with input_schema (PR #5924 fix) - Add 8 tests for LiteLLM tool_choice propagation: covers _get_completion_inputs returning correct tool_choice for ANY/NONE/AUTO modes and None when no tool_config; covers generate_content_async correctly including/omitting tool_choice in completion_args --- tests/unittests/models/test_litellm.py | 239 +++++++++++++++++++++-- tests/unittests/tools/test_agent_tool.py | 131 +++++++++++++ 2 files changed, 358 insertions(+), 12 deletions(-) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 216866602f..094df5111f 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -258,7 +258,7 @@ async def test_get_completion_inputs_formats_pydantic_schema_for_litellm(): config=types.GenerateContentConfig(response_schema=_StructuredOutput) ) - _, _, response_format, _ = await _get_completion_inputs( + _, _, response_format, _, _ = await _get_completion_inputs( llm_request, model="gemini/gemini-2.5-flash" ) @@ -550,7 +550,7 @@ async def test_get_completion_inputs_uses_openai_format_for_openai_model(): config=types.GenerateContentConfig(response_schema=_StructuredOutput), ) - _, _, response_format, _ = await _get_completion_inputs( + _, _, response_format, _, _ = await _get_completion_inputs( llm_request, model="gpt-4o-mini" ) @@ -570,7 +570,7 @@ async def test_get_completion_inputs_uses_gemini_format_for_gemini_model(): config=types.GenerateContentConfig(response_schema=_StructuredOutput), ) - _, _, response_format, _ = await _get_completion_inputs( + _, _, response_format, _, _ = await _get_completion_inputs( llm_request, model="gemini/gemini-2.5-flash" ) @@ -590,7 +590,7 @@ async def test_get_completion_inputs_uses_passed_model_for_response_format(): ) # Pass OpenAI model explicitly - should use json_schema format - _, _, response_format, _ = await _get_completion_inputs( + _, _, response_format, _, _ = await _get_completion_inputs( llm_request, model="gpt-4o-mini" ) @@ -615,7 +615,7 @@ async def test_get_completion_inputs_uses_passed_model_for_gemini_format(): ) # Pass Gemini model explicitly - should use response_schema format - _, _, response_format, _ = await _get_completion_inputs( + _, _, response_format, _, _ = await _get_completion_inputs( llm_request, model="gemini/gemini-2.5-flash" ) @@ -645,7 +645,7 @@ async def test_get_completion_inputs_inserts_missing_tool_results(): llm_request = LlmRequest( contents=[user_content, assistant_content, followup_user] ) - messages, _, _, _ = await _get_completion_inputs( + messages, _, _, _, _ = await _get_completion_inputs( llm_request, model="openai/gpt-4o" ) @@ -4195,7 +4195,7 @@ async def test_get_completion_inputs_generation_params(): ), ) - _, _, _, generation_params = await _get_completion_inputs( + _, _, _, generation_params, _ = await _get_completion_inputs( req, model="gpt-4o-mini" ) assert generation_params["temperature"] == 0.33 @@ -4220,7 +4220,7 @@ async def test_get_completion_inputs_empty_generation_params(): config=types.GenerateContentConfig(), ) - _, _, _, generation_params = await _get_completion_inputs( + _, _, _, generation_params, _ = await _get_completion_inputs( req, model="gpt-4o-mini" ) assert generation_params is None @@ -4238,7 +4238,7 @@ async def test_get_completion_inputs_minimal_config(): ), ) - _, _, _, generation_params = await _get_completion_inputs( + _, _, _, generation_params, _ = await _get_completion_inputs( req, model="gpt-4o-mini" ) assert generation_params is None @@ -4257,7 +4257,7 @@ async def test_get_completion_inputs_partial_generation_params(): ), ) - _, _, _, generation_params = await _get_completion_inputs( + _, _, _, generation_params, _ = await _get_completion_inputs( req, model="gpt-4o-mini" ) assert generation_params is not None @@ -4620,7 +4620,7 @@ async def test_get_completion_inputs_openai_file_upload(mocker): config=types.GenerateContentConfig(tools=[]), ) - messages, tools, response_format, generation_params = ( + messages, tools, response_format, generation_params, _ = ( await _get_completion_inputs(llm_request, model="openai/gpt-4o") ) @@ -4659,7 +4659,7 @@ async def test_get_completion_inputs_non_openai_no_file_upload(mocker): config=types.GenerateContentConfig(tools=[]), ) - messages, tools, response_format, generation_params = ( + messages, tools, response_format, generation_params, _ = ( await _get_completion_inputs(llm_request, model="anthropic/claude-3-opus") ) @@ -4987,3 +4987,218 @@ async def test_generate_content_async_skips_request_log_build_above_debug( assert mock_build.called is should_call finally: litellm_logger.setLevel(original_level) + + +# --------------------------------------------------------------------------- +# Tests for tool_choice propagation (PR #5924) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_completion_inputs_tool_choice_none_without_tool_config(): + """tool_choice must be None when no tool_config is present.""" + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ) + ], + ) + + _, _, _, _, tool_choice = await _get_completion_inputs( + llm_request, model="openai/gpt-4o" + ) + + assert tool_choice is None + + +@pytest.mark.asyncio +async def test_get_completion_inputs_tool_choice_required_for_any_mode(): + """tool_choice must be 'required' when mode=ANY.""" + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ) + ], + config=types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.ANY + ) + ) + ), + ) + + _, _, _, _, tool_choice = await _get_completion_inputs( + llm_request, model="openai/gpt-4o" + ) + + assert tool_choice == "required" + + +@pytest.mark.asyncio +async def test_get_completion_inputs_tool_choice_none_for_none_mode(): + """tool_choice must be 'none' when mode=NONE.""" + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ) + ], + config=types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.NONE + ) + ) + ), + ) + + _, _, _, _, tool_choice = await _get_completion_inputs( + llm_request, model="openai/gpt-4o" + ) + + assert tool_choice == "none" + + +@pytest.mark.asyncio +async def test_get_completion_inputs_tool_choice_none_for_auto_mode(): + """tool_choice must be None (provider default) when mode=AUTO.""" + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ) + ], + config=types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.AUTO + ) + ) + ), + ) + + _, _, _, _, tool_choice = await _get_completion_inputs( + llm_request, model="openai/gpt-4o" + ) + + assert tool_choice is None + + +@pytest.mark.asyncio +async def test_generate_content_async_propagates_tool_choice_required( + mock_acompletion, mock_completion +): + """generate_content_async must pass tool_choice='required' to acompletion.""" + llm_client = MockLLMClient(mock_acompletion, mock_completion) + lite_llm_instance = LiteLlm(model="openai/gpt-4o", llm_client=llm_client) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Call a tool")] + ) + ], + config=types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.ANY + ) + ) + ), + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert kwargs.get("tool_choice") == "required" + + +@pytest.mark.asyncio +async def test_generate_content_async_propagates_tool_choice_none_mode( + mock_acompletion, mock_completion +): + """generate_content_async must pass tool_choice='none' to acompletion for NONE mode.""" + llm_client = MockLLMClient(mock_acompletion, mock_completion) + lite_llm_instance = LiteLlm(model="openai/gpt-4o", llm_client=llm_client) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="No tools please")] + ) + ], + config=types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.NONE + ) + ) + ), + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert kwargs.get("tool_choice") == "none" + + +@pytest.mark.asyncio +async def test_generate_content_async_omits_tool_choice_for_auto_mode( + mock_acompletion, mock_completion +): + """generate_content_async must NOT include tool_choice in completion_args for AUTO.""" + llm_client = MockLLMClient(mock_acompletion, mock_completion) + lite_llm_instance = LiteLlm(model="openai/gpt-4o", llm_client=llm_client) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Hi")] + ) + ], + config=types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.AUTO + ) + ) + ), + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert "tool_choice" not in kwargs + + +@pytest.mark.asyncio +async def test_generate_content_async_omits_tool_choice_without_tool_config( + mock_acompletion, mock_completion +): + """generate_content_async must NOT include tool_choice when no tool_config.""" + llm_client = MockLLMClient(mock_acompletion, mock_completion) + lite_llm_instance = LiteLlm(model="openai/gpt-4o", llm_client=llm_client) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Hi")] + ) + ], + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert "tool_choice" not in kwargs diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index 4c664ae822..6624bdeb91 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -1431,3 +1431,134 @@ def test_empty_sequential_agent_falls_back_to_request(self): } else: assert declaration.parameters.properties['request'].type == 'STRING' + + +# --------------------------------------------------------------------------- +# Tests for input_schema message wrapping (PR #5924) +# --------------------------------------------------------------------------- + + +async def _run_agent_tool_and_capture_content( + args: dict, + input_schema=None, +) -> types.Content: + """Drives AgentTool and captures the Content passed to the inner agent. + + This uses a stub Runner (same pattern as test_agent_tool_inherits_parent_app_name) + to intercept the new_message without executing the actual agent pipeline. + """ + from google.adk.agents.llm_agent import LlmAgent + from google.adk.plugins.plugin_manager import PluginManager + from unittest.mock import patch + import google.adk.runners as _runners_module + + if input_schema is not None: + inner = LlmAgent( + name='inner_agent', + description='captures input', + model=testing_utils.MockModel.create(responses=['done']), + input_schema=input_schema, + ) + else: + inner = Agent(name='inner_agent', model='test-model') + + new_message_holder: list = [] + + async def _empty_async_generator(): + if False: + yield None + + class _StubRunner: + + def __init__(self, *, app_name, agent, artifact_service, + session_service, memory_service, credential_service, plugins): + del artifact_service, memory_service, credential_service + self.agent = agent + self.session_service = session_service + self.plugin_manager = PluginManager(plugins=plugins) + self.app_name = app_name + + def run_async(self, *, user_id, session_id, invocation_id=None, + new_message=None, state_delta=None, run_config=None): + new_message_holder.append(new_message) + return _empty_async_generator() + + async def close(self): + pass + + with patch.object(_runners_module, 'Runner', _StubRunner): + agent_tool = AgentTool(agent=inner) + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + invocation_context = InvocationContext( + invocation_id='invocation_id', + agent=inner, + session=session, + session_service=session_service, + ) + tool_context = ToolContext(invocation_context=invocation_context) + await agent_tool.run_async(args=args, tool_context=tool_context) + + return new_message_holder[0] if new_message_holder else None + + +@mark.asyncio +async def test_run_async_no_input_schema_passes_request_unchanged(): + """Without input_schema, the message is args['request'] verbatim.""" + content = await _run_agent_tool_and_capture_content( + args={'request': 'hello world'}, + input_schema=None, + ) + + assert content is not None + assert len(content.parts) == 1 + assert content.parts[0].text == 'hello world' + + +@mark.asyncio +async def test_run_async_with_input_schema_wraps_in_natural_language(): + """With input_schema, the message starts with a natural-language instruction.""" + + class MyInput(BaseModel): + custom_input: str + + content = await _run_agent_tool_and_capture_content( + args={'custom_input': 'test_value'}, + input_schema=MyInput, + ) + + assert content is not None + assert len(content.parts) == 1 + text = content.parts[0].text + # Must start with the natural-language prompt, not with raw JSON + assert text.startswith('Process the following structured request') + # Must contain the JSON payload after "Request:\n" + assert 'Request:\n' in text + json_part = text.split('Request:\n', 1)[1] + import json as _json + payload = _json.loads(json_part) + assert payload['custom_input'] == 'test_value' + # The full text must NOT be just the raw JSON blob + assert text != json_part + + +@mark.asyncio +async def test_run_async_with_input_schema_text_not_raw_json(): + """The content text must not be a bare JSON string when input_schema is set.""" + + class MyInput(BaseModel): + value: int + + content = await _run_agent_tool_and_capture_content( + args={'value': 42}, + input_schema=MyInput, + ) + + assert content is not None + text = content.parts[0].text + # A bare JSON blob would start with '{'; the wrapped version must not + assert not text.startswith('{'), ( + 'Content text is raw JSON instead of a natural-language instruction' + ) From 67d087a031639d92395b8af48cf7420eae373536 Mon Sep 17 00:00:00 2001 From: ecantn Date: Tue, 2 Jun 2026 11:35:00 +0200 Subject: [PATCH 3/3] fix(agent_tool): only apply ReAct wrapper when output_schema is not set - Fix: only apply the ReAct wrapper in agent_tool.py when output_schema is not set on the inner agent, preventing breaking of single-shot structured output mode - Add regression test test_run_async_with_input_and_output_schema_passes_raw_json documenting that raw JSON is passed when both input_schema and output_schema are set - Apply pre-commit formatting fixes (isort + pyink) --- src/google/adk/tools/agent_tool.py | 36 ++++++---- tests/unittests/models/test_litellm.py | 24 ++----- tests/unittests/tools/test_agent_tool.py | 89 ++++++++++++++++++++++-- 3 files changed, 111 insertions(+), 38 deletions(-) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 8efd6baf01..00fcdb8e9f 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -218,19 +218,29 @@ async def run_async( if input_schema: input_value = input_schema.model_validate(args) json_payload = input_value.model_dump_json(exclude_none=True) - content = types.Content( - role='user', - parts=[ - types.Part.from_text( - text=( - 'Process the following structured request. Use your' - ' available tools as needed to gather information or' - ' perform actions before producing the final' - ' response.\n\nRequest:\n' + json_payload - ) - ) - ], - ) + output_schema = _get_output_schema(self.agent) + if output_schema: + # Single-shot structured output mode: pass raw JSON, no ReAct wrapper. + content = types.Content( + role='user', + parts=[types.Part.from_text(text=json_payload)], + ) + else: + # Tool-calling mode: wrap with ReAct-style prompt. + content = types.Content( + role='user', + parts=[ + types.Part.from_text( + text=( + 'Process the following structured request. Use your' + ' available tools as needed to gather information or' + ' perform actions before producing the final' + ' response.\n\nRequest:\n' + + json_payload + ) + ) + ], + ) else: content = types.Content( role='user', diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 094df5111f..f3c3282bdb 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -4999,9 +4999,7 @@ async def test_get_completion_inputs_tool_choice_none_without_tool_config(): """tool_choice must be None when no tool_config is present.""" llm_request = LlmRequest( contents=[ - types.Content( - role="user", parts=[types.Part.from_text(text="Hello")] - ) + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) ], ) @@ -5017,9 +5015,7 @@ async def test_get_completion_inputs_tool_choice_required_for_any_mode(): """tool_choice must be 'required' when mode=ANY.""" llm_request = LlmRequest( contents=[ - types.Content( - role="user", parts=[types.Part.from_text(text="Hello")] - ) + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) ], config=types.GenerateContentConfig( tool_config=types.ToolConfig( @@ -5042,9 +5038,7 @@ async def test_get_completion_inputs_tool_choice_none_for_none_mode(): """tool_choice must be 'none' when mode=NONE.""" llm_request = LlmRequest( contents=[ - types.Content( - role="user", parts=[types.Part.from_text(text="Hello")] - ) + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) ], config=types.GenerateContentConfig( tool_config=types.ToolConfig( @@ -5067,9 +5061,7 @@ async def test_get_completion_inputs_tool_choice_none_for_auto_mode(): """tool_choice must be None (provider default) when mode=AUTO.""" llm_request = LlmRequest( contents=[ - types.Content( - role="user", parts=[types.Part.from_text(text="Hello")] - ) + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) ], config=types.GenerateContentConfig( tool_config=types.ToolConfig( @@ -5159,9 +5151,7 @@ async def test_generate_content_async_omits_tool_choice_for_auto_mode( llm_request = LlmRequest( contents=[ - types.Content( - role="user", parts=[types.Part.from_text(text="Hi")] - ) + types.Content(role="user", parts=[types.Part.from_text(text="Hi")]) ], config=types.GenerateContentConfig( tool_config=types.ToolConfig( @@ -5190,9 +5180,7 @@ async def test_generate_content_async_omits_tool_choice_without_tool_config( llm_request = LlmRequest( contents=[ - types.Content( - role="user", parts=[types.Part.from_text(text="Hi")] - ) + types.Content(role="user", parts=[types.Part.from_text(text="Hi")]) ], ) diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index 6624bdeb91..2ed4525904 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -1441,15 +1441,17 @@ def test_empty_sequential_agent_falls_back_to_request(self): async def _run_agent_tool_and_capture_content( args: dict, input_schema=None, + output_schema=None, ) -> types.Content: """Drives AgentTool and captures the Content passed to the inner agent. This uses a stub Runner (same pattern as test_agent_tool_inherits_parent_app_name) to intercept the new_message without executing the actual agent pipeline. """ + from unittest.mock import patch + from google.adk.agents.llm_agent import LlmAgent from google.adk.plugins.plugin_manager import PluginManager - from unittest.mock import patch import google.adk.runners as _runners_module if input_schema is not None: @@ -1458,6 +1460,7 @@ async def _run_agent_tool_and_capture_content( description='captures input', model=testing_utils.MockModel.create(responses=['done']), input_schema=input_schema, + output_schema=output_schema, ) else: inner = Agent(name='inner_agent', model='test-model') @@ -1470,16 +1473,33 @@ async def _empty_async_generator(): class _StubRunner: - def __init__(self, *, app_name, agent, artifact_service, - session_service, memory_service, credential_service, plugins): + def __init__( + self, + *, + app_name, + agent, + artifact_service, + session_service, + memory_service, + credential_service, + plugins, + ): del artifact_service, memory_service, credential_service self.agent = agent self.session_service = session_service self.plugin_manager = PluginManager(plugins=plugins) self.app_name = app_name - def run_async(self, *, user_id, session_id, invocation_id=None, - new_message=None, state_delta=None, run_config=None): + def run_async( + self, + *, + user_id, + session_id, + invocation_id=None, + new_message=None, + state_delta=None, + run_config=None, + ): new_message_holder.append(new_message) return _empty_async_generator() @@ -1538,6 +1558,7 @@ class MyInput(BaseModel): assert 'Request:\n' in text json_part = text.split('Request:\n', 1)[1] import json as _json + payload = _json.loads(json_part) assert payload['custom_input'] == 'test_value' # The full text must NOT be just the raw JSON blob @@ -1559,6 +1580,60 @@ class MyInput(BaseModel): assert content is not None text = content.parts[0].text # A bare JSON blob would start with '{'; the wrapped version must not - assert not text.startswith('{'), ( - 'Content text is raw JSON instead of a natural-language instruction' + assert not text.startswith( + '{' + ), 'Content text is raw JSON instead of a natural-language instruction' + + +@mark.asyncio +async def test_run_async_with_input_and_output_schema_passes_raw_json(): + """With both input_schema AND output_schema, the raw JSON payload is passed + directly to the inner runner WITHOUT the ReAct wrapper prefix. + + The wrapper ('Process the following structured request...') is only added + when input_schema is set and output_schema is NOT set (tool-calling mode). + When output_schema is also present the agent operates in single-shot + structured-output mode, so the runner receives the bare JSON string that the + inner agent can parse deterministically — adding the prose prefix would + corrupt the structured input. + """ + import json as _json + + class MyInput(BaseModel): + query: str + limit: int + + class MyOutput(BaseModel): + result: str + + content = await _run_agent_tool_and_capture_content( + args={'query': 'hello', 'limit': 5}, + input_schema=MyInput, + output_schema=MyOutput, ) + + assert content is not None + assert len(content.parts) == 1 + text = content.parts[0].text + + # output_schema mode is single-shot; wrapper must not be applied + assert not text.startswith('Process'), ( + 'output_schema mode is single-shot; wrapper must not be applied,' + f' but text starts with: {text[:60]!r}' + ) + + # The payload must be valid JSON + try: + payload = _json.loads(text) + except _json.JSONDecodeError as exc: + raise AssertionError( + f'Content text is not valid JSON in output_schema mode: {text!r}' + ) from exc + + # The JSON must match the input args + assert ( + payload['query'] == 'hello' + ), f"Expected query='hello', got {payload.get('query')!r}" + assert ( + payload['limit'] == 5 + ), f"Expected limit=5, got {payload.get('limit')!r}"