diff --git a/langfuse/_client/observe.py b/langfuse/_client/observe.py index 64882a20f..cd541a1aa 100644 --- a/langfuse/_client/observe.py +++ b/langfuse/_client/observe.py @@ -632,7 +632,17 @@ def __next__(self) -> Any: class _ContextPreservedAsyncGeneratorWrapper: - """Async generator wrapper that ensures each iteration runs in preserved context.""" + """Async generator wrapper that ensures each iteration runs in preserved context. + + .. note:: + The wrapper snapshots the caller's contextvars once at construction time + and re-applies those values before each ``__anext__`` call. This means + mutations issued by the generator body *across* ``yield`` points are + discarded between iterations. For Langfuse's own tracing this is fine + because child spans are opened and closed within a single ``__anext__`` + call, but user generators that rely on context-var state persisting + across yields will not see those changes. + """ def __init__( self, @@ -694,19 +704,24 @@ async def aclose(self) -> None: if self._span_ended: return + # Apply preserved context to current task without creating a new + # task, so that asyncio.timeout / asyncio.current_task() bindings + # remain intact (see langfuse/langfuse#13349). + tokens: list[tuple[contextvars.ContextVar[Any], contextvars.Token[Any]]] = [] + for var in list(self.context.keys()): + val = self.context.get(var) + tokens.append((var, var.set(val))) try: try: - await asyncio.create_task( - self.generator.aclose(), - context=self.context, - ) # type: ignore - except TypeError: - await self.context.run(asyncio.create_task, self.generator.aclose()) - except (Exception, asyncio.CancelledError) as error: - self._finalize_with_error(error) - raise - else: - self._finalize() + await self.generator.aclose() + except (Exception, asyncio.CancelledError) as error: + self._finalize_with_error(error) + raise + else: + self._finalize() + finally: + for var, token in reversed(tokens): + var.reset(token) async def close(self) -> None: await self.aclose() @@ -716,19 +731,18 @@ def __del__(self) -> None: async def __anext__(self) -> Any: try: - # Run the generator's __anext__ in the preserved context + # Apply preserved context to current task without creating a new + # task, so that asyncio.timeout / asyncio.current_task() bindings + # remain intact (see langfuse/langfuse#13349). + tokens: list[tuple[contextvars.ContextVar[Any], contextvars.Token[Any]]] = [] + for var in list(self.context.keys()): + val = self.context.get(var) + tokens.append((var, var.set(val))) try: - # Python 3.11+ approach with explicit task context - item = await asyncio.create_task( - self.generator.__anext__(), # type: ignore - context=self.context, - ) # type: ignore - except TypeError: - # Python 3.10 fallback - create the task inside the preserved context. - item = await self.context.run( - asyncio.create_task, - self.generator.__anext__(), # type: ignore - ) + item = await self.generator.__anext__() # type: ignore + finally: + for var, token in reversed(tokens): + var.reset(token) if self.capture_output: self.items.append(item) diff --git a/tests/unit/test_observe.py b/tests/unit/test_observe.py index 5527be9b9..803f91f7b 100644 --- a/tests/unit/test_observe.py +++ b/tests/unit/test_observe.py @@ -271,27 +271,53 @@ async def generator() -> AsyncGenerator[str, None]: @pytest.mark.asyncio -async def test_async_generator_wrapper_fallback_preserves_context( - monkeypatch: pytest.MonkeyPatch, -) -> None: +async def test_async_generator_wrapper_preserves_asyncio_timeout() -> None: + """__anext__ must not create new tasks, otherwise asyncio.timeout + loses track of the task it should cancel (langfuse/langfuse#13349).""" marker = contextvars.ContextVar("marker", default="ambient") seen: list[str] = [] - original_create_task = asyncio.create_task - def create_task_with_type_error(*args: Any, **kwargs: Any) -> asyncio.Task[Any]: - if "context" in kwargs: - raise TypeError("context argument unsupported") + async def generator() -> AsyncGenerator[str, None]: + for i in range(3): + seen.append(marker.get()) + yield f"item_{i}" + await asyncio.sleep(0) - return original_create_task(*args, **kwargs) + span = SpanRecorder() + context = contextvars.copy_context() + context.run(marker.set, "preserved") + wrapper = _ContextPreservedAsyncGeneratorWrapper( + generator(), + context, + cast(Any, span), + False, + None, + ) - monkeypatch.setattr(asyncio, "create_task", create_task_with_type_error) + # asyncio.timeout should work — the generator completes before the deadline + async with asyncio.timeout(1.0): + items = [] + async for item in wrapper: + items.append(item) + + assert items == ["item_0", "item_1", "item_2"] + assert seen == ["preserved", "preserved", "preserved"] + assert span.ended == 1 + + +@pytest.mark.asyncio +async def test_async_generator_wrapper_respects_asyncio_timeout() -> None: + """asyncio.timeout must be able to cancel a hanging generator decorated + with @observe (langfuse/langfuse#13349). Without the fix, the timeout + becomes a no-op because each __anext__ creates a fresh task.""" + marker = contextvars.ContextVar("marker", default="ambient") + seen: list[str] = [] async def generator() -> AsyncGenerator[str, None]: - try: - yield marker.get() - yield "item_1" - finally: + for i in range(10): seen.append(marker.get()) + yield f"item_{i}" + await asyncio.sleep(0.1) span = SpanRecorder() context = contextvars.copy_context() @@ -304,12 +330,14 @@ async def generator() -> AsyncGenerator[str, None]: None, ) - assert await wrapper.__anext__() == "preserved" - marker.set("ambient-now") - - await wrapper.aclose() + with pytest.raises(asyncio.TimeoutError): + async with asyncio.timeout(0.05): + async for _item in wrapper: + pass - assert seen == ["preserved"] + # Should have yielded at least once before the timeout fired + assert len(seen) >= 1 + assert all(v == "preserved" for v in seen) assert span.ended == 1