Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 38 additions & 24 deletions langfuse/_client/observe.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,17 @@ def __next__(self) -> Any:


class _ContextPreservedAsyncGeneratorWrapper:
"""Async generator wrapper that ensures each iteration runs in preserved context."""
"""Async generator wrapper that ensures each iteration runs in preserved context.

.. note::
The wrapper snapshots the caller's contextvars once at construction time
and re-applies those values before each ``__anext__`` call. This means
mutations issued by the generator body *across* ``yield`` points are
discarded between iterations. For Langfuse's own tracing this is fine
because child spans are opened and closed within a single ``__anext__``
call, but user generators that rely on context-var state persisting
across yields will not see those changes.
"""

def __init__(
self,
Expand Down Expand Up @@ -694,19 +704,24 @@ async def aclose(self) -> None:
if self._span_ended:
return

# Apply preserved context to current task without creating a new
# task, so that asyncio.timeout / asyncio.current_task() bindings
# remain intact (see langfuse/langfuse#13349).
tokens: list[tuple[contextvars.ContextVar[Any], contextvars.Token[Any]]] = []
for var in list(self.context.keys()):
val = self.context.get(var)
tokens.append((var, var.set(val)))
try:
try:
await asyncio.create_task(
self.generator.aclose(),
context=self.context,
) # type: ignore
except TypeError:
await self.context.run(asyncio.create_task, self.generator.aclose())
except (Exception, asyncio.CancelledError) as error:
self._finalize_with_error(error)
raise
else:
self._finalize()
await self.generator.aclose()
except (Exception, asyncio.CancelledError) as error:
self._finalize_with_error(error)
raise
else:
self._finalize()
finally:
for var, token in reversed(tokens):
var.reset(token)

async def close(self) -> None:
await self.aclose()
Expand All @@ -716,19 +731,18 @@ def __del__(self) -> None:

async def __anext__(self) -> Any:
try:
# Run the generator's __anext__ in the preserved context
# Apply preserved context to current task without creating a new
# task, so that asyncio.timeout / asyncio.current_task() bindings
# remain intact (see langfuse/langfuse#13349).
tokens: list[tuple[contextvars.ContextVar[Any], contextvars.Token[Any]]] = []
for var in list(self.context.keys()):
val = self.context.get(var)
tokens.append((var, var.set(val)))
try:
# Python 3.11+ approach with explicit task context
item = await asyncio.create_task(
self.generator.__anext__(), # type: ignore
context=self.context,
) # type: ignore
except TypeError:
# Python 3.10 fallback - create the task inside the preserved context.
item = await self.context.run(
asyncio.create_task,
self.generator.__anext__(), # type: ignore
)
item = await self.generator.__anext__() # type: ignore
finally:
for var, token in reversed(tokens):
var.reset(token)
Comment thread
goingforstudying-ctrl marked this conversation as resolved.

if self.capture_output:
self.items.append(item)
Expand Down
64 changes: 46 additions & 18 deletions tests/unit/test_observe.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,27 +271,53 @@ async def generator() -> AsyncGenerator[str, None]:


@pytest.mark.asyncio
async def test_async_generator_wrapper_fallback_preserves_context(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def test_async_generator_wrapper_preserves_asyncio_timeout() -> None:
"""__anext__ must not create new tasks, otherwise asyncio.timeout
loses track of the task it should cancel (langfuse/langfuse#13349)."""
marker = contextvars.ContextVar("marker", default="ambient")
seen: list[str] = []
original_create_task = asyncio.create_task

def create_task_with_type_error(*args: Any, **kwargs: Any) -> asyncio.Task[Any]:
if "context" in kwargs:
raise TypeError("context argument unsupported")
async def generator() -> AsyncGenerator[str, None]:
for i in range(3):
seen.append(marker.get())
yield f"item_{i}"
await asyncio.sleep(0)

return original_create_task(*args, **kwargs)
span = SpanRecorder()
context = contextvars.copy_context()
context.run(marker.set, "preserved")
wrapper = _ContextPreservedAsyncGeneratorWrapper(
generator(),
context,
cast(Any, span),
False,
None,
)

monkeypatch.setattr(asyncio, "create_task", create_task_with_type_error)
# asyncio.timeout should work — the generator completes before the deadline
async with asyncio.timeout(1.0):
items = []
async for item in wrapper:
items.append(item)

assert items == ["item_0", "item_1", "item_2"]
assert seen == ["preserved", "preserved", "preserved"]
assert span.ended == 1


@pytest.mark.asyncio
async def test_async_generator_wrapper_respects_asyncio_timeout() -> None:
"""asyncio.timeout must be able to cancel a hanging generator decorated
with @observe (langfuse/langfuse#13349). Without the fix, the timeout
becomes a no-op because each __anext__ creates a fresh task."""
marker = contextvars.ContextVar("marker", default="ambient")
seen: list[str] = []

async def generator() -> AsyncGenerator[str, None]:
try:
yield marker.get()
yield "item_1"
finally:
for i in range(10):
seen.append(marker.get())
yield f"item_{i}"
await asyncio.sleep(0.1)

span = SpanRecorder()
context = contextvars.copy_context()
Expand All @@ -304,12 +330,14 @@ async def generator() -> AsyncGenerator[str, None]:
None,
)

assert await wrapper.__anext__() == "preserved"
marker.set("ambient-now")

await wrapper.aclose()
with pytest.raises(asyncio.TimeoutError):
async with asyncio.timeout(0.05):
async for _item in wrapper:
pass

assert seen == ["preserved"]
# Should have yielded at least once before the timeout fired
assert len(seen) >= 1
assert all(v == "preserved" for v in seen)
Comment thread
goingforstudying-ctrl marked this conversation as resolved.
assert span.ended == 1


Expand Down