Skip to content

Commit af25037

Browse files
authored
fix: do not post_process before finally in ModelOutputThunk.astream (#580)
* fix: do not post_process before finally in ModelOutputThunk.astream Signed-off-by: Paul S. Schweigert <paul@paulschweigert.com> * add test Signed-off-by: Paul S. Schweigert <paul@paulschweigert.com> * cleanup Signed-off-by: Paul S. Schweigert <paul@paulschweigert.com> * handle spans Signed-off-by: Paul S. Schweigert <paul@paulschweigert.com> --------- Signed-off-by: Paul S. Schweigert <paul@paulschweigert.com>
1 parent dfc8942 commit af25037

2 files changed

Lines changed: 182 additions & 80 deletions

File tree

mellea/core/base.py

Lines changed: 76 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -282,88 +282,84 @@ async def astream(self) -> str:
282282
0 if self._underlying_value is None else len(str(self._underlying_value))
283283
) # type: ignore
284284

285-
exception_to_raise = None
286-
try:
287-
# Type of the chunk depends on the backend.
288-
chunks: list[Any | None] = []
289-
while True:
290-
try:
291-
item = self._async_queue.get_nowait()
292-
chunks.append(item)
293-
except asyncio.QueueEmpty:
294-
# We've exhausted the current items in the queue.
295-
break
296-
297-
# Make sure we always get the minimum chunk size.
298-
while len(chunks) <= self._chunk_size:
299-
if len(chunks) > 0:
300-
if chunks[-1] is None or isinstance(chunks[-1], Exception):
301-
break # Hit sentinel value or an error.
302-
# We could switch to relying on the `done` / `finish_reason` field of chunks,
303-
# but that forces us to know about the chunk type here. Prefer sentinel values
304-
# for now.
305-
306-
item = await self._async_queue.get()
285+
# Type of the chunk depends on the backend.
286+
chunks: list[Any | None] = []
287+
while True:
288+
try:
289+
item = self._async_queue.get_nowait()
307290
chunks.append(item)
308-
309-
# Process the sentinel value if it's there.
310-
if chunks[-1] is None:
311-
chunks.pop() # Remove the sentinel value.
312-
do_set_computed = True
313-
314-
# Shouldn't be needed, but cancel the Tasks this ModelOutputThunk relied on.
315-
if self._generate is not None:
316-
self._generate.cancel()
317-
if self._generate_extra is not None:
318-
# Covers an hf edge case. The task is done generating anything useful but isn't `done` yet.
319-
await self._generate_extra
320-
self._generate_extra.cancel()
321-
322-
# If ModelOutputThunks get too bulky, we can do additional cleanup here
323-
# and set fields to None.
324-
325-
elif isinstance(chunks[-1], Exception):
326-
# Mark as computed so post_process runs in finally block
327-
self._computed = True
328-
# Store exception to re-raise after cleanup
329-
exception_to_raise = chunks[-1]
330-
331-
for chunk in chunks:
332-
assert self._process is not None
333-
await self._process(self, chunk)
334-
335-
if do_set_computed:
336-
assert self._underlying_value is not None
337-
self._computed = True
338-
finally:
339-
# Always call post_process if computed, even on exception
340-
# This ensures telemetry spans are properly closed
341-
if self._computed:
342-
assert self._post_process is not None
343-
await self._post_process(self)
344-
345-
# Only parse if no exception occurred
346-
if exception_to_raise is None:
347-
match self._action:
348-
case Component():
349-
self.parsed_repr = self._action._parse(self)
350-
case CBlock():
351-
assert self.value is not None, (
352-
"value must be non-None since this thunk is computed"
353-
)
354-
self.parsed_repr = self.value # type: ignore
355-
case _:
356-
raise ValueError(
357-
"attempted to astream from a model output thunk with no ._action set"
358-
)
359-
assert self.parsed_repr is not None, (
360-
"enforce constraint that a computed ModelOutputThunk has a non-None parsed_repr"
291+
except asyncio.QueueEmpty:
292+
# We've exhausted the current items in the queue.
293+
break
294+
295+
# Make sure we always get the minimum chunk size.
296+
while len(chunks) <= self._chunk_size:
297+
if len(chunks) > 0:
298+
if chunks[-1] is None or isinstance(chunks[-1], Exception):
299+
break # Hit sentinel value or an error.
300+
# We could switch to relying on the `done` / `finish_reason` field of chunks,
301+
# but that forces us to know about the chunk type here. Prefer sentinel values
302+
# for now.
303+
304+
item = await self._async_queue.get()
305+
chunks.append(item)
306+
307+
# Process the sentinel value if it's there.
308+
if chunks[-1] is None:
309+
chunks.pop() # Remove the sentinel value.
310+
do_set_computed = True
311+
312+
# Shouldn't be needed, but cancel the Tasks this ModelOutputThunk relied on.
313+
if self._generate is not None:
314+
self._generate.cancel()
315+
if self._generate_extra is not None:
316+
# Covers an hf edge case. The task is done generating anything useful but isn't `done` yet.
317+
await self._generate_extra
318+
self._generate_extra.cancel()
319+
320+
# If ModelOutputThunks get too bulky, we can do additional cleanup here
321+
# and set fields to None.
322+
323+
elif isinstance(chunks[-1], Exception):
324+
# Close any open telemetry span before propagating the error.
325+
# We can't call full post_process here (it assumes success invariants),
326+
# but we must not leak the span.
327+
span = self._meta.get("_telemetry_span")
328+
if span is not None:
329+
from ..telemetry import end_backend_span, set_span_error
330+
331+
set_span_error(span, chunks[-1])
332+
end_backend_span(span)
333+
del self._meta["_telemetry_span"]
334+
raise chunks[-1]
335+
336+
for chunk in chunks:
337+
assert self._process is not None
338+
await self._process(self, chunk)
339+
340+
if do_set_computed:
341+
assert self._underlying_value is not None
342+
self._computed = True
343+
344+
assert self._post_process is not None
345+
await self._post_process(self)
346+
347+
match self._action:
348+
case Component():
349+
self.parsed_repr = self._action._parse(self)
350+
case CBlock():
351+
assert self.value is not None, (
352+
"value must be non-None since this thunk is computed"
361353
)
362-
return self._underlying_value # type: ignore
363-
364-
# Re-raise exception after cleanup if one occurred
365-
if exception_to_raise is not None:
366-
raise exception_to_raise
354+
self.parsed_repr = self.value # type: ignore
355+
case _:
356+
raise ValueError(
357+
"attempted to astream from a model output thunk with no ._action set"
358+
)
359+
assert self.parsed_repr is not None, (
360+
"enforce constraint that a computed ModelOutputThunk has a non-None parsed_repr"
361+
)
362+
return self._underlying_value # type: ignore
367363

368364
return (
369365
self._underlying_value
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""Tests that exceptions during generation propagate correctly through ModelOutputThunk.astream().
2+
3+
Regression test for issue #577: post_process in a finally block was swallowing
4+
the original generation exception by raising a secondary error from post_process
5+
(which assumes system invariants that don't hold during failures).
6+
"""
7+
8+
import pytest
9+
10+
from mellea.core.base import CBlock, GenerateType, ModelOutputThunk
11+
12+
13+
async def _noop_process(mot, chunk):
14+
if mot._underlying_value is None:
15+
mot._underlying_value = ""
16+
mot._underlying_value += str(chunk)
17+
18+
19+
async def _failing_post_process(mot):
20+
raise RuntimeError("post_process failed due to broken invariants")
21+
22+
23+
def _make_thunk(post_process=_failing_post_process):
24+
mot = ModelOutputThunk(value=None)
25+
mot._generate_type = GenerateType.ASYNC
26+
mot._process = _noop_process
27+
mot._post_process = post_process
28+
mot._action = CBlock("test")
29+
mot._chunk_size = 0
30+
return mot
31+
32+
33+
@pytest.mark.parametrize(
34+
"error",
35+
[ValueError("connection reset by peer"), ConnectionError("server unavailable")],
36+
)
37+
async def test_astream_propagates_generation_exception(error):
38+
"""The original generation error must propagate, not a secondary error from post_process."""
39+
mot = _make_thunk()
40+
await mot._async_queue.put(error)
41+
42+
with pytest.raises(type(error), match=str(error)):
43+
await mot.astream()
44+
45+
46+
async def test_astream_post_process_only_called_on_success():
47+
"""post_process must be called on success but not on error."""
48+
post_process_called = False
49+
50+
async def _tracking_post_process(mot):
51+
nonlocal post_process_called
52+
post_process_called = True
53+
54+
# Error path: post_process should NOT be called
55+
mot = _make_thunk(post_process=_tracking_post_process)
56+
await mot._async_queue.put(RuntimeError("generation failed"))
57+
58+
with pytest.raises(RuntimeError, match="generation failed"):
59+
await mot.astream()
60+
61+
assert not post_process_called, (
62+
"post_process should not be called when generation fails"
63+
)
64+
65+
# Success path: post_process SHOULD be called
66+
post_process_called = False
67+
mot = _make_thunk(post_process=_tracking_post_process)
68+
await mot._async_queue.put("hello")
69+
await mot._async_queue.put(None) # sentinel for completion
70+
71+
await mot.astream()
72+
73+
assert post_process_called, "post_process should be called on successful completion"
74+
75+
76+
async def test_astream_closes_telemetry_span_on_error():
77+
"""Telemetry span must be ended and error recorded when generation fails."""
78+
from unittest.mock import MagicMock
79+
80+
mock_span = MagicMock()
81+
mot = _make_thunk()
82+
mot._meta["_telemetry_span"] = mock_span
83+
84+
error = ConnectionError("server unavailable")
85+
await mot._async_queue.put(error)
86+
87+
with pytest.raises(ConnectionError, match="server unavailable"):
88+
await mot.astream()
89+
90+
# Span should have been ended and cleaned up
91+
mock_span.record_exception.assert_called_once_with(error)
92+
mock_span.set_status.assert_called_once()
93+
mock_span.end.assert_called_once()
94+
assert "_telemetry_span" not in mot._meta
95+
96+
97+
async def test_astream_no_span_leak_when_no_telemetry():
98+
"""When no telemetry span is present, error propagation still works."""
99+
mot = _make_thunk()
100+
assert "_telemetry_span" not in mot._meta
101+
102+
error = ValueError("test error")
103+
await mot._async_queue.put(error)
104+
105+
with pytest.raises(ValueError, match="test error"):
106+
await mot.astream()

0 commit comments

Comments
 (0)