diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 4d4f93c88f..a8cc267b98 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -1883,6 +1883,7 @@ async def _get_completion_inputs( Optional[List[Dict]], Optional[Dict[str, Any]], Optional[Dict], + Optional[str], ]: """Converts an LlmRequest to litellm inputs and extracts generation params. @@ -1891,8 +1892,8 @@ async def _get_completion_inputs( model: The model string to use for determining provider-specific behavior. Returns: - The litellm inputs (message list, tool dictionary, response format and - generation params). + The litellm inputs (message list, tool dictionary, response format, + generation params, and tool_choice). """ _ensure_litellm_imported() @@ -1967,7 +1968,21 @@ async def _get_completion_inputs( if not generation_params: generation_params = None - return messages, tools, response_format, generation_params + # 5. Extract tool_choice from tool_config + tool_choice: Optional[str] = None + if ( + llm_request.config + and llm_request.config.tool_config + and llm_request.config.tool_config.function_calling_config + ): + mode = llm_request.config.tool_config.function_calling_config.mode + if mode == types.FunctionCallingConfigMode.ANY: + tool_choice = "required" + elif mode == types.FunctionCallingConfigMode.NONE: + tool_choice = "none" + # AUTO → None (provider default) + + return messages, tools, response_format, generation_params, tool_choice def _build_function_declaration_log( @@ -2228,7 +2243,7 @@ async def generate_content_async( logger.debug(_build_request_log(llm_request)) effective_model = llm_request.model or self.model - messages, tools, response_format, generation_params = ( + messages, tools, response_format, generation_params, tool_choice = ( await _get_completion_inputs(llm_request, effective_model) ) normalized_messages = _normalize_ollama_chat_messages( @@ -2260,6 +2275,9 @@ async def generate_content_async( if generation_params: completion_args.update(generation_params) + if tool_choice is not None: + completion_args["tool_choice"] = tool_choice + if stream: text = "" reasoning_parts: List[types.Part] = [] diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 1768861dba..00fcdb8e9f 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -217,14 +217,30 @@ async def run_async( input_schema = _get_input_schema(self.agent) if input_schema: input_value = input_schema.model_validate(args) - content = types.Content( - role='user', - parts=[ - types.Part.from_text( - text=input_value.model_dump_json(exclude_none=True) - ) - ], - ) + json_payload = input_value.model_dump_json(exclude_none=True) + output_schema = _get_output_schema(self.agent) + if output_schema: + # Single-shot structured output mode: pass raw JSON, no ReAct wrapper. + content = types.Content( + role='user', + parts=[types.Part.from_text(text=json_payload)], + ) + else: + # Tool-calling mode: wrap with ReAct-style prompt. + content = types.Content( + role='user', + parts=[ + types.Part.from_text( + text=( + 'Process the following structured request. Use your' + ' available tools as needed to gather information or' + ' perform actions before producing the final' + ' response.\n\nRequest:\n' + + json_payload + ) + ) + ], + ) else: content = types.Content( role='user', diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 216866602f..f3c3282bdb 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -258,7 +258,7 @@ async def test_get_completion_inputs_formats_pydantic_schema_for_litellm(): config=types.GenerateContentConfig(response_schema=_StructuredOutput) ) - _, _, response_format, _ = await _get_completion_inputs( + _, _, response_format, _, _ = await _get_completion_inputs( llm_request, model="gemini/gemini-2.5-flash" ) @@ -550,7 +550,7 @@ async def test_get_completion_inputs_uses_openai_format_for_openai_model(): config=types.GenerateContentConfig(response_schema=_StructuredOutput), ) - _, _, response_format, _ = await _get_completion_inputs( + _, _, response_format, _, _ = await _get_completion_inputs( llm_request, model="gpt-4o-mini" ) @@ -570,7 +570,7 @@ async def test_get_completion_inputs_uses_gemini_format_for_gemini_model(): config=types.GenerateContentConfig(response_schema=_StructuredOutput), ) - _, _, response_format, _ = await _get_completion_inputs( + _, _, response_format, _, _ = await _get_completion_inputs( llm_request, model="gemini/gemini-2.5-flash" ) @@ -590,7 +590,7 @@ async def test_get_completion_inputs_uses_passed_model_for_response_format(): ) # Pass OpenAI model explicitly - should use json_schema format - _, _, response_format, _ = await _get_completion_inputs( + _, _, response_format, _, _ = await _get_completion_inputs( llm_request, model="gpt-4o-mini" ) @@ -615,7 +615,7 @@ async def test_get_completion_inputs_uses_passed_model_for_gemini_format(): ) # Pass Gemini model explicitly - should use response_schema format - _, _, response_format, _ = await _get_completion_inputs( + _, _, response_format, _, _ = await _get_completion_inputs( llm_request, model="gemini/gemini-2.5-flash" ) @@ -645,7 +645,7 @@ async def test_get_completion_inputs_inserts_missing_tool_results(): llm_request = LlmRequest( contents=[user_content, assistant_content, followup_user] ) - messages, _, _, _ = await _get_completion_inputs( + messages, _, _, _, _ = await _get_completion_inputs( llm_request, model="openai/gpt-4o" ) @@ -4195,7 +4195,7 @@ async def test_get_completion_inputs_generation_params(): ), ) - _, _, _, generation_params = await _get_completion_inputs( + _, _, _, generation_params, _ = await _get_completion_inputs( req, model="gpt-4o-mini" ) assert generation_params["temperature"] == 0.33 @@ -4220,7 +4220,7 @@ async def test_get_completion_inputs_empty_generation_params(): config=types.GenerateContentConfig(), ) - _, _, _, generation_params = await _get_completion_inputs( + _, _, _, generation_params, _ = await _get_completion_inputs( req, model="gpt-4o-mini" ) assert generation_params is None @@ -4238,7 +4238,7 @@ async def test_get_completion_inputs_minimal_config(): ), ) - _, _, _, generation_params = await _get_completion_inputs( + _, _, _, generation_params, _ = await _get_completion_inputs( req, model="gpt-4o-mini" ) assert generation_params is None @@ -4257,7 +4257,7 @@ async def test_get_completion_inputs_partial_generation_params(): ), ) - _, _, _, generation_params = await _get_completion_inputs( + _, _, _, generation_params, _ = await _get_completion_inputs( req, model="gpt-4o-mini" ) assert generation_params is not None @@ -4620,7 +4620,7 @@ async def test_get_completion_inputs_openai_file_upload(mocker): config=types.GenerateContentConfig(tools=[]), ) - messages, tools, response_format, generation_params = ( + messages, tools, response_format, generation_params, _ = ( await _get_completion_inputs(llm_request, model="openai/gpt-4o") ) @@ -4659,7 +4659,7 @@ async def test_get_completion_inputs_non_openai_no_file_upload(mocker): config=types.GenerateContentConfig(tools=[]), ) - messages, tools, response_format, generation_params = ( + messages, tools, response_format, generation_params, _ = ( await _get_completion_inputs(llm_request, model="anthropic/claude-3-opus") ) @@ -4987,3 +4987,206 @@ async def test_generate_content_async_skips_request_log_build_above_debug( assert mock_build.called is should_call finally: litellm_logger.setLevel(original_level) + + +# --------------------------------------------------------------------------- +# Tests for tool_choice propagation (PR #5924) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_completion_inputs_tool_choice_none_without_tool_config(): + """tool_choice must be None when no tool_config is present.""" + llm_request = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) + ], + ) + + _, _, _, _, tool_choice = await _get_completion_inputs( + llm_request, model="openai/gpt-4o" + ) + + assert tool_choice is None + + +@pytest.mark.asyncio +async def test_get_completion_inputs_tool_choice_required_for_any_mode(): + """tool_choice must be 'required' when mode=ANY.""" + llm_request = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) + ], + config=types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.ANY + ) + ) + ), + ) + + _, _, _, _, tool_choice = await _get_completion_inputs( + llm_request, model="openai/gpt-4o" + ) + + assert tool_choice == "required" + + +@pytest.mark.asyncio +async def test_get_completion_inputs_tool_choice_none_for_none_mode(): + """tool_choice must be 'none' when mode=NONE.""" + llm_request = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) + ], + config=types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.NONE + ) + ) + ), + ) + + _, _, _, _, tool_choice = await _get_completion_inputs( + llm_request, model="openai/gpt-4o" + ) + + assert tool_choice == "none" + + +@pytest.mark.asyncio +async def test_get_completion_inputs_tool_choice_none_for_auto_mode(): + """tool_choice must be None (provider default) when mode=AUTO.""" + llm_request = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) + ], + config=types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.AUTO + ) + ) + ), + ) + + _, _, _, _, tool_choice = await _get_completion_inputs( + llm_request, model="openai/gpt-4o" + ) + + assert tool_choice is None + + +@pytest.mark.asyncio +async def test_generate_content_async_propagates_tool_choice_required( + mock_acompletion, mock_completion +): + """generate_content_async must pass tool_choice='required' to acompletion.""" + llm_client = MockLLMClient(mock_acompletion, mock_completion) + lite_llm_instance = LiteLlm(model="openai/gpt-4o", llm_client=llm_client) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Call a tool")] + ) + ], + config=types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.ANY + ) + ) + ), + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert kwargs.get("tool_choice") == "required" + + +@pytest.mark.asyncio +async def test_generate_content_async_propagates_tool_choice_none_mode( + mock_acompletion, mock_completion +): + """generate_content_async must pass tool_choice='none' to acompletion for NONE mode.""" + llm_client = MockLLMClient(mock_acompletion, mock_completion) + lite_llm_instance = LiteLlm(model="openai/gpt-4o", llm_client=llm_client) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="No tools please")] + ) + ], + config=types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.NONE + ) + ) + ), + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert kwargs.get("tool_choice") == "none" + + +@pytest.mark.asyncio +async def test_generate_content_async_omits_tool_choice_for_auto_mode( + mock_acompletion, mock_completion +): + """generate_content_async must NOT include tool_choice in completion_args for AUTO.""" + llm_client = MockLLMClient(mock_acompletion, mock_completion) + lite_llm_instance = LiteLlm(model="openai/gpt-4o", llm_client=llm_client) + + llm_request = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hi")]) + ], + config=types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.AUTO + ) + ) + ), + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert "tool_choice" not in kwargs + + +@pytest.mark.asyncio +async def test_generate_content_async_omits_tool_choice_without_tool_config( + mock_acompletion, mock_completion +): + """generate_content_async must NOT include tool_choice when no tool_config.""" + llm_client = MockLLMClient(mock_acompletion, mock_completion) + lite_llm_instance = LiteLlm(model="openai/gpt-4o", llm_client=llm_client) + + llm_request = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hi")]) + ], + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert "tool_choice" not in kwargs diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index 4c664ae822..2ed4525904 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -1431,3 +1431,209 @@ def test_empty_sequential_agent_falls_back_to_request(self): } else: assert declaration.parameters.properties['request'].type == 'STRING' + + +# --------------------------------------------------------------------------- +# Tests for input_schema message wrapping (PR #5924) +# --------------------------------------------------------------------------- + + +async def _run_agent_tool_and_capture_content( + args: dict, + input_schema=None, + output_schema=None, +) -> types.Content: + """Drives AgentTool and captures the Content passed to the inner agent. + + This uses a stub Runner (same pattern as test_agent_tool_inherits_parent_app_name) + to intercept the new_message without executing the actual agent pipeline. + """ + from unittest.mock import patch + + from google.adk.agents.llm_agent import LlmAgent + from google.adk.plugins.plugin_manager import PluginManager + import google.adk.runners as _runners_module + + if input_schema is not None: + inner = LlmAgent( + name='inner_agent', + description='captures input', + model=testing_utils.MockModel.create(responses=['done']), + input_schema=input_schema, + output_schema=output_schema, + ) + else: + inner = Agent(name='inner_agent', model='test-model') + + new_message_holder: list = [] + + async def _empty_async_generator(): + if False: + yield None + + class _StubRunner: + + def __init__( + self, + *, + app_name, + agent, + artifact_service, + session_service, + memory_service, + credential_service, + plugins, + ): + del artifact_service, memory_service, credential_service + self.agent = agent + self.session_service = session_service + self.plugin_manager = PluginManager(plugins=plugins) + self.app_name = app_name + + def run_async( + self, + *, + user_id, + session_id, + invocation_id=None, + new_message=None, + state_delta=None, + run_config=None, + ): + new_message_holder.append(new_message) + return _empty_async_generator() + + async def close(self): + pass + + with patch.object(_runners_module, 'Runner', _StubRunner): + agent_tool = AgentTool(agent=inner) + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + invocation_context = InvocationContext( + invocation_id='invocation_id', + agent=inner, + session=session, + session_service=session_service, + ) + tool_context = ToolContext(invocation_context=invocation_context) + await agent_tool.run_async(args=args, tool_context=tool_context) + + return new_message_holder[0] if new_message_holder else None + + +@mark.asyncio +async def test_run_async_no_input_schema_passes_request_unchanged(): + """Without input_schema, the message is args['request'] verbatim.""" + content = await _run_agent_tool_and_capture_content( + args={'request': 'hello world'}, + input_schema=None, + ) + + assert content is not None + assert len(content.parts) == 1 + assert content.parts[0].text == 'hello world' + + +@mark.asyncio +async def test_run_async_with_input_schema_wraps_in_natural_language(): + """With input_schema, the message starts with a natural-language instruction.""" + + class MyInput(BaseModel): + custom_input: str + + content = await _run_agent_tool_and_capture_content( + args={'custom_input': 'test_value'}, + input_schema=MyInput, + ) + + assert content is not None + assert len(content.parts) == 1 + text = content.parts[0].text + # Must start with the natural-language prompt, not with raw JSON + assert text.startswith('Process the following structured request') + # Must contain the JSON payload after "Request:\n" + assert 'Request:\n' in text + json_part = text.split('Request:\n', 1)[1] + import json as _json + + payload = _json.loads(json_part) + assert payload['custom_input'] == 'test_value' + # The full text must NOT be just the raw JSON blob + assert text != json_part + + +@mark.asyncio +async def test_run_async_with_input_schema_text_not_raw_json(): + """The content text must not be a bare JSON string when input_schema is set.""" + + class MyInput(BaseModel): + value: int + + content = await _run_agent_tool_and_capture_content( + args={'value': 42}, + input_schema=MyInput, + ) + + assert content is not None + text = content.parts[0].text + # A bare JSON blob would start with '{'; the wrapped version must not + assert not text.startswith( + '{' + ), 'Content text is raw JSON instead of a natural-language instruction' + + +@mark.asyncio +async def test_run_async_with_input_and_output_schema_passes_raw_json(): + """With both input_schema AND output_schema, the raw JSON payload is passed + directly to the inner runner WITHOUT the ReAct wrapper prefix. + + The wrapper ('Process the following structured request...') is only added + when input_schema is set and output_schema is NOT set (tool-calling mode). + When output_schema is also present the agent operates in single-shot + structured-output mode, so the runner receives the bare JSON string that the + inner agent can parse deterministically — adding the prose prefix would + corrupt the structured input. + """ + import json as _json + + class MyInput(BaseModel): + query: str + limit: int + + class MyOutput(BaseModel): + result: str + + content = await _run_agent_tool_and_capture_content( + args={'query': 'hello', 'limit': 5}, + input_schema=MyInput, + output_schema=MyOutput, + ) + + assert content is not None + assert len(content.parts) == 1 + text = content.parts[0].text + + # output_schema mode is single-shot; wrapper must not be applied + assert not text.startswith('Process'), ( + 'output_schema mode is single-shot; wrapper must not be applied,' + f' but text starts with: {text[:60]!r}' + ) + + # The payload must be valid JSON + try: + payload = _json.loads(text) + except _json.JSONDecodeError as exc: + raise AssertionError( + f'Content text is not valid JSON in output_schema mode: {text!r}' + ) from exc + + # The JSON must match the input args + assert ( + payload['query'] == 'hello' + ), f"Expected query='hello', got {payload.get('query')!r}" + assert ( + payload['limit'] == 5 + ), f"Expected limit=5, got {payload.get('limit')!r}"