Skip to content

Commit 151b892

Browse files
committed
fix: prevent exception chunk from being passed to _process in astream
When a backend error (e.g. ollama.ResponseError) is propagated through the async queue, astream() stored the exception but left it in the chunks list. The subsequent _process() loop then tried to process the exception object as a real chunk, causing an AttributeError that masked the original backend error. Two fixes: 1. base.py: pop() the exception from chunks (like we already do for the None sentinel) so _process never receives it. 2. ollama.py: use .get() instead of [] for chat_response in post_processing, since the key may not exist if no valid chunks were processed before the error. Reproduces as: KeyError: 'chat_response' when an Ollama model returns a ResponseError (e.g. timeout on a large model). The post_processing runs in the finally block after the AttributeError, finds no chat_response key, and raises KeyError — masking the real error. Signed-off-by: 0xCUB3 <skula@mit.edu>
1 parent 6f3e131 commit 151b892

3 files changed

Lines changed: 142 additions & 3 deletions

File tree

mellea/backends/ollama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ async def post_processing(
596596
generate_log.backend = f"ollama::{self._get_ollama_model_id()}"
597597
generate_log.model_options = mot._model_options
598598
generate_log.date = datetime.datetime.now()
599-
generate_log.model_output = mot._meta["chat_response"]
599+
generate_log.model_output = mot._meta.get("chat_response")
600600
generate_log.extra = {
601601
"format": _format,
602602
"thinking": mot._model_options.get(ModelOption.THINKING, None),

mellea/core/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,8 @@ async def astream(self) -> str:
325325
elif isinstance(chunks[-1], Exception):
326326
# Mark as computed so post_process runs in finally block
327327
self._computed = True
328-
# Store exception to re-raise after cleanup
329-
exception_to_raise = chunks[-1]
328+
# Remove the exception from chunks so _process doesn't receive it
329+
exception_to_raise = chunks.pop()
330330

331331
for chunk in chunks:
332332
assert self._process is not None
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Tests for ModelOutputThunk.astream() exception handling.
2+
3+
Verifies that backend exceptions propagated through the async queue are
4+
handled correctly: the exception should not be fed into _process as a chunk,
5+
and post_processing should not crash if it runs during exception cleanup.
6+
"""
7+
8+
import asyncio
9+
import functools
10+
11+
import pytest
12+
13+
from mellea.core.base import GenerateType, ModelOutputThunk
14+
15+
16+
async def _noop_post_process(mot: ModelOutputThunk) -> None:
17+
"""Minimal post_process that doesn't touch _meta."""
18+
19+
20+
async def _tracking_process(
21+
mot: ModelOutputThunk, chunk: object, *, calls: list
22+
) -> None:
23+
"""Process callback that records each chunk it receives."""
24+
calls.append(chunk)
25+
if mot._underlying_value is None:
26+
mot._underlying_value = ""
27+
mot._underlying_value += str(chunk)
28+
29+
30+
def _make_mot(
31+
*, process_calls: list | None = None, post_process=None
32+
) -> ModelOutputThunk:
33+
"""Create a ModelOutputThunk wired for async queue consumption.
34+
35+
Args:
36+
process_calls: If provided, attaches a tracking _process that appends
37+
each chunk to this list.
38+
post_process: Custom post_process callback. Defaults to a noop.
39+
40+
Returns:
41+
A ModelOutputThunk ready for queue-based streaming.
42+
"""
43+
mot = ModelOutputThunk(None)
44+
mot._generate_type = GenerateType.ASYNC
45+
46+
if process_calls is not None:
47+
mot._process = functools.partial(_tracking_process, calls=process_calls)
48+
else:
49+
# Default process that just accumulates string content.
50+
async def _default_process(m: ModelOutputThunk, chunk: object) -> None:
51+
if m._underlying_value is None:
52+
m._underlying_value = ""
53+
m._underlying_value += str(chunk)
54+
55+
mot._process = _default_process
56+
57+
mot._post_process = post_process or _noop_post_process
58+
return mot
59+
60+
61+
async def test_exception_not_passed_to_process():
62+
"""Exception in the queue should not be passed to _process as a chunk."""
63+
calls: list = []
64+
mot = _make_mot(process_calls=calls)
65+
66+
await mot._async_queue.put(RuntimeError("backend failed"))
67+
68+
with pytest.raises(RuntimeError, match="backend failed"):
69+
await mot.astream()
70+
71+
# The exception should have been popped; _process should not receive it.
72+
assert len(calls) == 0, (
73+
f"_process should not have been called with an exception, got {calls}"
74+
)
75+
76+
77+
async def test_valid_chunks_before_exception_are_processed():
78+
"""Valid chunks before the exception should still be processed."""
79+
calls: list = []
80+
mot = _make_mot(process_calls=calls)
81+
82+
# Two valid chunks then an exception.
83+
await mot._async_queue.put("chunk1")
84+
await mot._async_queue.put("chunk2")
85+
await mot._async_queue.put(ValueError("oops"))
86+
87+
with pytest.raises(ValueError, match="oops"):
88+
await mot.astream()
89+
90+
assert calls == ["chunk1", "chunk2"]
91+
92+
93+
async def test_post_process_runs_on_exception():
94+
"""post_process should still run (for telemetry cleanup) when an exception occurs."""
95+
post_process_called = False
96+
97+
async def _tracking_post(mot: ModelOutputThunk) -> None:
98+
nonlocal post_process_called
99+
post_process_called = True
100+
101+
mot = _make_mot(post_process=_tracking_post)
102+
await mot._async_queue.put(RuntimeError("backend failed"))
103+
104+
with pytest.raises(RuntimeError, match="backend failed"):
105+
await mot.astream()
106+
107+
assert post_process_called, "post_process should run even when an exception occurs"
108+
109+
110+
async def test_post_process_with_missing_meta_key():
111+
"""post_process accessing a missing _meta key should not mask the real exception.
112+
113+
This reproduces the KeyError: 'chat_response' scenario where
114+
post_processing tries to read a key that was never set because
115+
no valid chunks were processed.
116+
"""
117+
118+
async def _fragile_post(mot: ModelOutputThunk) -> None:
119+
# Simulates ollama post_processing using .get() (the fix) instead
120+
# of direct [] access. This should not raise.
121+
_ = mot._meta.get("chat_response")
122+
123+
mot = _make_mot(post_process=_fragile_post)
124+
await mot._async_queue.put(RuntimeError("backend failed"))
125+
126+
# The original backend error should propagate, not a KeyError.
127+
with pytest.raises(RuntimeError, match="backend failed"):
128+
await mot.astream()
129+
130+
131+
async def test_exception_only_queue_marks_computed():
132+
"""A queue with only an exception should still mark the thunk as computed."""
133+
mot = _make_mot()
134+
await mot._async_queue.put(RuntimeError("backend failed"))
135+
136+
with pytest.raises(RuntimeError, match="backend failed"):
137+
await mot.astream()
138+
139+
assert mot._computed, "Thunk should be marked computed after exception"

0 commit comments

Comments
 (0)