diff --git a/tests/langfuse/prompt/test__manager.py b/tests/langfuse/prompt/test_manager.py similarity index 100% rename from tests/langfuse/prompt/test__manager.py rename to tests/langfuse/prompt/test_manager.py diff --git a/tests/langfuse/tracing/test_langfuse_reporting_fixtures.py b/tests/langfuse/tracing/test_langfuse_reporting_fixtures.py new file mode 100644 index 00000000..acbed5a8 --- /dev/null +++ b/tests/langfuse/tracing/test_langfuse_reporting_fixtures.py @@ -0,0 +1,326 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Langfuse reporting regression tests tied to the active pytest interpreter. + +Expected behaviour (same test file, different venv): +- ``./venv/bin/pytest tests/langfuse/tracing/test_langfuse_reporting_fixtures.py`` → PASS + (dev env: no broken site-packages copy, or fixed source). +- ``./examples/quickstart/venv/bin/pytest ...`` → FAIL + (quickstart venv ships an older ``trpc_agent_sdk`` in site-packages with detached spans). + +The probe reads span-creation code from **site-packages when present**, matching +``run_agent.py`` under ``examples/quickstart/``. Repo-root ``sys.path`` is ignored +for that detection so pytest at repo root still exercises the installed wheel. +""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import Callable +from unittest.mock import MagicMock + +import pytest +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +import trpc_agent_sdk.server.langfuse.tracing.opentelemetry as otel_module +from trpc_agent_sdk.events import Event +from trpc_agent_sdk.server.langfuse.tracing.opentelemetry import LangfuseConfig, _LangfuseMixin +from trpc_agent_sdk.telemetry._trace import trace_agent +from trpc_agent_sdk.telemetry._trace import trace_runner +from trpc_agent_sdk.telemetry._trace import tracer +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import Part + +SPAN_PREFIX = "trpc.python.agent" + +SYSTEM_INSTRUCTION = ( + "You are an agent who's name is [assistant].\n\n" + "You are a helpful assistant for query weather." +) +TOOLS = [ + { + "function_declarations": [ + { + "description": "get weather information for the specified city", + "name": "get_weather_report", + "parameters": { + "properties": {"city": {"type": "STRING"}}, + "type": "OBJECT", + }, + } + ] + } +] +LLM_REQUEST = { + "model": "glm-5.0-w4afp8", + "config": { + "system_instruction": SYSTEM_INSTRUCTION, + "tools": TOOLS, + }, + "contents": [ + { + "parts": [{"text": "What's the weather like today?"}], + "role": "user", + } + ], +} +LLM_RESPONSE = { + "content": { + "parts": [ + {"text": "assistant reply", "thought": False}, + ], + "role": "model", + }, + "partial": False, + "usage_metadata": { + "candidates_token_count": 107, + "prompt_token_count": 185, + "total_token_count": 292, + }, +} + + +def _iter_site_packages() -> list[Path]: + paths: list[Path] = [] + prefix = Path(sys.prefix) + for lib_dir in ("lib64", "lib"): + base = prefix / lib_dir + if not base.is_dir(): + continue + for child in sorted(base.glob("python*/site-packages")): + if child.is_dir(): + paths.append(child) + return paths + + +def _resolve_sdk_file(relative: str) -> Path: + """Resolve a SDK file, preferring site-packages over repo-source import.""" + for site in _iter_site_packages(): + candidate = site / "trpc_agent_sdk" / relative + if candidate.is_file(): + return candidate + module_path = "trpc_agent_sdk." + relative.replace("/", ".").removesuffix(".py") + module = __import__(module_path, fromlist=["_"]) + return Path(module.__file__) + + +def _span_pattern_from_source( + source: str, + *, + detached_needle: str, + current_needle: str, +) -> str: + """Return ``current`` or ``detached`` based on how a span is opened in source text.""" + has_current = current_needle in source + has_detached = detached_needle in source + if has_current and not has_detached: + return "current" + if has_detached and not has_current: + return "detached" + if has_current: + return "current" + if has_detached: + return "detached" + raise AssertionError( + f"cannot detect span pattern (needles detached={detached_needle!r} " + f"current={current_needle!r})" + ) + + +def _agent_run_span_pattern() -> str: + source = _resolve_sdk_file("agents/_base_agent.py").read_text(encoding="utf-8") + return _span_pattern_from_source( + source, + detached_needle='span = tracer.start_span(f"agent_run', + current_needle='with tracer.start_as_current_span(f"agent_run', + ) + + +def _invocation_span_pattern() -> str: + source = _resolve_sdk_file("runners.py").read_text(encoding="utf-8") + return _span_pattern_from_source( + source, + detached_needle='span = tracer.start_span("invocation")', + current_needle='with tracer.start_as_current_span(f"invocation")', + ) + + +_EXPORTER: InMemorySpanExporter | None = None + + +@pytest.fixture(scope="session", autouse=True) +def _init_otel_tracer_once(): + """OpenTelemetry allows only one TracerProvider; share it across this module.""" + global _EXPORTER # noqa: PLW0603 + _EXPORTER = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(_EXPORTER)) + trace.set_tracer_provider(provider) + yield + _EXPORTER = None + + +def _clear_finished_spans() -> None: + assert _EXPORTER is not None + _EXPORTER.clear() + + +def _run_with_span_pattern(pattern: str, span_name: str, callback: Callable[[], None]) -> None: + if pattern == "current": + with tracer.start_as_current_span(span_name): + callback() + return + if pattern == "detached": + span = tracer.start_span(span_name) + try: + callback() + finally: + span.end() + return + raise ValueError(f"unknown span pattern: {pattern}") + + +def _finished_span_attributes(name_substring: str) -> dict: + assert _EXPORTER is not None + spans = _EXPORTER.get_finished_spans() + matched = [span for span in spans if name_substring in span.name] + assert matched, ( + f"no finished span containing {name_substring!r}; " + f"got span names: {[span.name for span in spans]}" + ) + return dict(matched[0].attributes or {}) + + +def _map_to_langfuse(raw_attributes: dict) -> dict: + mixin = _LangfuseMixin() + otel_module._langfuse_config = LangfuseConfig() + return mixin._map_attributes_to_langfuse(raw_attributes) + + +def _make_invocation_context() -> MagicMock: + ctx = MagicMock() + ctx.agent.name = "assistant" + ctx.user_content = Content(role="user", parts=[Part(text="What's the weather like today?")]) + ctx.override_messages = None + ctx.session.id = "a252d252-4b55-4713-80e4-90abb177c433" + ctx.session.user_id = "demo_user" + ctx.invocation_id = "e-d5a9872c-80e3-43ea-b2a8-0091257a1616" + return ctx + + +def probe_agent_run_langfuse_mapping() -> dict: + _clear_finished_spans() + ctx = _make_invocation_context() + pattern = _agent_run_span_pattern() + + def _callback() -> None: + trace_agent( + invocation_context=ctx, + agent_action="Could you please tell me the city you're interested in?", + state_begin={"user_name": "demo_user"}, + state_end={"user_name": "demo_user"}, + ) + + _run_with_span_pattern(pattern, "agent_run [assistant]", _callback) + return _map_to_langfuse(_finished_span_attributes("agent_run")) + + +def probe_invocation_langfuse_mapping() -> dict: + _clear_finished_spans() + ctx = _make_invocation_context() + pattern = _invocation_span_pattern() + user_message = Content(role="user", parts=[Part(text="What's the weather like today?")]) + last_event = Event( + content=Content(role="model", parts=[Part(text="Could you please tell me the city you're interested in?")]), + ) + + def _callback() -> None: + trace_runner( + app_name="weather_agent_demo", + user_id="demo_user", + session_id="a252d252-4b55-4713-80e4-90abb177c433", + invocation_context=ctx, + new_message=user_message, + last_event=last_event, + state_begin={"user_name": "demo_user"}, + state_end={"user_name": "demo_user"}, + ) + + _run_with_span_pattern(pattern, "invocation", _callback) + return _map_to_langfuse(_finished_span_attributes("invocation")) + + +def assert_valid_call_llm_langfuse_mapping(result: dict) -> None: + assert result["langfuse.observation.type"] == "generation", result + llm_input = json.loads(result["langfuse.observation.input"]) + config = llm_input.get("config", {}) + assert config.get("system_instruction"), result + assert config.get("tools"), result + model_params = json.loads(result["langfuse.observation.model.parameters"]) + assert model_params.get("system_instruction"), result + assert model_params.get("tools"), result + assert result["langfuse.observation.output"] != "unknown", result + + +def assert_valid_run_agent_langfuse_mapping(result: dict) -> None: + assert result.get("langfuse.observation.type") == "span", result + assert result.get("langfuse.observation.input") == "What's the weather like today?", result + assert result.get("langfuse.observation.output"), result + + +def assert_valid_run_runner_langfuse_mapping(result: dict) -> None: + assert result.get("langfuse.trace.name") == "[trpc-agent]: weather_agent_demo/assistant", result + assert result.get("langfuse.user.id") == "demo_user", result + assert result.get("langfuse.session.id") == "a252d252-4b55-4713-80e4-90abb177c433", result + assert result.get("langfuse.observation.input") == "What's the weather like today?", result + assert result.get("langfuse.observation.output"), result + assert "langfuse.trace.metadata" in result, result + + +@pytest.fixture(autouse=True) +def _langfuse_config(): + original = otel_module._langfuse_config + otel_module._langfuse_config = LangfuseConfig() + yield + otel_module._langfuse_config = original + + +@pytest.fixture +def mixin(): + return _LangfuseMixin() + + +class TestLangfuseReportingSpanContext: + """End-to-end: telemetry must land on the span Langfuse exports (ok.txt vs error.txt).""" + + def test_trace_agent_reaches_agent_run_span(self): + result = probe_agent_run_langfuse_mapping() + assert_valid_run_agent_langfuse_mapping(result) + + def test_trace_runner_reaches_invocation_span(self): + result = probe_invocation_langfuse_mapping() + assert_valid_run_runner_langfuse_mapping(result) + + +class TestLangfuseReportingCallLlmMapping: + """call_llm mapping must always include system prompt and tools (ok.txt generation).""" + + def test_call_llm_mapping_includes_system_instruction_and_tools(self, mixin): + attrs = { + "gen_ai.operation.name": "call_llm", + f"{SPAN_PREFIX}.llm_request": json.dumps(LLM_REQUEST, ensure_ascii=False), + f"{SPAN_PREFIX}.llm_response": json.dumps(LLM_RESPONSE, ensure_ascii=False), + "gen_ai.usage.input_tokens": 185, + "gen_ai.usage.output_tokens": 107, + "gen_ai.request.model": "glm-5.0-w4afp8", + } + result = mixin._map_attributes_to_langfuse(attrs) + assert_valid_call_llm_langfuse_mapping(result) diff --git a/tests/models/test_openai_model.py b/tests/models/test_openai_model.py index 7f15ea0d..aa9b7a67 100644 --- a/tests/models/test_openai_model.py +++ b/tests/models/test_openai_model.py @@ -240,6 +240,73 @@ def test_model_type_is_model(self): assert model._type == FilterType.MODEL + def test_create_async_client_uses_custom_http_client_factory(self): + """A custom http_client_factory is passed through to AsyncOpenAI.""" + shared_http_client = Mock() + http_client_factory = Mock(return_value=shared_http_client) + model = OpenAIModel( + model_name="gpt-4", + api_key="test_key", + base_url="https://custom.api.com", + client_args={"timeout": 30}, + http_client_factory=http_client_factory, + ) + + with patch("trpc_agent_sdk.models._openai_model.openai.AsyncOpenAI") as mock_async_openai: + client = model._create_async_client() + + assert client is mock_async_openai.return_value + http_client_factory.assert_called_once_with() + mock_async_openai.assert_called_once_with( + api_key="test_key", + max_retries=0, + organization="", + base_url="https://custom.api.com", + timeout=30, + http_client=shared_http_client, + ) + + def test_create_async_client_default_factory_reuses_shared_http_client(self): + """Default factory should reuse one shared httpx.AsyncClient across model calls.""" + from trpc_agent_sdk.models import _openai_model + + _openai_model._shared_http_client = None + shared_http_client = Mock() + model = OpenAIModel(model_name="gpt-4", api_key="test_key") + + try: + with patch("trpc_agent_sdk.models._openai_model.httpx.AsyncClient", + return_value=shared_http_client) as mock_httpx_client: + with patch("trpc_agent_sdk.models._openai_model.openai.AsyncOpenAI") as mock_async_openai: + model._create_async_client() + model._create_async_client() + finally: + _openai_model._shared_http_client = None + + mock_httpx_client.assert_called_once_with() + first_call_kwargs = mock_async_openai.call_args_list[0].kwargs + second_call_kwargs = mock_async_openai.call_args_list[1].kwargs + assert first_call_kwargs["http_client"] is shared_http_client + assert second_call_kwargs["http_client"] is shared_http_client + + def test_create_async_client_overwrites_stale_client_args_http_client(self): + """Factory owns http_client injection even if client_args already has one.""" + stale_http_client = Mock() + fresh_http_client = Mock() + http_client_factory = Mock(return_value=fresh_http_client) + model = OpenAIModel( + model_name="gpt-4", + api_key="test_key", + client_args={"http_client": stale_http_client, "timeout": 30}, + http_client_factory=http_client_factory, + ) + + with patch("trpc_agent_sdk.models._openai_model.openai.AsyncOpenAI") as mock_async_openai: + model._create_async_client() + + assert mock_async_openai.call_args.kwargs["http_client"] is fresh_http_client + assert mock_async_openai.call_args.kwargs["timeout"] == 30 + # ==================== Tests for generate_async method ==================== @pytest.mark.asyncio diff --git a/trpc_agent_sdk/agents/_base_agent.py b/trpc_agent_sdk/agents/_base_agent.py index 7b16cc16..48b1eb9a 100644 --- a/trpc_agent_sdk/agents/_base_agent.py +++ b/trpc_agent_sdk/agents/_base_agent.py @@ -267,12 +267,7 @@ async def run_async( # because __aexit__ of the context manager is not guaranteed to run when # an async generator is cancelled, but try/finally always executes # even under CancelledError (PEP 492). - from opentelemetry import context as context_api - from opentelemetry.trace import set_span_in_context - - span = tracer.start_span(f"agent_run [{self.name}]") - _ctx_token = context_api.attach(set_span_in_context(span, context_api.get_current())) - try: + with tracer.start_as_current_span(f"agent_run [{self.name}]"): ctx = self._create_invocation_context(parent_context) if ctx.agent_context is None: ctx.agent_context = create_agent_context() @@ -333,9 +328,6 @@ async def run_async( # avoid memory leak reset_invocation_ctx(token) - finally: - context_api.detach(_ctx_token) - span.end() @abstractmethod async def _run_async_impl(self, ctx: InvocationContext) -> AsyncGenerator[Event, None]: diff --git a/trpc_agent_sdk/models/_openai_model.py b/trpc_agent_sdk/models/_openai_model.py index 1819da15..2253c1cd 100644 --- a/trpc_agent_sdk/models/_openai_model.py +++ b/trpc_agent_sdk/models/_openai_model.py @@ -19,9 +19,11 @@ from typing import Dict from typing import List from typing import Optional +from typing import Callable from typing_extensions import override import openai +import httpx from pydantic import BaseModel from trpc_agent_sdk.common import check_enum @@ -106,6 +108,19 @@ class ApiParamsKey(str, Enum): PROMPT_CACHE_RETENTION = "prompt_cache_retention" +HttpClientFactory = Callable[[], httpx.AsyncClient] + +_shared_http_client: httpx.AsyncClient | None = None + + +def default_http_client_factory() -> httpx.AsyncClient: + """Create a default HTTP client.""" + global _shared_http_client + if _shared_http_client is None: + _shared_http_client = httpx.AsyncClient() + return _shared_http_client + + @register_model(model_name="OpenAIModel", supported_models=[r"gpt-.*", r"o1-.*", r"deepseek-.*", r"hy3-.*"]) class OpenAIModel(LLMModel): """OpenAI model implementation using the abstract model interface. @@ -171,6 +186,11 @@ def __init__( # Extract OpenAI-specific config self.organization: str = kwargs.get(const.ORGANIZATION, "") self.client_args = kwargs.get(const.CLIENT_ARGS, {}) + # Allow callers to inject a tuned httpx client so the underlying openai.AsyncOpenAI honors connection-pool + # settings such as keepalive_expiry / max_keepalive_connections (avoids reusing stale + # keep-alive sockets that gateways close earlier than httpx). + self._http_client_factory: HttpClientFactory = kwargs.pop("http_client_factory", None) + self._http_client_factory = self._http_client_factory or default_http_client_factory # Tool prompt configuration self.add_tools_to_prompt = add_tools_to_prompt @@ -178,8 +198,8 @@ def __init__( # Default generation config that can be overridden per request self.generate_content_config = generate_content_config - # Optional hard cap for tool-response payload injected into model - # context. Disabled by default; callers (e.g. OpenClaw) can opt in. + # Optional hard cap for tool-response payload injected into model context. + # Disabled by default; callers can opt in. self._tool_response_clip_chars = int(kwargs.get("tool_response_clip_chars", 0) or 0) # Validate tool_prompt parameter @@ -207,7 +227,7 @@ def set_model_name(self, value: str) -> None: super().set_model_name(value) self._refresh_adapter() - def _create_async_client(self): + def _create_async_client(self) -> openai.AsyncOpenAI: """Create a new async client instance.""" # Disable httpx logging to prevent HTTP request logs @@ -215,6 +235,8 @@ def _create_async_client(self): logging.getLogger("httpx").setLevel(logging.WARNING) + self.client_args['http_client'] = self._http_client_factory() + return openai.AsyncOpenAI( api_key=self._api_key, max_retries=0, # disable retries @@ -1090,26 +1112,23 @@ async def _generate_single(self, if http_options is None: http_options = {} client = self._create_async_client() - try: - response = await client.chat.completions.create(**api_params, **http_options) - response_dict: dict = response.model_dump() + response = await client.chat.completions.create(**api_params, **http_options) + response_dict: dict = response.model_dump() - # Check if the response contains valid text content or tool calls - has_text_content = self._verify_text_content_in_openai_message_response(response_dict) - has_tool_calls = False + # Check if the response contains valid text content or tool calls + has_text_content = self._verify_text_content_in_openai_message_response(response_dict) + has_tool_calls = False - # Check for tool calls - choices: list[dict] = response_dict.get(const.CHOICES, [{}]) - if choices and choices[0].get(const.MESSAGE, {}).get(const.TOOL_CALLS): - has_tool_calls = True + # Check for tool calls + choices: list[dict] = response_dict.get(const.CHOICES, [{}]) + if choices and choices[0].get(const.MESSAGE, {}).get(const.TOOL_CALLS): + has_tool_calls = True - # Create response with content if we have text or tool calls - if has_text_content or has_tool_calls: - return self._create_response_with_content(response_dict) - else: - return self._create_response_without_content(response_dict) - finally: - await client.close() + # Create response with content if we have text or tool calls + if has_text_content or has_tool_calls: + return self._create_response_with_content(response_dict) + else: + return self._create_response_without_content(response_dict) def _convert_tools_to_openai_format(self, tools: List[Tool]) -> List[Dict[str, Any]]: """Convert Google GenAI tools format to OpenAI tools format. @@ -1771,5 +1790,3 @@ async def _generate_stream(self, partial=False, custom_metadata={"error": str(ex)}, ) - finally: - await client.close() diff --git a/trpc_agent_sdk/runners.py b/trpc_agent_sdk/runners.py index 7173886d..5d48f26d 100644 --- a/trpc_agent_sdk/runners.py +++ b/trpc_agent_sdk/runners.py @@ -387,12 +387,7 @@ async def run_async( # because __aexit__ of the context manager is not guaranteed to run when # an async generator is cancelled, but try/finally always executes # even under CancelledError (PEP 492). - from opentelemetry import context as context_api - from opentelemetry.trace import set_span_in_context - - span = tracer.start_span("invocation") - _ctx_token = context_api.attach(set_span_in_context(span, context_api.get_current())) - try: + with tracer.start_as_current_span("invocation"): # Create default agent context if not provided if agent_context is None: agent_context = new_agent_context() @@ -629,9 +624,6 @@ async def run_async( user_id=user_id, session_id=session_id, ) - finally: - context_api.detach(_ctx_token) - span.end() async def _append_new_message_to_session( self, diff --git a/trpc_agent_sdk/server/ag_ui/_core/_converters.py b/trpc_agent_sdk/server/ag_ui/_core/_converters.py index 7d22b4da..989395bf 100644 --- a/trpc_agent_sdk/server/ag_ui/_core/_converters.py +++ b/trpc_agent_sdk/server/ag_ui/_core/_converters.py @@ -282,7 +282,7 @@ def convert_message_content_to_parts(content: Optional[Union[str, List[Any]]]) - parts.append(part) else: item_type_name = item.get("type") if isinstance(item, dict) else type(item).__name__ - logger.debug("Ignoring unknown multimodal content item: %s", item_type_name) + logger.debug("Ignoring unknown multi-model content item: %s", item_type_name) return parts