|
| 1 | +"""Tests for ModelOutputThunk.astream() exception handling. |
| 2 | +
|
| 3 | +Verifies that backend exceptions propagated through the async queue are |
| 4 | +handled correctly: the exception should not be fed into _process as a chunk, |
| 5 | +and post_processing should not crash if it runs during exception cleanup. |
| 6 | +""" |
| 7 | + |
| 8 | +import asyncio |
| 9 | +import functools |
| 10 | + |
| 11 | +import pytest |
| 12 | + |
| 13 | +from mellea.core.base import GenerateType, ModelOutputThunk |
| 14 | + |
| 15 | + |
| 16 | +async def _noop_post_process(mot: ModelOutputThunk) -> None: |
| 17 | + """Minimal post_process that doesn't touch _meta.""" |
| 18 | + |
| 19 | + |
| 20 | +async def _tracking_process( |
| 21 | + mot: ModelOutputThunk, chunk: object, *, calls: list |
| 22 | +) -> None: |
| 23 | + """Process callback that records each chunk it receives.""" |
| 24 | + calls.append(chunk) |
| 25 | + if mot._underlying_value is None: |
| 26 | + mot._underlying_value = "" |
| 27 | + mot._underlying_value += str(chunk) |
| 28 | + |
| 29 | + |
| 30 | +def _make_mot( |
| 31 | + *, process_calls: list | None = None, post_process=None |
| 32 | +) -> ModelOutputThunk: |
| 33 | + """Create a ModelOutputThunk wired for async queue consumption. |
| 34 | +
|
| 35 | + Args: |
| 36 | + process_calls: If provided, attaches a tracking _process that appends |
| 37 | + each chunk to this list. |
| 38 | + post_process: Custom post_process callback. Defaults to a noop. |
| 39 | +
|
| 40 | + Returns: |
| 41 | + A ModelOutputThunk ready for queue-based streaming. |
| 42 | + """ |
| 43 | + mot = ModelOutputThunk(None) |
| 44 | + mot._generate_type = GenerateType.ASYNC |
| 45 | + |
| 46 | + if process_calls is not None: |
| 47 | + mot._process = functools.partial(_tracking_process, calls=process_calls) |
| 48 | + else: |
| 49 | + # Default process that just accumulates string content. |
| 50 | + async def _default_process(m: ModelOutputThunk, chunk: object) -> None: |
| 51 | + if m._underlying_value is None: |
| 52 | + m._underlying_value = "" |
| 53 | + m._underlying_value += str(chunk) |
| 54 | + |
| 55 | + mot._process = _default_process |
| 56 | + |
| 57 | + mot._post_process = post_process or _noop_post_process |
| 58 | + return mot |
| 59 | + |
| 60 | + |
| 61 | +async def test_exception_not_passed_to_process(): |
| 62 | + """Exception in the queue should not be passed to _process as a chunk.""" |
| 63 | + calls: list = [] |
| 64 | + mot = _make_mot(process_calls=calls) |
| 65 | + |
| 66 | + await mot._async_queue.put(RuntimeError("backend failed")) |
| 67 | + |
| 68 | + with pytest.raises(RuntimeError, match="backend failed"): |
| 69 | + await mot.astream() |
| 70 | + |
| 71 | + # The exception should have been popped; _process should not receive it. |
| 72 | + assert len(calls) == 0, ( |
| 73 | + f"_process should not have been called with an exception, got {calls}" |
| 74 | + ) |
| 75 | + |
| 76 | + |
| 77 | +async def test_valid_chunks_before_exception_are_processed(): |
| 78 | + """Valid chunks before the exception should still be processed.""" |
| 79 | + calls: list = [] |
| 80 | + mot = _make_mot(process_calls=calls) |
| 81 | + |
| 82 | + # Two valid chunks then an exception. |
| 83 | + await mot._async_queue.put("chunk1") |
| 84 | + await mot._async_queue.put("chunk2") |
| 85 | + await mot._async_queue.put(ValueError("oops")) |
| 86 | + |
| 87 | + with pytest.raises(ValueError, match="oops"): |
| 88 | + await mot.astream() |
| 89 | + |
| 90 | + assert calls == ["chunk1", "chunk2"] |
| 91 | + |
| 92 | + |
| 93 | +async def test_post_process_runs_on_exception(): |
| 94 | + """post_process should still run (for telemetry cleanup) when an exception occurs.""" |
| 95 | + post_process_called = False |
| 96 | + |
| 97 | + async def _tracking_post(mot: ModelOutputThunk) -> None: |
| 98 | + nonlocal post_process_called |
| 99 | + post_process_called = True |
| 100 | + |
| 101 | + mot = _make_mot(post_process=_tracking_post) |
| 102 | + await mot._async_queue.put(RuntimeError("backend failed")) |
| 103 | + |
| 104 | + with pytest.raises(RuntimeError, match="backend failed"): |
| 105 | + await mot.astream() |
| 106 | + |
| 107 | + assert post_process_called, "post_process should run even when an exception occurs" |
| 108 | + |
| 109 | + |
| 110 | +async def test_post_process_with_missing_meta_key(): |
| 111 | + """post_process accessing a missing _meta key should not mask the real exception. |
| 112 | +
|
| 113 | + This reproduces the KeyError: 'chat_response' scenario where |
| 114 | + post_processing tries to read a key that was never set because |
| 115 | + no valid chunks were processed. |
| 116 | + """ |
| 117 | + |
| 118 | + async def _fragile_post(mot: ModelOutputThunk) -> None: |
| 119 | + # Simulates ollama post_processing using .get() (the fix) instead |
| 120 | + # of direct [] access. This should not raise. |
| 121 | + _ = mot._meta.get("chat_response") |
| 122 | + |
| 123 | + mot = _make_mot(post_process=_fragile_post) |
| 124 | + await mot._async_queue.put(RuntimeError("backend failed")) |
| 125 | + |
| 126 | + # The original backend error should propagate, not a KeyError. |
| 127 | + with pytest.raises(RuntimeError, match="backend failed"): |
| 128 | + await mot.astream() |
| 129 | + |
| 130 | + |
| 131 | +async def test_exception_only_queue_marks_computed(): |
| 132 | + """A queue with only an exception should still mark the thunk as computed.""" |
| 133 | + mot = _make_mot() |
| 134 | + await mot._async_queue.put(RuntimeError("backend failed")) |
| 135 | + |
| 136 | + with pytest.raises(RuntimeError, match="backend failed"): |
| 137 | + await mot.astream() |
| 138 | + |
| 139 | + assert mot._computed, "Thunk should be marked computed after exception" |
0 commit comments