diff --git a/AGENTS.md b/AGENTS.md index 09ab2575fb..d601680f39 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -76,6 +76,11 @@ Apply to packages under `instrumentation/` and `instrumentation-genai/`. - When catching exceptions from the underlying library to record telemetry, always re-raise the original exception unmodified. - Do not raise new exceptions in instrumentation/telemetry code. +- For GenAI streaming wrappers, prefer the shared `SyncStreamWrapper` and `AsyncStreamWrapper` + helpers from `opentelemetry.util.genai.stream` instead of reimplementing iteration, + close/context-manager, and finalization behavior in provider packages. +- Put provider-specific chunk parsing and telemetry finalization in private hook methods or a + narrow mixin. Do not make async stream wrappers inherit from sync stream wrappers. ### Semantic conventions diff --git a/CLAUDE.md b/CLAUDE.md index 43c994c2d3..ce60f10b9b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1 +1,2 @@ @AGENTS.md + diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/CHANGELOG.md b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/CHANGELOG.md index edb75dd2f9..14a219551b 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/CHANGELOG.md +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Refactor chat completion stream wrappers to use shared GenAI stream lifecycle helpers. + ([#4500](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4500)) ## Version 2.4b0 (2026-05-01) diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/chat_buffers.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/chat_buffers.py new file mode 100644 index 0000000000..bfa5d21a57 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/chat_buffers.py @@ -0,0 +1,52 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class ToolCallBuffer: + def __init__(self, index, tool_call_id, function_name): + self.index = index + self.function_name = function_name + self.tool_call_id = tool_call_id + self.arguments = [] + + def append_arguments(self, arguments): + if arguments is not None: + self.arguments.append(arguments) + + +class ChoiceBuffer: + def __init__(self, index): + self.index = index + self.finish_reason = None + self.text_content = [] + self.tool_calls_buffers = [] + + def append_text_content(self, content): + self.text_content.append(content) + + def append_tool_call(self, tool_call): + idx = tool_call.index + for _ in range(len(self.tool_calls_buffers), idx + 1): + self.tool_calls_buffers.append(None) + + function = tool_call.function + if not self.tool_calls_buffers[idx]: + self.tool_calls_buffers[idx] = ToolCallBuffer( + idx, + tool_call.id, + function.name if function else None, + ) + + if function: + self.tool_calls_buffers[idx].append_arguments(function.arguments) diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/chat_wrappers.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/chat_wrappers.py new file mode 100644 index 0000000000..dcce2efbac --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/chat_wrappers.py @@ -0,0 +1,219 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import Any, Optional + +from openai import AsyncStream, Stream + +from opentelemetry.semconv._incubating.attributes import ( + openai_attributes as OpenAIAttributes, +) +from opentelemetry.util.genai.invocation import InferenceInvocation +from opentelemetry.util.genai.stream import ( + AsyncStreamWrapper, + SyncStreamWrapper, +) +from opentelemetry.util.genai.types import ( + OutputMessage, + Text, + ToolCallRequest, +) + +from .chat_buffers import ChoiceBuffer + + +class _ChatStreamMixin: + """Chat-specific hooks shared by sync and async stream wrappers.""" + + invocation: InferenceInvocation + capture_content: bool + choice_buffers: list + response_id: Optional[str] = None + response_model: Optional[str] = None + service_tier: Optional[str] = None + prompt_tokens: Optional[int] = None + completion_tokens: Optional[int] = None + + def _set_response_model(self, chunk): + if self.response_model: + return + + if getattr(chunk, "model", None): + self.response_model = chunk.model + + def _set_response_id(self, chunk): + if self.response_id: + return + + if getattr(chunk, "id", None): + self.response_id = chunk.id + + def _set_response_service_tier(self, chunk): + if self.service_tier: + return + + if getattr(chunk, "service_tier", None): + self.service_tier = chunk.service_tier + + def _build_streaming_response(self, chunk): + if getattr(chunk, "choices", None) is None: + return + + choices = chunk.choices + for choice in choices: + if not choice.delta: + continue + + for idx in range(len(self.choice_buffers), choice.index + 1): + self.choice_buffers.append(ChoiceBuffer(idx)) + + if choice.finish_reason: + self.choice_buffers[ + choice.index + ].finish_reason = choice.finish_reason + + if choice.delta.content is not None: + self.choice_buffers[choice.index].append_text_content( + choice.delta.content + ) + + if choice.delta.tool_calls is not None: + for tool_call in choice.delta.tool_calls: + self.choice_buffers[choice.index].append_tool_call( + tool_call + ) + + def _set_usage(self, chunk): + if getattr(chunk, "usage", None): + self.completion_tokens = chunk.usage.completion_tokens + self.prompt_tokens = chunk.usage.prompt_tokens + + def _process_chunk(self, chunk): + self._set_response_id(chunk) + self._set_response_model(chunk) + self._set_response_service_tier(chunk) + self._build_streaming_response(chunk) + self._set_usage(chunk) + + def _set_output_messages(self): + if not self.capture_content: # optimization + return + output_messages = [] + for choice in self.choice_buffers: + message = OutputMessage( + role="assistant", + finish_reason=choice.finish_reason or "error", + parts=[], + ) + if choice.text_content: + message.parts.append( + Text(content="".join(choice.text_content)) + ) + if choice.tool_calls_buffers: + tool_calls = [] + for tool_call in choice.tool_calls_buffers: + arguments = None + arguments_str = "".join(tool_call.arguments) + if arguments_str: + try: + arguments = json.loads(arguments_str) + except json.JSONDecodeError: + arguments = arguments_str + tool_call_part = ToolCallRequest( + name=tool_call.function_name, + id=tool_call.tool_call_id, + arguments=arguments, + ) + tool_calls.append(tool_call_part) + message.parts.extend(tool_calls) + output_messages.append(message) + + self.invocation.output_messages = output_messages + + def _stop_stream(self) -> None: + self._cleanup() + + def _fail_stream(self, error: BaseException) -> None: + self._cleanup(error) + + def parse(self): + """Called when using with_raw_response with stream=True.""" + return self + + def _cleanup(self, error: Optional[BaseException] = None) -> None: + self.invocation.response_model_name = self.response_model + self.invocation.response_id = self.response_id + self.invocation.input_tokens = self.prompt_tokens + self.invocation.output_tokens = self.completion_tokens + finish_reasons = [ + choice.finish_reason + for choice in self.choice_buffers + if choice.finish_reason + ] + if finish_reasons: + self.invocation.finish_reasons = finish_reasons + if self.service_tier: + self.invocation.attributes.update( + { + OpenAIAttributes.OPENAI_RESPONSE_SERVICE_TIER: self.service_tier + }, + ) + + self._set_output_messages() + + if error: + self.invocation.fail(error) + else: + self.invocation.stop() + + +class ChatStreamWrapper( + _ChatStreamMixin, + SyncStreamWrapper[Any], +): + def __init__( + self, + stream: Stream, + invocation: InferenceInvocation, + capture_content: bool, + ): + super().__init__(stream) + self.invocation = invocation + self.choice_buffers = [] + self.capture_content = capture_content + + +class AsyncChatStreamWrapper( + _ChatStreamMixin, + AsyncStreamWrapper[Any], +): + def __init__( + self, + stream: AsyncStream, + invocation: InferenceInvocation, + capture_content: bool, + ): + super().__init__(stream) + self.invocation = invocation + self.choice_buffers = [] + self.capture_content = capture_content + + +__all__ = [ + "AsyncChatStreamWrapper", + "ChatStreamWrapper", +] diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/patch.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/patch.py index 267d57e653..8804446470 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/patch.py +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/patch.py @@ -13,7 +13,6 @@ # limitations under the License. -import json from timeit import default_timer from typing import Any, Optional @@ -36,11 +35,10 @@ from opentelemetry.util.genai.invocation import InferenceInvocation from opentelemetry.util.genai.types import ( Error, - OutputMessage, - Text, - ToolCallRequest, ) +from .chat_buffers import ChoiceBuffer +from .chat_wrappers import AsyncChatStreamWrapper, ChatStreamWrapper from .instruments import Instruments from .utils import ( _prepare_output_messages, @@ -240,7 +238,7 @@ async def traced_method(wrapped, instance, args, kwargs): else: parsed_result = result if is_streaming(kwargs): - return ChatStreamWrapper( + return AsyncChatStreamWrapper( parsed_result, chat_invocation, capture_content ) @@ -568,46 +566,6 @@ def _set_embeddings_response_attributes( # Don't set output tokens for embeddings as all tokens are input tokens -class ToolCallBuffer: - def __init__(self, index, tool_call_id, function_name): - self.index = index - self.function_name = function_name - self.tool_call_id = tool_call_id - self.arguments = [] - - def append_arguments(self, arguments): - if arguments is not None: - self.arguments.append(arguments) - - -class ChoiceBuffer: - def __init__(self, index): - self.index = index - self.finish_reason = None - self.text_content = [] - self.tool_calls_buffers = [] - - def append_text_content(self, content): - self.text_content.append(content) - - def append_tool_call(self, tool_call): - idx = tool_call.index - # make sure we have enough tool call buffers - for _ in range(len(self.tool_calls_buffers), idx + 1): - self.tool_calls_buffers.append(None) - - function = tool_call.function - if not self.tool_calls_buffers[idx]: - self.tool_calls_buffers[idx] = ToolCallBuffer( - idx, - tool_call.id, - function.name if function else None, - ) - - if function: - self.tool_calls_buffers[idx].append_arguments(function.arguments) - - class BaseStreamWrapper: response_id: Optional[str] = None response_model: Optional[str] = None @@ -859,83 +817,3 @@ def cleanup(self, error: Optional[BaseException] = None): else: self.span.end() self._started = False - - -class ChatStreamWrapper(BaseStreamWrapper): - invocation: InferenceInvocation - response_id: Optional[str] = None - response_model: Optional[str] = None - service_tier: Optional[str] = None - finish_reasons: list = [] - prompt_tokens: Optional[int] = None - completion_tokens: Optional[int] = None - - def __init__( - self, - stream: Stream, - invocation: InferenceInvocation, - capture_content: bool, - ): - super().__init__(stream, capture_content=capture_content) - self.stream = stream - self.invocation = invocation - self.choice_buffers = [] - - def _set_output_messages(self): - if not self.capture_content: # optimization - return - output_messages = [] - for choice in self.choice_buffers: - message = OutputMessage( - role="assistant", - finish_reason=choice.finish_reason or "error", - parts=[], - ) - if choice.text_content: - message.parts.append( - Text(content="".join(choice.text_content)) - ) - if choice.tool_calls_buffers: - tool_calls = [] - for tool_call in choice.tool_calls_buffers: - arguments = None - arguments_str = "".join(tool_call.arguments) - if arguments_str: - try: - arguments = json.loads(arguments_str) - except json.JSONDecodeError: - arguments = arguments_str - tool_call_part = ToolCallRequest( - name=tool_call.function_name, - id=tool_call.tool_call_id, - arguments=arguments, - ) - tool_calls.append(tool_call_part) - message.parts.extend(tool_calls) - output_messages.append(message) - - self.invocation.output_messages = output_messages - - def cleanup(self, error: Optional[BaseException] = None): - if not self._started: - return - - self.invocation.response_model_name = self.response_model - self.invocation.response_id = self.response_id - self.invocation.input_tokens = self.prompt_tokens - self.invocation.output_tokens = self.completion_tokens - self.invocation.finish_reasons = self.finish_reasons - if self.service_tier: - self.invocation.attributes.update( - { - OpenAIAttributes.OPENAI_RESPONSE_SERVICE_TIER: self.service_tier - }, - ) - - self._set_output_messages() - - if error: - self.invocation.fail(Error(type=type(error), message=str(error))) - else: - self.invocation.stop() - self._started = False diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/utils.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/utils.py index 4dab04d977..17794d6d6f 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/utils.py +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/utils.py @@ -399,7 +399,7 @@ def create_chat_invocation( extra_body = get_value(kwargs.get("extra_body")) if isinstance(extra_body, Mapping): service_tier = get_value(extra_body.get("service_tier")) - if service_tier is not None: + if service_tier is not None and service_tier != "auto": invocation.attributes[OpenAIAttributes.OPENAI_REQUEST_SERVICE_TIER] = ( service_tier ) diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_async_chat_completions.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_async_chat_completions.py index c19a8b3d5b..2f4c916938 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_async_chat_completions.py +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_async_chat_completions.py @@ -197,6 +197,18 @@ async def test_async_chat_completion_404( assert "NotFoundError" == spans[0].attributes[ErrorAttributes.ERROR_TYPE] +@pytest.mark.asyncio() +async def test_async_chat_completion_api_exception_propagates( + async_openai_client, instrument_no_content, vcr +): + with vcr.use_cassette("test_async_chat_completion_404.yaml"): + with pytest.raises(NotFoundError): + await async_openai_client.chat.completions.create( + messages=USER_ONLY_PROMPT, + model="this-model-does-not-exist", + ) + + @pytest.mark.asyncio() async def test_async_chat_completion_extra_params( span_exporter, async_openai_client, instrument_no_content, vcr @@ -883,6 +895,44 @@ async def test_async_chat_completion_streaming( ) +@pytest.mark.asyncio() +async def test_async_chat_completion_streaming_user_exception_propagates( + span_exporter, + async_openai_client, + instrument_with_content, + vcr, +): + latest_experimental_enabled = is_experimental_mode() + llm_model_value = "gpt-4" + kwargs = { + "model": llm_model_value, + "messages": USER_ONLY_PROMPT, + "stream": True, + "stream_options": {"include_usage": True}, + } + response_stream_model = None + response_stream_id = None + + with vcr.use_cassette("test_async_chat_completion_streaming.yaml"): + response = await async_openai_client.chat.completions.create(**kwargs) + with pytest.raises(RuntimeError, match="user failure"): + async with response: + async for chunk in response: + response_stream_model = chunk.model + response_stream_id = chunk.id + raise RuntimeError("user failure") + + spans = span_exporter.get_finished_spans() + assert_all_attributes( + spans[0], + llm_model_value, + latest_experimental_enabled, + response_stream_id, + response_stream_model, + ) + assert "RuntimeError" == spans[0].attributes[ErrorAttributes.ERROR_TYPE] + + @pytest.mark.asyncio() async def test_async_chat_completion_streaming_not_complete( span_exporter, @@ -921,7 +971,10 @@ async def test_async_chat_completion_streaming_not_complete( response_stream_id = chunk.id idx += 1 - response.close() + if latest_experimental_enabled: + await response.close() + else: + response.close() spans = span_exporter.get_finished_spans() assert_all_attributes( spans[0], diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_chat_completions.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_chat_completions.py index 3e4df914dc..7d25e400f1 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_chat_completions.py +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_chat_completions.py @@ -287,6 +287,17 @@ def test_chat_completion_404( ) +def test_chat_completion_api_exception_propagates( + openai_client, instrument_no_content, vcr +): + with vcr.use_cassette("test_chat_completion_404.yaml"): + with pytest.raises(NotFoundError): + openai_client.chat.completions.create( + messages=USER_ONLY_PROMPT, + model="this-model-does-not-exist", + ) + + def test_chat_completion_extra_params( span_exporter, openai_client, instrument_no_content, vcr ): @@ -997,6 +1008,123 @@ def test_chat_completion_streaming( ) +def test_chat_completion_streaming_user_exception_propagates( + span_exporter, openai_client, instrument_with_content, vcr +): + latest_experimental_enabled = is_experimental_mode() + kwargs = { + "model": DEFAULT_MODEL, + "messages": USER_ONLY_PROMPT, + "stream": True, + "stream_options": {"include_usage": True}, + } + response_stream_model = None + response_stream_id = None + + with vcr.use_cassette("test_chat_completion_streaming.yaml"): + response = openai_client.chat.completions.create(**kwargs) + with pytest.raises(RuntimeError, match="user failure"): + with response: + for chunk in response: + response_stream_model = chunk.model + response_stream_id = chunk.id + raise RuntimeError("user failure") + + spans = span_exporter.get_finished_spans() + assert_all_attributes( + spans[0], + DEFAULT_MODEL, + latest_experimental_enabled, + response_stream_id, + response_stream_model, + ) + assert "RuntimeError" == spans[0].attributes[ErrorAttributes.ERROR_TYPE] + + +def test_chat_completion_streaming_user_exception_wins_over_close_exception( + span_exporter, openai_client, instrument_with_content, vcr, monkeypatch +): + if not is_experimental_mode(): + pytest.skip("new stream wrapper only") + + kwargs = { + "model": DEFAULT_MODEL, + "messages": USER_ONLY_PROMPT, + "stream": True, + "stream_options": {"include_usage": True}, + } + + with vcr.use_cassette("test_chat_completion_streaming.yaml"): + response = openai_client.chat.completions.create(**kwargs) + original_close = response.stream.close + + def close_raises(): + original_close() + raise RuntimeError("close failure") + + monkeypatch.setattr(response.stream, "close", close_raises) + with pytest.raises(RuntimeError, match="user failure"): + with response: + raise RuntimeError("user failure") + + spans = span_exporter.get_finished_spans() + assert "RuntimeError" == spans[0].attributes[ErrorAttributes.ERROR_TYPE] + + +def test_chat_completion_streaming_close_exception_propagates_when_first( + span_exporter, openai_client, instrument_with_content, vcr, monkeypatch +): + if not is_experimental_mode(): + pytest.skip("new stream wrapper only") + + kwargs = { + "model": DEFAULT_MODEL, + "messages": USER_ONLY_PROMPT, + "stream": True, + "stream_options": {"include_usage": True}, + } + + with vcr.use_cassette("test_chat_completion_streaming.yaml"): + response = openai_client.chat.completions.create(**kwargs) + original_close = response.stream.close + + def close_raises(): + original_close() + raise RuntimeError("close failure") + + monkeypatch.setattr(response.stream, "close", close_raises) + with pytest.raises(RuntimeError, match="close failure"): + response.close() + + spans = span_exporter.get_finished_spans() + assert "RuntimeError" == spans[0].attributes[ErrorAttributes.ERROR_TYPE] + + +def test_chat_completion_streaming_instrumentation_finalize_errors_swallowed( + span_exporter, openai_client, instrument_with_content, vcr, monkeypatch +): + if not is_experimental_mode(): + pytest.skip("new stream wrapper only") + + kwargs = { + "model": DEFAULT_MODEL, + "messages": USER_ONLY_PROMPT, + "stream": True, + "stream_options": {"include_usage": True}, + } + + with vcr.use_cassette("test_chat_completion_streaming.yaml"): + response = openai_client.chat.completions.create(**kwargs) + + def stop_raises(): + raise RuntimeError("instrumentation failure") + + monkeypatch.setattr(response, "_stop_stream", stop_raises) + response.close() + + assert span_exporter.get_finished_spans() == () + + def test_chat_completion_streaming_not_complete( span_exporter, log_exporter, openai_client, instrument_with_content, vcr ): diff --git a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_choice_buffer.py b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_choice_buffer.py index 7717ff73b2..17aa0ea289 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_choice_buffer.py +++ b/instrumentation-genai/opentelemetry-instrumentation-openai-v2/tests/test_choice_buffer.py @@ -19,7 +19,7 @@ ChoiceDeltaToolCallFunction, ) -from opentelemetry.instrumentation.openai_v2.patch import ( +from opentelemetry.instrumentation.openai_v2.chat_buffers import ( ChoiceBuffer, ToolCallBuffer, ) diff --git a/util/opentelemetry-util-genai/CHANGELOG.md b/util/opentelemetry-util-genai/CHANGELOG.md index c8858e66be..be8af56710 100644 --- a/util/opentelemetry-util-genai/CHANGELOG.md +++ b/util/opentelemetry-util-genai/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Add shared sync and async stream wrapper base classes for GenAI instrumentations. + ([#4500](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4500)) ## Version 0.4b0 (2026-05-01) - Add `AgentInvocation` type with `invoke_agent` span lifecycle diff --git a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/stream.py b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/stream.py new file mode 100644 index 0000000000..4a754a97e6 --- /dev/null +++ b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/stream.py @@ -0,0 +1,272 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Any, Generic, Literal, TypeVar + +ChunkT = TypeVar("ChunkT") +_logger = logging.getLogger(__name__) + + +class SyncStreamWrapper(ABC, Generic[ChunkT]): + """Base class for synchronous instrumented stream wrappers. + + Subclass this when wrapping a provider SDK stream that is consumed with + normal iteration. The subclass should pass the SDK stream to + ``super().__init__(stream)`` and implement the three telemetry hooks: + ``_process_chunk`` for per-chunk state, ``_stop_stream`` for successful + finalization, and ``_fail_stream`` for failure finalization. + + Users should consume subclasses as normal streams, for example with + ``for chunk in wrapper`` or ``with wrapper``. The hook methods are called + internally by the wrapper lifecycle and are not part of the public API. + """ + + def __init__(self, stream: Any): + self.stream = stream + self._iterator = iter(stream) + self._finalized = False + + def __enter__(self): + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> Literal[False]: + if exc_type is not None: + self._safe_finalize_failure(exc_val or Exception()) + try: + self.stream.close() + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "GenAI stream close error after user exception", + exc_info=True, + ) + return False + + self.close() + return False + + def close(self) -> None: + try: + self.stream.close() + except Exception as error: + self._safe_finalize_failure(error) + raise + self._safe_finalize_success() + + def __iter__(self): + return self + + def __next__(self) -> ChunkT: + try: + chunk = next(self._iterator) + except StopIteration: + self._safe_finalize_success() + raise + except Exception as error: + self._safe_finalize_failure(error) + raise + try: + self._process_chunk(chunk) + except Exception as error: # pylint: disable=broad-exception-caught + self._handle_process_chunk_error(error) + return chunk + + def __getattr__(self, name: str) -> Any: + return getattr(self.stream, name) + + def _finalize_success(self) -> None: + if self._finalized: + return + self._finalized = True + self._stop_stream() + + def _finalize_failure(self, error: BaseException) -> None: + if self._finalized: + return + self._finalized = True + self._fail_stream(error) + + def _safe_finalize_success(self) -> None: + try: + self._finalize_success() + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "GenAI stream instrumentation error during finalization", + exc_info=True, + ) + + def _safe_finalize_failure(self, error: BaseException) -> None: + try: + self._finalize_failure(error) + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "GenAI stream instrumentation error during failure finalization", + exc_info=True, + ) + + @abstractmethod + def _process_chunk(self, chunk: ChunkT) -> None: + """Process one stream chunk for telemetry.""" + + @abstractmethod + def _stop_stream(self) -> None: + """Finalize the stream successfully.""" + + @abstractmethod + def _fail_stream(self, error: BaseException) -> None: + """Finalize the stream with failure.""" + + @staticmethod + def _handle_process_chunk_error(_error: Exception) -> None: + _logger.debug( + "GenAI stream instrumentation error during chunk processing", + exc_info=True, + ) + + +class AsyncStreamWrapper(ABC, Generic[ChunkT]): + """Base class for asynchronous instrumented stream wrappers. + + Subclass this when wrapping a provider SDK stream that is consumed with + async iteration. The subclass should pass the SDK stream to + ``super().__init__(stream)`` and implement the three telemetry hooks: + ``_process_chunk`` for per-chunk state, ``_stop_stream`` for successful + finalization, and ``_fail_stream`` for failure finalization. + + Users should consume subclasses as normal async streams, for example with + ``async for chunk in wrapper`` or ``async with wrapper``. The hook methods + remain synchronous telemetry hooks; async stream reads and close handling + are owned by this base class. + """ + + def __init__(self, stream: Any): + self.stream = stream + self._aiter = aiter(stream) + self._finalized = False + + async def __aenter__(self): + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> Literal[False]: + if exc_type is not None: + self._safe_finalize_failure(exc_val or Exception()) + try: + await self.stream.close() + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "GenAI stream close error after user exception", + exc_info=True, + ) + return False + + await self.close() + return False + + async def close(self) -> None: + try: + await self.stream.close() + except Exception as error: + self._safe_finalize_failure(error) + raise + self._safe_finalize_success() + + def __aiter__(self): + return self + + async def __anext__(self) -> ChunkT: + try: + chunk = await anext(self._aiter) + except StopAsyncIteration: + self._safe_finalize_success() + raise + except Exception as error: + self._safe_finalize_failure(error) + raise + try: + self._process_chunk(chunk) + except Exception as error: # pylint: disable=broad-exception-caught + self._handle_process_chunk_error(error) + return chunk + + def __getattr__(self, name: str) -> Any: + return getattr(self.stream, name) + + def _finalize_success(self) -> None: + if self._finalized: + return + self._finalized = True + self._stop_stream() + + def _finalize_failure(self, error: BaseException) -> None: + if self._finalized: + return + self._finalized = True + self._fail_stream(error) + + def _safe_finalize_success(self) -> None: + try: + self._finalize_success() + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "GenAI stream instrumentation error during finalization", + exc_info=True, + ) + + def _safe_finalize_failure(self, error: BaseException) -> None: + try: + self._finalize_failure(error) + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "GenAI stream instrumentation error during failure finalization", + exc_info=True, + ) + + @abstractmethod + def _process_chunk(self, chunk: ChunkT) -> None: + """Process one stream chunk for telemetry.""" + + @abstractmethod + def _stop_stream(self) -> None: + """Finalize the stream successfully.""" + + @abstractmethod + def _fail_stream(self, error: BaseException) -> None: + """Finalize the stream with failure.""" + + @staticmethod + def _handle_process_chunk_error(_error: Exception) -> None: + _logger.debug( + "GenAI stream instrumentation error during chunk processing", + exc_info=True, + ) + + +__all__ = [ + "AsyncStreamWrapper", + "SyncStreamWrapper", +] diff --git a/util/opentelemetry-util-genai/tests/test_stream.py b/util/opentelemetry-util-genai/tests/test_stream.py new file mode 100644 index 0000000000..b34012cff7 --- /dev/null +++ b/util/opentelemetry-util-genai/tests/test_stream.py @@ -0,0 +1,519 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import inspect + +import pytest + +from opentelemetry.util.genai.stream import ( + AsyncStreamWrapper, + SyncStreamWrapper, +) + + +def test_stream_wrapper_abstract_method_signatures_match(): + method_names = ( + "_process_chunk", + "_stop_stream", + "_fail_stream", + "_handle_process_chunk_error", + ) + + for method_name in method_names: + assert inspect.signature( + getattr(SyncStreamWrapper, method_name) + ) == inspect.signature(getattr(AsyncStreamWrapper, method_name)) + + +class _FakeSyncStream: + def __init__(self, chunks=None, error=None, close_error=None): + self._chunks = list(chunks or []) + self._error = error + self._close_error = close_error + self.close_count = 0 + self.extra_attribute = "passthrough" + + def __iter__(self): + return self + + def __next__(self): + if self._chunks: + return self._chunks.pop(0) + if self._error: + raise self._error + raise StopIteration + + def close(self): + self.close_count += 1 + if self._close_error: + raise self._close_error + + +class _FakeSyncIterable: + def __init__(self, chunks=None): + self.iterator = iter(chunks or []) + self.close_count = 0 + + def __iter__(self): + return self.iterator + + def close(self): + self.close_count += 1 + + +class _TestSyncStreamWrapper(SyncStreamWrapper): + def __init__(self, stream): + super().__init__(stream) + self.processed = [] + self.stop_count = 0 + self.failures = [] + + def _process_chunk(self, chunk): + self.processed.append(chunk) + + def _stop_stream(self): + self.stop_count += 1 + + def _fail_stream(self, error): + self.failures.append(error) + + +class _FailingSyncProcessStreamWrapper(_TestSyncStreamWrapper): + def _process_chunk(self, chunk): + raise ValueError("instrumentation failed") + + +class _FailingSyncStopStreamWrapper(_TestSyncStreamWrapper): + def _stop_stream(self): + self.stop_count += 1 + raise ValueError("instrumentation failed") + + +class _FailingSyncFailStreamWrapper(_TestSyncStreamWrapper): + def _fail_stream(self, error): + self.failures.append(error) + raise ValueError("instrumentation failed") + + +def test_sync_stream_wrapper_processes_chunks_and_stops(): + stream = _FakeSyncStream(chunks=["chunk"]) + wrapper = _TestSyncStreamWrapper(stream) + + assert next(wrapper) == "chunk" + assert wrapper.processed == ["chunk"] + + try: + next(wrapper) + except StopIteration: + pass + + assert wrapper.stop_count == 1 + + +def test_sync_stream_wrapper_processes_iterables(): + stream = _FakeSyncIterable(chunks=["chunk"]) + wrapper = _TestSyncStreamWrapper(stream) + + assert next(wrapper) == "chunk" + assert wrapper.processed == ["chunk"] + + with pytest.raises(StopIteration): + next(wrapper) + + assert wrapper.stop_count == 1 + + +def test_sync_stream_wrapper_fails_stream_errors(): + error = ValueError("boom") + wrapper = _TestSyncStreamWrapper(_FakeSyncStream(error=error)) + + try: + next(wrapper) + except ValueError: + pass + + assert wrapper.failures == [error] + + +def test_sync_stream_wrapper_close_stops_once(): + stream = _FakeSyncStream(chunks=["chunk"]) + wrapper = _TestSyncStreamWrapper(stream) + + wrapper.close() + wrapper.close() + + assert stream.close_count == 2 + assert wrapper.stop_count == 1 + assert not wrapper.failures + + +def test_sync_stream_wrapper_close_fails_with_close_error(): + error = RuntimeError("close failure") + wrapper = _TestSyncStreamWrapper( + _FakeSyncStream(chunks=["chunk"], close_error=error) + ) + + with pytest.raises(RuntimeError, match="close failure"): + wrapper.close() + + assert wrapper.failures == [error] + assert wrapper.stop_count == 0 + + +def test_sync_stream_wrapper_exit_closes_and_propagates_user_errors(): + stream = _FakeSyncStream(chunks=["chunk"]) + wrapper = _TestSyncStreamWrapper(stream) + error = RuntimeError("user failure") + + assert wrapper.__exit__(RuntimeError, error, None) is False + + assert stream.close_count == 1 + assert wrapper.stop_count == 0 + assert wrapper.failures == [error] + + +def test_sync_stream_wrapper_exit_keeps_user_error_when_close_fails(): + close_error = RuntimeError("close failure") + stream = _FakeSyncStream(chunks=["chunk"], close_error=close_error) + wrapper = _TestSyncStreamWrapper(stream) + error = RuntimeError("user failure") + + assert wrapper.__exit__(RuntimeError, error, None) is False + + assert stream.close_count == 1 + assert wrapper.failures == [error] + assert wrapper.stop_count == 0 + + +def test_sync_stream_wrapper_swallows_finalize_errors(): + wrapper = _FailingSyncStopStreamWrapper(_FakeSyncStream()) + + wrapper.close() + wrapper.close() + + assert wrapper.stop_count == 1 + + +def test_sync_stream_wrapper_swallows_failure_finalize_errors(): + close_error = RuntimeError("close failure") + stream = _FakeSyncStream(close_error=close_error) + wrapper = _FailingSyncFailStreamWrapper(stream) + + with pytest.raises(RuntimeError, match="close failure"): + wrapper.close() + stream._close_error = None + wrapper.close() + + assert wrapper.failures == [close_error] + + +def test_sync_stream_wrapper_swallows_stop_iteration_finalize_errors(): + wrapper = _FailingSyncStopStreamWrapper(_FakeSyncStream()) + + with pytest.raises(StopIteration): + next(wrapper) + + +def test_sync_stream_wrapper_preserves_stream_error_when_finalize_fails(): + error = RuntimeError("stream failure") + wrapper = _FailingSyncFailStreamWrapper(_FakeSyncStream(error=error)) + + with pytest.raises(RuntimeError, match="stream failure"): + next(wrapper) + + +def test_sync_stream_wrapper_getattr_passthrough(): + wrapper = _TestSyncStreamWrapper(_FakeSyncStream()) + + assert wrapper.extra_attribute == "passthrough" + + +def test_sync_stream_wrapper_stop_iteration_does_not_double_finalize(): + wrapper = _TestSyncStreamWrapper(_FakeSyncStream()) + + with pytest.raises(StopIteration): + next(wrapper) + wrapper.close() + + assert wrapper.stop_count == 1 + assert not wrapper.failures + + +def test_sync_stream_wrapper_swallows_process_chunk_errors(): + wrapper = _FailingSyncProcessStreamWrapper( + _FakeSyncStream(chunks=["chunk"]) + ) + + assert next(wrapper) == "chunk" + assert not wrapper.failures + + +class _FakeAsyncStream: + def __init__(self, chunks=None, error=None, close_error=None): + self._chunks = list(chunks or []) + self._error = error + self._close_error = close_error + self.close_count = 0 + self.extra_attribute = "passthrough" + + def __aiter__(self): + return self + + async def __anext__(self): + if self._chunks: + return self._chunks.pop(0) + if self._error: + raise self._error + raise StopAsyncIteration + + async def close(self): + self.close_count += 1 + if self._close_error: + raise self._close_error + + +class _FakeAsyncIterable: + def __init__(self, chunks=None): + self.iterator = _FakeAsyncStream(chunks=chunks) + self.close_count = 0 + + def __aiter__(self): + return self.iterator + + async def close(self): + self.close_count += 1 + + +class _TestAsyncStreamWrapper(AsyncStreamWrapper): + def __init__(self, stream): + super().__init__(stream) + self.processed = [] + self.stop_count = 0 + self.failures = [] + + def _process_chunk(self, chunk): + self.processed.append(chunk) + + def _stop_stream(self): + self.stop_count += 1 + + def _fail_stream(self, error): + self.failures.append(error) + + +class _FailingAsyncProcessStreamWrapper(_TestAsyncStreamWrapper): + def _process_chunk(self, chunk): + raise ValueError("instrumentation failed") + + +class _FailingAsyncStopStreamWrapper(_TestAsyncStreamWrapper): + def _stop_stream(self): + self.stop_count += 1 + raise ValueError("instrumentation failed") + + +class _FailingAsyncFailStreamWrapper(_TestAsyncStreamWrapper): + def _fail_stream(self, error): + self.failures.append(error) + raise ValueError("instrumentation failed") + + +def test_async_stream_wrapper_processes_chunks_and_stops(): + async def exercise(): + wrapper = _TestAsyncStreamWrapper(_FakeAsyncStream(chunks=["chunk"])) + + assert await anext(wrapper) == "chunk" + assert wrapper.processed == ["chunk"] + + try: + await anext(wrapper) + except StopAsyncIteration: + pass + + assert wrapper.stop_count == 1 + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_processes_async_iterables(): + async def exercise(): + stream = _FakeAsyncIterable(chunks=["chunk"]) + wrapper = _TestAsyncStreamWrapper(stream) + + assert await anext(wrapper) == "chunk" + assert wrapper.processed == ["chunk"] + + with pytest.raises(StopAsyncIteration): + await anext(wrapper) + + assert wrapper.stop_count == 1 + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_fails_stream_errors(): + async def exercise(): + error = ValueError("boom") + wrapper = _TestAsyncStreamWrapper(_FakeAsyncStream(error=error)) + + with pytest.raises(ValueError): + await anext(wrapper) + + assert wrapper.failures == [error] + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_close_stops_once(): + async def exercise(): + stream = _FakeAsyncStream(chunks=["chunk"]) + wrapper = _TestAsyncStreamWrapper(stream) + + await wrapper.close() + await wrapper.close() + + assert stream.close_count == 2 + assert wrapper.stop_count == 1 + assert not wrapper.failures + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_close_fails_with_close_error(): + async def exercise(): + error = RuntimeError("close failure") + wrapper = _TestAsyncStreamWrapper( + _FakeAsyncStream(chunks=["chunk"], close_error=error) + ) + + with pytest.raises(RuntimeError, match="close failure"): + await wrapper.close() + + assert wrapper.failures == [error] + assert wrapper.stop_count == 0 + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_exit_closes_and_propagates_user_errors(): + async def exercise(): + stream = _FakeAsyncStream(chunks=["chunk"]) + wrapper = _TestAsyncStreamWrapper(stream) + error = RuntimeError("user failure") + + assert await wrapper.__aexit__(RuntimeError, error, None) is False + + assert stream.close_count == 1 + assert wrapper.stop_count == 0 + assert wrapper.failures == [error] + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_exit_keeps_user_error_when_close_fails(): + async def exercise(): + close_error = RuntimeError("close failure") + stream = _FakeAsyncStream(chunks=["chunk"], close_error=close_error) + wrapper = _TestAsyncStreamWrapper(stream) + error = RuntimeError("user failure") + + assert await wrapper.__aexit__(RuntimeError, error, None) is False + + assert stream.close_count == 1 + assert wrapper.failures == [error] + assert wrapper.stop_count == 0 + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_swallows_finalize_errors(): + async def exercise(): + wrapper = _FailingAsyncStopStreamWrapper(_FakeAsyncStream()) + + await wrapper.close() + await wrapper.close() + + assert wrapper.stop_count == 1 + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_swallows_failure_finalize_errors(): + async def exercise(): + close_error = RuntimeError("close failure") + stream = _FakeAsyncStream(close_error=close_error) + wrapper = _FailingAsyncFailStreamWrapper(stream) + + with pytest.raises(RuntimeError, match="close failure"): + await wrapper.close() + stream._close_error = None + await wrapper.close() + + assert wrapper.failures == [close_error] + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_swallows_stop_iteration_finalize_errors(): + async def exercise(): + wrapper = _FailingAsyncStopStreamWrapper(_FakeAsyncStream()) + + with pytest.raises(StopAsyncIteration): + await anext(wrapper) + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_preserves_stream_error_when_finalize_fails(): + async def exercise(): + error = RuntimeError("stream failure") + wrapper = _FailingAsyncFailStreamWrapper(_FakeAsyncStream(error=error)) + + with pytest.raises(RuntimeError, match="stream failure"): + await anext(wrapper) + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_getattr_passthrough(): + wrapper = _TestAsyncStreamWrapper(_FakeAsyncStream()) + + assert wrapper.extra_attribute == "passthrough" + + +def test_async_stream_wrapper_stop_iteration_does_not_double_finalize(): + async def exercise(): + wrapper = _TestAsyncStreamWrapper(_FakeAsyncStream()) + + with pytest.raises(StopAsyncIteration): + await anext(wrapper) + await wrapper.close() + + assert wrapper.stop_count == 1 + assert not wrapper.failures + + asyncio.run(exercise()) + + +def test_async_stream_wrapper_swallows_process_chunk_errors(): + async def exercise(): + wrapper = _FailingAsyncProcessStreamWrapper( + _FakeAsyncStream(chunks=["chunk"]) + ) + + assert await anext(wrapper) == "chunk" + assert not wrapper.failures + + asyncio.run(exercise())