From 8f5ca4ee12426acb40b2d9778c8d1305e72c06a7 Mon Sep 17 00:00:00 2001 From: Anthony Casagrande Date: Tue, 30 Jun 2026 10:29:37 -0700 Subject: [PATCH 1/2] feat(weka): guarantee trailing-user turns via two-pass block planning Reconstructed weka turns could end with a trailing assistant segment when a block-aligned context pull-back truncated onto an assistant block (376/98,827 turns on the 062126 corpus). A chat request must end with a user message. Add compute_asst_block_caps: a role-independent forward pass that finds every degenerate pull-back's truncation target and caps the assistant block count of the turn that created that block, so the boundary lands on a user block. The boundary is fixed at creation and never relabeled, preserving cross-turn KV-cache reuse. Wire the cap into all five reconstruction loops (serial parent/child/flat-chain + parallel parent/child). Extract compute_turn_block_geometry as the single source of truth advance_turn and the planner share; the planner mutates its block tile in place to stay O(new) per turn. Drives trailing-assistant turns to 0 across the full 393-trace corpus with no crashes; byte-exact sum(seg.tokens)==in_tokens preserved. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Anthony Casagrande --- .../dataset/loader/weka_parallel_convert.py | 11 + src/aiperf/dataset/loader/weka_synth_buf.py | 237 +++++++++++- src/aiperf/dataset/loader/weka_trace.py | 17 + .../dataset/loader/test_weka_synth_buf.py | 366 ++++++++++++++++-- .../loader/test_weka_synth_buf_turn_delta.py | 48 ++- 5 files changed, 620 insertions(+), 59 deletions(-) diff --git a/src/aiperf/dataset/loader/weka_parallel_convert.py b/src/aiperf/dataset/loader/weka_parallel_convert.py index 569237ceb..95bda3257 100644 --- a/src/aiperf/dataset/loader/weka_parallel_convert.py +++ b/src/aiperf/dataset/loader/weka_parallel_convert.py @@ -357,6 +357,7 @@ def _process_task(task: _WekaTraceTask) -> _WekaProcessTaskResult: assert _worker_state is not None from aiperf.dataset.loader.weka_synth_buf import ( ConversationReconstructor, + compute_asst_block_caps, ) state = _worker_state @@ -382,6 +383,10 @@ def _process_task(task: _WekaTraceTask) -> _WekaProcessTaskResult: parent_turns: list[_WekaParentTurnDict] = [] outer_to_turn_pos: dict[int, int] = {} normals: list[tuple[int, _WekaNormalRequestPayload]] = parent["normals"] + parent_asst_block_caps = compute_asst_block_caps( + [(r["hash_ids"], r["input_length"]) for _, r in normals], + bs, + ) for k, (outer_idx, req) in enumerate(normals): seed = f"{task.trace_id}:turn_{k}:partial_tail" is_tool_result = req.get("input_kind") == "tool_result" @@ -404,6 +409,7 @@ def _process_task(task: _WekaTraceTask) -> _WekaProcessTaskResult: curr_in_tokens=req["input_length"], seed=seed, is_tool_result=is_tool_result, + max_asst_blocks=parent_asst_block_caps[k], ) if "effective_t" in req: @@ -617,6 +623,10 @@ def _process_task(task: _WekaTraceTask) -> _WekaProcessTaskResult: child_turns: list[_WekaParentTurnDict] = [] creqs: list[_WekaNormalRequestPayload] = cp["requests"] + child_asst_block_caps = compute_asst_block_caps( + [(r["hash_ids"], r["input_length"]) for r in creqs], + bs, + ) for k, creq in enumerate(creqs): seed = f"{cp['session_id']}:turn_{k}:partial_tail" is_tool_result = creq.get("input_kind") == "tool_result" @@ -639,6 +649,7 @@ def _process_task(task: _WekaTraceTask) -> _WekaProcessTaskResult: curr_in_tokens=creq["input_length"], seed=seed, is_tool_result=is_tool_result, + max_asst_blocks=child_asst_block_caps[k], ) if "effective_t" in creq: t_ms = creq["effective_t"] * 1000.0 diff --git a/src/aiperf/dataset/loader/weka_synth_buf.py b/src/aiperf/dataset/loader/weka_synth_buf.py index ebe6a97b5..f9afe0cc7 100644 --- a/src/aiperf/dataset/loader/weka_synth_buf.py +++ b/src/aiperf/dataset/loader/weka_synth_buf.py @@ -14,6 +14,8 @@ from dataclasses import dataclass, field from typing import Literal +from aiperf.common.aiperf_logger import AIPerfLogger + # Re-exported for backwards compatibility; the composer lives in its own # module since it is independent of the synthesis-buffer state machine. from aiperf.dataset.loader.weka_prompt_compose import ( # noqa: E402 @@ -24,6 +26,8 @@ tool_shape_segment_messages, ) +_logger = AIPerfLogger(__name__) + @dataclass class TurnDelta: @@ -116,6 +120,17 @@ class ConversationReconstructor: _emitted_segment_count: int = 0 _last_disturbance_at: int | None = None _turn_index: int = 0 + _trailing_non_user_turns: list[int] = field(default_factory=list) + """Turn indices whose segment list could NOT be made to end with a user + block. The wire invariant (every turn ends with a user message) is + enforced for every turn that contributes any new content, but a turn that + appends zero new tokens (a fully block-aligned pull-back/compaction: + ``new_blocks_count == 0`` with no partial/missing tail) has nothing to + relabel as user, and a turn-0 prompt entirely consumed by the tool/system + prefix has only the cached system segment. Those degenerate shapes leave a + trailing non-user segment; they are recorded here (and warned once) rather + than faked, since synthesizing tokens would break the byte-exact + ``sum(seg.tokens) == in_tokens`` invariant.""" def init_turn_0( self, @@ -241,6 +256,7 @@ def init_turn_0( self._segments = segs self._emitted_segment_count = 0 self._last_disturbance_at = None + self._assert_trailing_user() def advance_turn( self, @@ -251,6 +267,8 @@ def advance_turn( curr_in_tokens: int, seed: str, is_tool_result: bool = False, + *, + max_asst_blocks: int | None = None, ) -> None: """Advance synth_buf to turn k via LCP-driven symmetric attribution. @@ -276,25 +294,41 @@ def advance_turn( hash-content invariant — every cached block emits its full content, unmodified by any terminator stamp. + Wire invariant: the segment list always ends with a user segment. A + chat request cannot ask the model to continue from a trailing + assistant message, so when the assistant target would consume the + entire new region and no partial/missing tail remains to seed a + trailing user segment, the final new block is handed back to the user + (relabeled, not resized — byte-exactness holds). The only turns that + cannot satisfy this are recorded on ``_trailing_non_user_turns`` (see + :meth:`_assert_trailing_user`): a turn that appends zero new tokens has + nothing to relabel. + ``prev_in_tokens`` is retained for the recorded-request contract (mirroring ``prev_hash_ids`` / ``prev_out_tokens``) but no longer feeds the truncation: the buffer self-describes its trailing overhang, which is also correct when the boundary segment is not the trailing one. + + ``max_asst_blocks`` is the optional two-pass planning cap from + :func:`compute_asst_block_caps`: an upper bound on this turn's assistant + block count chosen so a future block-aligned pull-back truncates onto a + user block instead of an assistant one. ``None`` means no cap. It only + ever lowers ``asst_blocks`` (an extra ``min``), so it composes with the + context-loss rule and the trailing-user reservation and never breaks the + byte-exact ``sum(seg.tokens) == curr_in_tokens`` invariant. """ bs = self.block_size - m_curr = len(curr_hash_ids) - m_curr_full = curr_in_tokens // bs - # Mirror init_turn_0's ``covered_blocks = min(m_full, len(hash_ids))``: - # a hashed-but-partial last block (len(curr_hash_ids) > curr_in_tokens - # // bs, e.g. in=250 with hash_ids=[..,4] at bs=64) covers only its - # ``curr_in_tokens % bs`` partial tail, not a full ``bs``-token block. - # Decoding it as a full block AND appending the partial tail below - # double-counts ~bs tokens and breaks the sum(seg.tokens) == - # curr_in_tokens invariant the byte-exact contract depends on. - m_curr_covered = min(m_curr, m_curr_full) - missing_block_tokens = max(0, (m_curr_full - m_curr) * bs) - lcp = longest_common_prefix(prev_hash_ids, curr_hash_ids) + # Single source of truth for this turn's block accounting (covered + # blocks, LCP, new-region size, synth tail), shared with the Pass-1 + # planner ``compute_asst_block_caps`` so the two never drift. + geo = compute_turn_block_geometry( + prev_hash_ids, curr_hash_ids, curr_in_tokens, bs + ) + lcp = geo.lcp + m_curr_covered = geo.m_curr_covered + synth_tail_n = geo.synth_tail_n + new_blocks_count = geo.new_blocks_count truncate_disturbance = truncate_synth_buf_at_block( self._segments, @@ -305,14 +339,11 @@ def advance_turn( self._last_disturbance_at = truncate_disturbance new_blocks = curr_hash_ids[lcp:m_curr_covered] - new_partial_tail_n = curr_in_tokens % bs new_region_tokens = self.decode_block_tokens(new_blocks) - synth_tail_n = missing_block_tokens + new_partial_tail_n if synth_tail_n > 0: new_region_tokens.extend( self.sample_partial_tail_tokens(synth_tail_n, seed) ) - new_blocks_count = max(0, m_curr_covered - lcp) self._turn_index += 1 asst_blocks_target = ( @@ -325,6 +356,25 @@ def advance_turn( # any user input. The whole new region becomes user content. asst_blocks_target = 0 asst_blocks = min(asst_blocks_target, new_blocks_count) + if max_asst_blocks is not None: + # Two-pass planning cap (see compute_asst_block_caps): a future + # block-aligned pull-back will truncate exactly to a boundary inside + # this turn's region. Shrinking the assistant here so that boundary + # block lands in THIS turn's user segment makes that future turn end + # with a user block — decided up front and stable across every turn, + # so the rendered prefix never diverges and KV-cache reuse holds. + asst_blocks = min(asst_blocks, max_asst_blocks) + if synth_tail_n == 0 and asst_blocks == new_blocks_count and asst_blocks > 0: + # Wire invariant: every turn must end with a user message — a chat + # request cannot ask the model to continue from a trailing + # assistant message. When the assistant would consume the entire + # new region and no partial/missing tail remains to seed a trailing + # user segment, hand the final new block back to the user. This + # relabels block content (role only), so the byte-exact + # sum(seg.tokens) == curr_in_tokens invariant is preserved. When + # the region is a single block the assistant segment vanishes and + # the turn becomes user-only. + asst_blocks -= 1 asst_emit_size = asst_blocks * bs cursor = lcp @@ -361,6 +411,34 @@ def advance_turn( ) ) + self._assert_trailing_user() + + def _assert_trailing_user(self) -> None: + """Record (and warn once) when the segment list does not end with a user. + + The wire invariant — every turn ends with a user message — is enforced + for any turn that contributes new content. The only shapes that cannot + satisfy it without faking tokens are recorded on + ``_trailing_non_user_turns``: a turn appending zero new tokens (a fully + block-aligned pull-back with ``new_blocks_count == 0`` and no tail) has + nothing to relabel as user, and a turn-0 prompt entirely consumed by + the cached tool/system prefix has only a system segment. Synthesizing a + user block in either case would violate the byte-exact + ``sum(seg.tokens) == in_tokens`` invariant, so the caveat is surfaced + rather than hidden. + """ + if not self._segments or self._segments[-1].role == "user": + return + self._trailing_non_user_turns.append(self._turn_index) + trailing_role = self._segments[-1].role + _logger.warning( + f"weka reconstructor: turn {self._turn_index} ends with a " + f"'{trailing_role}' segment, not 'user'. The turn added no new " + f"content to relabel as a trailing user block (block-aligned " + f"pull-back or system-only turn-0); synthesizing one would break " + f"the byte-exact ISL invariant, so it is left as-is." + ) + def turn_delta(self) -> TurnDelta: """Compute the raw_messages to emit for the just-completed turn. @@ -435,6 +513,135 @@ def longest_common_prefix(prev_hash_ids: list[int], curr_hash_ids: list[int]) -> return n +@dataclass(frozen=True) +class TurnBlockGeometry: + """Role-independent block accounting for one turn transition. + + Single source of truth shared by + :meth:`ConversationReconstructor.advance_turn` (which truncates to ``lcp`` + then appends the new region) and :func:`compute_asst_block_caps` (which + replays the same tile to plan assistant caps). Depends only on hash ids and + token counts — never on role labeling — so the two callers cannot drift. + """ + + lcp: int + """Longest common prefix (in blocks) of prev/curr hash ids — the block the + buffer truncates back to before appending this turn's new region.""" + m_curr_covered: int + """Covered block count ``min(len(curr_hash_ids), curr_in_tokens // bs)``. A + partial last hashed block (``len > in // bs``) contributes only its tail, + not a full block.""" + new_blocks_count: int + """Full blocks appended after truncating to ``lcp``: ``max(0, m_curr_covered - lcp)``.""" + synth_tail_n: int + """Synthesized trailing tokens: the partial tail (``in % bs``) plus any + missing-block region when the recording stored fewer hash blocks than + ``in // bs``.""" + + +def compute_turn_block_geometry( + prev_hash_ids: list[int], + curr_hash_ids: list[int], + curr_in_tokens: int, + block_size: int, +) -> TurnBlockGeometry: + """Compute the role-independent block geometry for one turn transition. + + A hashed-but-partial last block (``len(curr_hash_ids) > curr_in_tokens // + bs``, e.g. ``in=250`` with ``hash_ids=[..,4]`` at ``bs=64``) covers only its + ``curr_in_tokens % bs`` partial tail, not a full ``bs``-token block — + counting it as a full block AND appending the tail would double-count ~``bs`` + tokens and break the ``sum(seg.tokens) == curr_in_tokens`` invariant. + """ + bs = block_size + m_curr = len(curr_hash_ids) + m_curr_full = curr_in_tokens // bs + m_curr_covered = min(m_curr, m_curr_full) + missing_block_tokens = max(0, (m_curr_full - m_curr) * bs) + lcp = longest_common_prefix(prev_hash_ids, curr_hash_ids) + return TurnBlockGeometry( + lcp=lcp, + m_curr_covered=m_curr_covered, + new_blocks_count=max(0, m_curr_covered - lcp), + synth_tail_n=curr_in_tokens % bs + missing_block_tokens, + ) + + +def compute_asst_block_caps( + turns: list[tuple[list[int], int]], + block_size: int, +) -> list[int | None]: + """Pass-1 planner: per-turn upper bounds on assistant block count. + + ``turns`` is one ``(hash_ids, in_tokens)`` pair per turn, in order. + + Returns ``caps`` where ``caps[k]`` bounds the assistant blocks + :meth:`ConversationReconstructor.advance_turn` may attribute at turn ``k`` + (``None`` = no constraint, including turn 0). The bounds make every future + block-aligned pull-back truncate onto a user block rather than an assistant + one, so each such turn ends with a user message. + + Pure and role-independent: LCP, the per-turn covered-block region, and the + block tile depend only on ``hash_ids`` / ``in_tokens``, never on how blocks + are later labeled. So the caps can be computed once up front and applied at + the turn that *creates* each block — the boundary is fixed on first emission + and never relabeled, preserving cross-turn KV-cache reuse. + + Mirrors ``advance_turn`` + ``truncate_synth_buf_at_block``: replays the + block tile (owning turn per block) via ``compute_turn_block_geometry``, + clamping ``eff_lcp = min(lcp, len(tile))`` since truncate can't grow the + buffer. A turn that appends nothing (``new_blocks_count == 0`` and + ``synth_tail_n == 0``) truncates onto block ``T-1`` (``T = eff_lcp``); its + owning turn ``j = tile[T-1]`` must keep that block in its user segment, so + ``cap_j = min((T-1) - eff_lcp_j)``. Turn-0 owners are skipped (no assistant). + """ + n_turns = len(turns) + caps: list[int | None] = [None] * n_turns + if n_turns == 0: + return caps + + bs = block_size + tile: list[int] = [] + eff_lcp_per_turn: list[int] = [0] * n_turns + + for k in range(n_turns): + hash_ids, in_tokens = turns[k] + prev_hash_ids = turns[k - 1][0] if k else [] + # Same geometry function advance_turn uses, so the two never drift. + geo = compute_turn_block_geometry(prev_hash_ids, hash_ids, in_tokens, bs) + if k == 0: + eff_lcp_per_turn[0] = 0 + tile = [0] * geo.m_curr_covered + continue + + # eff_lcp clamps the raw LCP to the prior buffer's real block count: + # truncate is a no-op past the buffer end and cannot grow the tile. + eff_lcp = min(geo.lcp, len(tile)) + eff_lcp_per_turn[k] = eff_lcp + new_blocks_count = max(0, geo.m_curr_covered - eff_lcp) + synth_tail_n = geo.synth_tail_n + + if new_blocks_count == 0 and synth_tail_n == 0: + # Degenerate pull-back: truncate to T=eff_lcp exposes block T-1. + # target == eff_lcp <= len(tile) by construction, so the only guard + # needed is target >= 1 (a turn-0-boundary pull-back has nothing to + # cap). + target = eff_lcp + if target >= 1: + owner = tile[target - 1] + if owner != 0: + bound = (target - 1) - eff_lcp_per_turn[owner] + if bound >= 0: + prev = caps[owner] + caps[owner] = bound if prev is None else min(prev, bound) + + # Mutate the tile suffix in place (O(new)) instead of rebuilding it. + del tile[eff_lcp:] + tile.extend([k] * new_blocks_count) + + return caps + + def truncate_synth_buf_at_block( segments: list[RoleSegment], target_blocks: int, diff --git a/src/aiperf/dataset/loader/weka_trace.py b/src/aiperf/dataset/loader/weka_trace.py index aaa5d4dbf..76546c0df 100644 --- a/src/aiperf/dataset/loader/weka_trace.py +++ b/src/aiperf/dataset/loader/weka_trace.py @@ -1709,6 +1709,7 @@ def _reconstruct_serial( ) from aiperf.dataset.loader.weka_synth_buf import ( ConversationReconstructor, + compute_asst_block_caps, ) flat_plans_by_trace: dict[str, list[_FlatChainPlan]] = defaultdict(list) @@ -1756,6 +1757,10 @@ def _reconstruct_serial( # First pass: emit turns from normal requests; track outer-index → turn-pos. outer_to_turn_pos: dict[int, int] = {} trace_metric_values = metric_values_by_trace[plan.trace_id] + asst_block_caps = compute_asst_block_caps( + [(r.hash_ids, r.input_length) for _, r in plan.normals], + plan.block_size, + ) for k, (outer_idx, req) in enumerate(plan.normals): seed = f"{plan.trace_id}:turn_{k}:partial_tail" input_kind = _classify_turn_input( @@ -1786,6 +1791,7 @@ def _reconstruct_serial( curr_in_tokens=req.input_length, seed=seed, is_tool_result=is_tool_result, + max_asst_blocks=asst_block_caps[k], ) # Turn.timestamp/delay are in milliseconds; weka traces record seconds. @@ -2094,6 +2100,10 @@ def _reconstruct_serial( ), ) child_metric_values = metric_values_by_trace[cp.parent_trace_id] + child_asst_block_caps = compute_asst_block_caps( + [(r.hash_ids, r.input_length) for r in cp.requests], + cp.block_size, + ) for k, creq in enumerate(cp.requests): seed = f"{cp.session_id}:turn_{k}:partial_tail" input_kind = _classify_turn_input( @@ -2119,6 +2129,7 @@ def _reconstruct_serial( curr_in_tokens=creq.input_length, seed=seed, is_tool_result=is_tool_result, + max_asst_blocks=child_asst_block_caps[k], ) trace_idle_timing = trace_idle_timing_by_trace.get(cp.parent_trace_id) if trace_idle_timing is not None: @@ -2197,7 +2208,12 @@ def _emit_flat_chain_conversation( replay_scope_id=fp.parent_trace_id, ) from aiperf.common.models import Turn + from aiperf.dataset.loader.weka_synth_buf import compute_asst_block_caps + asst_block_caps = compute_asst_block_caps( + [(r.hash_ids, r.input_length) for _, r in fp.requests], + fp.block_size, + ) for k, (_outer_idx, req) in enumerate(fp.requests): seed = f"{fp.session_id}:turn_{k}:partial_tail" input_kind = _classify_turn_input(req, fp.requests[k - 1][1] if k else None) @@ -2221,6 +2237,7 @@ def _emit_flat_chain_conversation( curr_in_tokens=req.input_length, seed=seed, is_tool_result=is_tool_result, + max_asst_blocks=asst_block_caps[k], ) t_ms, delay_ms = self._flat_turn_timing( fp=fp, diff --git a/tests/unit/dataset/loader/test_weka_synth_buf.py b/tests/unit/dataset/loader/test_weka_synth_buf.py index f590ffa1f..56fcb6d6f 100644 --- a/tests/unit/dataset/loader/test_weka_synth_buf.py +++ b/tests/unit/dataset/loader/test_weka_synth_buf.py @@ -22,6 +22,7 @@ from aiperf.dataset.loader.weka_synth_buf import ( ConversationReconstructor, RoleSegment, + compute_asst_block_caps, longest_common_prefix, truncate_synth_buf_at_block, ) @@ -513,8 +514,9 @@ def test_advance_pattern_c_pull_back(): # truncate at LCP=3: mid-segment cut on turn-0 user (kept_blocks=3) -> # block_count=3, len(tokens)=192. Trailing partial_tail/asst-overflow gone. # new_region = 2*64 + (320 mod 64) = 128 + 0 = 128 tokens. - # out=80 -> asst_blocks = ceil(80/64) = 2 -> asst_tokens = 128. - # user_blocks = 2 - 2 = 0 -> no user_k. + # out=80 -> asst_blocks_target = ceil(80/64) = 2 == new_blocks_count (2), + # synth_tail = 0 -> the final new block is handed to the user so the turn + # ends with a user segment: asst = 1 block (64), user_k = 1 block (64). r.advance_turn( prev_hash_ids=list(range(1, 11)), prev_in_tokens=620, @@ -524,17 +526,21 @@ def test_advance_pattern_c_pull_back(): seed="s1", ) roles = [s.role for s in r._segments] - assert roles == ["user", "assistant"] + assert roles == ["user", "assistant", "user"] assert r._segments[0].block_count == 3 assert r._segments[0].content_token_count == 192 - assert r._segments[1].content_token_count == 128 - assert r._segments[1].block_count == 2 - # Sum = 192 + 128 = 320 == curr_in_tokens. + assert r._segments[1].content_token_count == 64 + assert r._segments[1].block_count == 1 + assert r._segments[2].content_token_count == 64 + assert r._segments[2].block_count == 1 + # Sum = 192 + 64 + 64 = 320 == curr_in_tokens. assert sum(len(s.tokens) for s in r._segments) == 320 def test_advance_asst_overflow_pattern_a_template_drift(): - """new_region < ceil(out[k-1]/bs)*bs: asst clamped to fit, user empty.""" + """new_region < ceil(out[k-1]/bs)*bs: asst clamps to the region, but the + final block is reserved for the user so the turn ends with a user + segment.""" r = _make_recon() r.init_turn_0( hash_ids=[1, 2], in_tokens=128, tool_tokens=0, system_tokens=0, seed="s0" @@ -548,15 +554,19 @@ def test_advance_asst_overflow_pattern_a_template_drift(): seed="s1", ) # new_region = 2*64 = 128 tokens. asst_blocks_target = ceil(200/64) = 4, - # clamped to new_blocks_count = 2. asst_tokens = 128. user empty. + # clamped to new_blocks_count = 2. The region is tail-free, so the last + # block is handed to the user: asst = 1 block (64), user_k = 1 block (64). roles = [s.role for s in r._segments] - assert roles == ["user", "assistant"] - assert r._segments[1].content_token_count == 128 - assert r._segments[1].block_count == 2 + assert roles == ["user", "assistant", "user"] + assert r._segments[1].content_token_count == 64 + assert r._segments[1].block_count == 1 + assert r._segments[2].content_token_count == 64 + assert r._segments[2].block_count == 1 def test_advance_asst_overflow_pattern_c_deep_compaction(): - """Pattern C with new_region < ceil(out[k-1]/bs)*bs: asst clamped, no user_k.""" + """Pattern C, single-block tail-free region: the lone new block must seed + the trailing user segment, so the assistant segment vanishes entirely.""" r = _make_recon() r.init_turn_0( hash_ids=list(range(1, 11)), @@ -573,11 +583,12 @@ def test_advance_asst_overflow_pattern_c_deep_compaction(): curr_in_tokens=128, seed="s1", ) - # LCP=1, kept=1 block (64 tokens). new_region = 1*64 = 64 tokens. - # asst_blocks_target = ceil(200/64) = 4, clamped to 1. asst_tokens=64. - # user empty. + # LCP=1, kept=1 block (64 tokens). new_region = 1*64 = 64 tokens (tail-free). + # asst_blocks_target = ceil(200/64) = 4, clamped to new_blocks_count = 1, + # then decremented to 0 to reserve the block for the user. No assistant + # segment; the new block becomes the trailing user segment. roles = [s.role for s in r._segments] - assert roles == ["user", "assistant"] + assert roles == ["user", "user"] assert r._segments[1].content_token_count == 64 assert r._segments[1].block_count == 1 @@ -600,8 +611,10 @@ def test_advance_zero_out_skips_assistant_segment(): assert roles == ["user", "user"] -def test_advance_zero_user_skips_user_segment(): - """When asst exactly fills new_region, no user_k segment emitted.""" +def test_advance_asst_exactly_fills_region_yields_trailing_user(): + """When the assistant target exactly equals a tail-free new region, the + final block is still reserved for the user so the turn ends with a user + segment (here a single-block region, so the assistant segment vanishes).""" r = _make_recon() r.init_turn_0( hash_ids=[1, 2], in_tokens=128, tool_tokens=0, system_tokens=0, seed="s0" @@ -614,10 +627,11 @@ def test_advance_zero_user_skips_user_segment(): curr_in_tokens=192, seed="s1", ) - # new_region = 1 block + 0 partial_tail = 64 tokens. - # asst_blocks = ceil(64/64) = 1 -> asst_tokens = 64. user empty. + # new_region = 1 block + 0 partial_tail = 64 tokens. asst_blocks_target = + # ceil(64/64) = 1 == new_blocks_count, decremented to 0 to reserve the + # block for the user. No assistant segment; the block is the user segment. roles = [s.role for s in r._segments] - assert roles == ["user", "assistant"] + assert roles == ["user", "user"] def test_advance_boundary_cut_strips_missing_block_overhang(): @@ -749,8 +763,9 @@ def test_byte_exact_sum_matches_recorded_advance_turn(): curr_in_tokens=320, seed="s1", ) - # lcp=3, kept=3 blocks=192. new_region=2*64+0=128. asst=ceil(80/64)*64=128. user=0. - # sum = 192 + 128 + 0 = 320. + # lcp=3, kept=3 blocks=192. new_region=2*64+0=128. asst_target=ceil(80/64)=2 + # == new_blocks_count, tail-free -> final block reserved for the user: + # asst=64, user=64. sum = 192 + 64 + 64 = 320. assert sum(len(s.tokens) for s in r2._segments) == 320 # Pattern A with non-zero partial tail in the new turn. @@ -1326,3 +1341,308 @@ def test_advance_turn_partial_last_hashed_block_clamps_to_budget(): # not 250 + bs. assert sum(len(s.tokens) for s in r._segments) == 250 assert all(s.block_count >= 0 for s in r._segments) + + +@pytest.mark.parametrize( + ("prev_out_tokens", "curr_hash_ids", "curr_in_tokens"), + [ + # asst target exactly equals a multi-block tail-free region. + (128, [1, 2, 3, 4], 256), + # asst target overflows a multi-block tail-free region (clamped). + (500, [1, 2, 3, 4], 256), + # single-block tail-free region: assistant segment must vanish. + (200, [1, 2, 3], 192), + # tail-free region equal to a 3-block growth. + (300, [1, 2, 3, 4, 5], 320), + ], +) +def test_advance_always_ends_with_user_segment( + prev_out_tokens, curr_hash_ids, curr_in_tokens +): + """Wire invariant: every turn that adds new content ends with a user + segment, even when the assistant target would consume the whole tail-free + new region. The final new block is relabeled to the user, never left as a + trailing assistant.""" + r = _make_recon() + r.init_turn_0( + hash_ids=[1, 2], in_tokens=128, tool_tokens=0, system_tokens=0, seed="s0" + ) + r.advance_turn( + prev_hash_ids=[1, 2], + prev_in_tokens=128, + prev_out_tokens=prev_out_tokens, + curr_hash_ids=curr_hash_ids, + curr_in_tokens=curr_in_tokens, + seed="s1", + ) + assert r._segments[-1].role == "user" + assert not r._trailing_non_user_turns + # The relabel is byte-exact: total tokens still equal the recorded input. + assert sum(len(s.tokens) for s in r._segments) == curr_in_tokens + + +def test_advance_zero_new_region_records_trailing_non_user_caveat(): + """The one shape that cannot end with a user: a fully block-aligned + pull-back that appends zero new tokens after truncation exposes a trailing + assistant. Synthesizing a user block would break the byte-exact ISL + invariant, so the turn is recorded on ``_trailing_non_user_turns`` and + warned rather than faked.""" + r = _make_recon() + # turn 0: 2-block user prompt. + r.init_turn_0( + hash_ids=[1, 2], in_tokens=128, tool_tokens=0, system_tokens=0, seed="s0" + ) + # turn 1: grow by 2 tail-free blocks; large prev_out -> asst gets block 2, + # user gets block 3. Buffer: [user, assistant, user]. + r.advance_turn( + prev_hash_ids=[1, 2], + prev_in_tokens=128, + prev_out_tokens=200, + curr_hash_ids=[1, 2, 3, 4], + curr_in_tokens=256, + seed="s1", + ) + assert [s.role for s in r._segments] == ["user", "assistant", "user"] + assert not r._trailing_non_user_turns + # turn 2: pull back to exactly 3 blocks (block-aligned, no new region). The + # truncation boundary lands at the assistant segment and deletes the + # trailing user; nothing new is appended, so the buffer ends with the + # assistant -- the unavoidable caveat. + r.advance_turn( + prev_hash_ids=[1, 2, 3, 4], + prev_in_tokens=256, + prev_out_tokens=64, + curr_hash_ids=[1, 2, 3], + curr_in_tokens=192, + seed="s2", + ) + assert [s.role for s in r._segments] == ["user", "assistant"] + assert r._trailing_non_user_turns == [2] + # Byte-exact contract still holds for the (degenerate) shape. + assert sum(len(s.tokens) for s in r._segments) == 192 + + +def test_init_turn_0_system_only_prompt_records_caveat(): + """A turn-0 prompt entirely consumed by the cached tool/system prefix has + no user content to make a trailing user block; the system-only shape is + recorded on ``_trailing_non_user_turns`` rather than faked.""" + r = _make_recon() + # in=128 == 2 blocks, all tool/system; no user remainder. + r.init_turn_0( + hash_ids=[1, 2], + in_tokens=128, + tool_tokens=128, + system_tokens=0, + seed="s0", + ) + assert [s.role for s in r._segments] == ["system"] + assert r._trailing_non_user_turns == [0] + + +# --------------------------------------------------------------------------- +# compute_asst_block_caps (Pass-1 planner) + advance_turn(max_asst_blocks=...) +# --------------------------------------------------------------------------- + + +def test_compute_caps_canonical_degenerate(): + """The canonical pull-back: turn 2 truncates onto the assistant block that + turn 1 created, so turn 1's assistant must be capped to 0.""" + caps = compute_asst_block_caps( + [([1, 2], 128), ([1, 2, 3, 4], 256), ([1, 2, 3], 192)], + 64, + ) + assert caps == [None, 0, None] + + +def test_compute_caps_clean_append_no_constraints(): + """A pure-growth conversation has no degenerate pull-backs -> no caps.""" + caps = compute_asst_block_caps( + [([1, 2], 128), ([1, 2, 3, 4, 5], 320)], + 64, + ) + assert caps == [None, None] + + +def test_compute_caps_target_owned_by_turn_0_no_cap(): + """A pull-back landing on a block created by turn 0 needs no cap (turn 0 + has no assistant segment to shrink).""" + caps = compute_asst_block_caps( + [([1, 2], 128), ([1, 2, 3, 4], 256), ([1, 2], 128)], + 64, + ) + assert caps == [None, None, None] + + +def test_compute_caps_two_targets_same_owner_takes_min(): + """Two later degenerate pull-backs landing inside the same turn's assistant + region collapse to the tightest (min) cap.""" + # turn 1 grows by 4 blocks (blocks 2,3,4,5) with a large prev_out. + # turn 2 pulls back to 5 blocks (block 4 boundary), turn 3 to 4 blocks + # (block 3 boundary) -- both inside turn 1's assistant region. + caps = compute_asst_block_caps( + [ + ([1, 2], 128), + ([1, 2, 3, 4, 5, 6], 384), + ([1, 2, 3, 4, 5], 320), + ([1, 2, 3, 4], 256), + ], + 64, + ) + # owner of both targets is turn 1 (lcp_1 = 2). target T=5 -> cap (5-1)-2=2; + # target T=4 -> cap (4-1)-2=1; min = 1. + assert caps[1] == 1 + assert caps[0] is None + + +def test_compute_caps_overcovered_prefix_clamps_no_indexerror(): + """lcp can exceed the current turn's covered-block count when the recorder + stored more hash blocks than curr_in_tokens covers; the effective-lcp clamp + must keep tile indexing in range.""" + # turn 1 covers only 2 blocks (in=128) but shares a 4-long hash prefix. + caps = compute_asst_block_caps( + [([1, 2, 3, 4], 256), ([1, 2, 3, 4], 128), ([1, 2], 128)], + 64, + ) + # No crash; turn 2 pulls back to blocks owned by turn 0 -> no cap. + assert len(caps) == 3 + assert caps[2] is None + + +def test_compute_caps_partial_last_hashed_block_uses_covered_budget(): + """end_k must use min(len(hash_ids), in_tokens // bs): a partial last hashed + block contributes to the tail, not the covered tile.""" + # turn 1: in=250 -> m_full=3, hash has 4 ids (4th is partial) -> end=3. + caps = compute_asst_block_caps( + [([1, 2, 3], 192), ([1, 2, 3, 4], 250)], + 64, + ) + assert caps == [None, None] + + +def _run_canonical_three_turns(caps): + """Drive the canonical degenerate 3-turn sequence, applying per-turn caps.""" + r = _make_recon() + r.init_turn_0( + hash_ids=[1, 2], in_tokens=128, tool_tokens=0, system_tokens=0, seed="s0" + ) + r.advance_turn( + prev_hash_ids=[1, 2], + prev_in_tokens=128, + prev_out_tokens=200, + curr_hash_ids=[1, 2, 3, 4], + curr_in_tokens=256, + seed="s1", + max_asst_blocks=caps[1], + ) + r.advance_turn( + prev_hash_ids=[1, 2, 3, 4], + prev_in_tokens=256, + prev_out_tokens=64, + curr_hash_ids=[1, 2, 3], + curr_in_tokens=192, + seed="s2", + max_asst_blocks=caps[2], + ) + return r + + +def test_advance_with_cap_eliminates_trailing_assistant(): + """Applying the planner cap to turn 1 makes the turn-2 pull-back land on a + user block: no trailing assistant, no flagged caveat, byte-exact preserved.""" + caps = compute_asst_block_caps( + [([1, 2], 128), ([1, 2, 3, 4], 256), ([1, 2, 3], 192)], 64 + ) + r = _run_canonical_three_turns(caps) + assert r._segments[-1].role == "user" + assert r._trailing_non_user_turns == [] + assert sum(len(s.tokens) for s in r._segments) == 192 + + +def test_advance_without_cap_reproduces_trailing_assistant(): + """Regression guard: max_asst_blocks=None reproduces the pre-fix degenerate + trailing-assistant shape and flags it.""" + r = _run_canonical_three_turns([None, None, None]) + assert [s.role for s in r._segments] == ["user", "assistant"] + assert r._trailing_non_user_turns == [2] + + +def test_advance_cap_larger_than_region_is_noop(): + """A cap >= new_blocks_count does not shrink the assistant below what the + target/region already allow.""" + r = _make_recon() + r.init_turn_0( + hash_ids=[1, 2], in_tokens=128, tool_tokens=0, system_tokens=0, seed="s0" + ) + # new region = 3 blocks, asst target ceil(100/64)=2; cap=99 (no effect). + r.advance_turn( + prev_hash_ids=[1, 2], + prev_in_tokens=128, + prev_out_tokens=100, + curr_hash_ids=[1, 2, 3, 4, 5], + curr_in_tokens=320, + seed="s1", + max_asst_blocks=99, + ) + assert [s.role for s in r._segments] == ["user", "assistant", "user"] + assert r._segments[1].block_count == 2 # unchanged by the loose cap + + +def _make_tool_shaped_recon(bs=64): + return ConversationReconstructor( + block_size=bs, + decode_block_tokens=_stub_decode_block_tokens, + sample_partial_tail_tokens=_stub_partial_tail_tokens, + decode_tokens_to_text=_stub_decode_tokens_to_text, + emit_assistant_segments=True, + tool_shaped_messages=True, + ) + + +def test_cap_demotes_unpaired_tool_result_to_plain_user(): + """When a planner cap removes the assistant a tool-result turn would have + paired with, the tool-result user must ship as a PLAIN user message (not a + dangling role:tool without a tool_calls partner), and stay plain across a + reset re-emission. Without the cap the same turn shapes as role:tool.""" + # Uncapped: 2-block tool-result region keeps an assistant -> shapes to tool. + r_uncapped = _make_tool_shaped_recon() + r_uncapped.init_turn_0( + hash_ids=[1, 2], in_tokens=128, tool_tokens=0, system_tokens=0, seed="s0" + ) + r_uncapped.turn_delta() + r_uncapped.advance_turn( + prev_hash_ids=[1, 2], + prev_in_tokens=128, + prev_out_tokens=200, + curr_hash_ids=[1, 2, 3, 4], + curr_in_tokens=256, + seed="s1", + is_tool_result=True, + ) + d_uncapped = r_uncapped.turn_delta() + assert d_uncapped.delta_messages[-1]["role"] == "tool" + + # Capped to 0: no assistant precedes the tool-result user -> demote to plain. + r_capped = _make_tool_shaped_recon() + r_capped.init_turn_0( + hash_ids=[1, 2], in_tokens=128, tool_tokens=0, system_tokens=0, seed="s0" + ) + r_capped.turn_delta() + r_capped.advance_turn( + prev_hash_ids=[1, 2], + prev_in_tokens=128, + prev_out_tokens=200, + curr_hash_ids=[1, 2, 3, 4], + curr_in_tokens=256, + seed="s1", + is_tool_result=True, + max_asst_blocks=0, + ) + d_capped = r_capped.turn_delta() + assert [m["role"] for m in d_capped.delta_messages] == ["user"] + assert all("tool_calls" not in m for m in d_capped.delta_messages) + # Force a reset re-emission and confirm the shape stays plain user. + r_capped._emitted_segment_count = 0 + r_capped._last_disturbance_at = None + d_reset = r_capped.turn_delta() + assert d_reset.delta_messages[-1]["role"] == "user" + assert all(m["role"] != "tool" for m in d_reset.delta_messages) diff --git a/tests/unit/dataset/loader/test_weka_synth_buf_turn_delta.py b/tests/unit/dataset/loader/test_weka_synth_buf_turn_delta.py index bf0155dc3..616cfde38 100644 --- a/tests/unit/dataset/loader/test_weka_synth_buf_turn_delta.py +++ b/tests/unit/dataset/loader/test_weka_synth_buf_turn_delta.py @@ -813,54 +813,60 @@ def test_pure_growth_after_tail_only_segment_keeps_block_alignment(): pure-growth cut lands on the assistant segment boundary with the tail-only segment past it.""" r = _make_recon() - # Turn 0: [user 2b]. + # Turn 0: [user 3b]. r.init_turn_0( - hash_ids=[1, 2], - in_tokens=2 * BLOCK_SIZE, + hash_ids=[1, 2, 3], + in_tokens=3 * BLOCK_SIZE, tool_tokens=0, system_tokens=0, seed="s0", ) r.turn_delta() - # Turn 1: new region exactly covers prev_out -> appends assistant only. + # Turn 1: grow by 2 tail-free blocks with a large prev_out. The assistant + # target would take both, but the final block is reserved for the user so + # the turn ends with a user segment: [user 3b, assistant(hash4), user(hash5)]. r.advance_turn( - prev_hash_ids=[1, 2], - prev_in_tokens=2 * BLOCK_SIZE, - prev_out_tokens=BLOCK_SIZE, - curr_hash_ids=[1, 2, 3], - curr_in_tokens=3 * BLOCK_SIZE, + prev_hash_ids=[1, 2, 3], + prev_in_tokens=3 * BLOCK_SIZE, + prev_out_tokens=200, + curr_hash_ids=[1, 2, 3, 4, 5], + curr_in_tokens=5 * BLOCK_SIZE, seed="s1", ) r.turn_delta() + assert [s.role for s in r._segments] == ["user", "assistant", "user"] # Turn 2: tail-only tool result (+12 tokens, no new hash block). r.advance_turn( - prev_hash_ids=[1, 2, 3], - prev_in_tokens=3 * BLOCK_SIZE, + prev_hash_ids=[1, 2, 3, 4, 5], + prev_in_tokens=5 * BLOCK_SIZE, prev_out_tokens=10, - curr_hash_ids=[1, 2, 3], - curr_in_tokens=3 * BLOCK_SIZE + 12, + curr_hash_ids=[1, 2, 3, 4, 5], + curr_in_tokens=5 * BLOCK_SIZE + 12, seed="s2", is_tool_result=True, ) r.turn_delta() - # Turn 3: pure growth ([1,2,3] -> [1,2,3,4]); LCP cut lands exactly on - # the assistant segment's boundary, with the tail-only segment past it. + # Turn 3: pure growth ([1,2,3,4,5] -> [1,2,3,4,99]); LCP=4 cut lands exactly + # on the assistant segment's boundary, deleting the trailing user(hash5) + # and the tail-only segment past it. r.advance_turn( - prev_hash_ids=[1, 2, 3], - prev_in_tokens=3 * BLOCK_SIZE + 12, + prev_hash_ids=[1, 2, 3, 4, 5], + prev_in_tokens=5 * BLOCK_SIZE + 12, prev_out_tokens=8, - curr_hash_ids=[1, 2, 3, 4], - curr_in_tokens=4 * BLOCK_SIZE, + curr_hash_ids=[1, 2, 3, 4, 99], + curr_in_tokens=5 * BLOCK_SIZE, seed="s3", ) delta = r.turn_delta() # Replacing the already-sent tail-only segment is a context reset. assert delta.reset_context is True # Byte accounting must hold exactly. - assert sum(len(s.tokens) for s in r._segments) == 4 * BLOCK_SIZE + assert sum(len(s.tokens) for s in r._segments) == 5 * BLOCK_SIZE # The boundary assistant segment keeps its full hash-block content. assert r._segments[1].role == "assistant" - assert r._segments[1].tokens == _stub_decode_block_tokens([3]) + assert r._segments[1].tokens == _stub_decode_block_tokens([4]) + # The turn still ends with a user segment (the lone new block went to it). + assert r._segments[-1].role == "user" # Re-emitted messages mirror the (uncorrupted) segment contents 1:1. for msg, seg in zip(delta.delta_messages, r._segments, strict=True): assert msg["content"] == seg.content From e165d787faa8ba8de34feb6e5e3525eaa33a82f0 Mon Sep 17 00:00:00 2001 From: Anthony Casagrande Date: Tue, 30 Jun 2026 10:29:58 -0700 Subject: [PATCH 2/2] fix(raw-payload): drop per-record orjson re-validation; fix batch-flush shutdown loss Remove the per-send orjson.loads round-trip in InferenceClient and the per-record orjson.loads validation in RawRecordWriterProcessor. payload_bytes are validated at dataset-load time (or produced by orjson.dumps), so re-parsing every request/record only reintroduces the decode cost the verbatim / orjson.Fragment paths exist to avoid. Invalid bytes now forward/splice verbatim. Also fix a pre-existing shutdown data-loss race: RawRecordWriterProcessor's batch-trigger flush task was scheduled via execute_async without being registered in _flush_tasks, so _stop_all_tasks cancelled it before _close_file awaited it, losing the whole batch on stop. Register the task like the parent BufferedJSONLWriterMixin does. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Anthony Casagrande --- .../raw_record_writer_processor.py | 34 +++---- src/aiperf/workers/inference_client.py | 11 +-- .../test_raw_payload_replay_adversarial.py | 57 ++++++------ .../test_raw_record_writer_adversarial.py | 91 ++++++++----------- 4 files changed, 82 insertions(+), 111 deletions(-) diff --git a/src/aiperf/post_processors/raw_record_writer_processor.py b/src/aiperf/post_processors/raw_record_writer_processor.py index 710890480..2ebf8d179 100644 --- a/src/aiperf/post_processors/raw_record_writer_processor.py +++ b/src/aiperf/post_processors/raw_record_writer_processor.py @@ -116,31 +116,16 @@ async def buffered_write(self, record: RawRecordInfo) -> None: ``payload=None, payload_bytes=None`` when the enriched ``RecordContext`` carries no ``payload_bytes``). - Validates ``payload_bytes`` round-trips as JSON before splicing — - ``orjson.Fragment`` would otherwise embed invalid bytes verbatim and - silently corrupt the output JSONL. Drop + count any record whose - payload won't parse or whose serialisation fails so operators see - the failure volume via ``dropped_record_count``. + ``payload_bytes`` is spliced verbatim with no per-record JSON + re-parse: the bytes are either produced by ``orjson.dumps`` upstream + (valid by construction) or loaded from a raw dataset that was already + parsed at dataset-load time, so re-validating every record here would + only reintroduce the decode cost the ``Fragment`` path exists to avoid. """ if record.payload_bytes is None: await super().buffered_write(record) return - try: - orjson.loads(record.payload_bytes) - except (orjson.JSONDecodeError, TypeError) as e: - size = ( - len(record.payload_bytes) - if isinstance(record.payload_bytes, bytes | bytearray | memoryview) - else -1 - ) - self.warning( - f"Dropping raw record: payload_bytes does not parse as JSON " - f"(size={size}): {e!r}" - ) - self.dropped_record_count += 1 - return - try: dumped = record.model_dump(exclude_none=True, mode="json") # ``payload_bytes`` carries the wire-exact JSON; substitute it @@ -156,7 +141,14 @@ async def buffered_write(self, record: RawRecordInfo) -> None: buffer_to_flush = self._buffer self._buffer = [] if buffer_to_flush: - self.execute_async(self._flush_buffer(buffer_to_flush)) + # Register the flush task in ``_flush_tasks`` (mirroring the + # parent mixin) so ``_close_file`` awaits it on shutdown. + # Without this the fire-and-forget task lives only in + # ``self.tasks``, which ``_stop_all_tasks`` cancels before the + # file is closed — losing the whole batch on stop. + task = self.execute_async(self._flush_buffer(buffer_to_flush)) + self._flush_tasks.add(task) + task.add_done_callback(self._flush_tasks.discard) except Exception as e: self.error(f"Failed to write raw record: {e!r}") self.dropped_record_count += 1 diff --git a/src/aiperf/workers/inference_client.py b/src/aiperf/workers/inference_client.py index 3b10d6925..8991aa3c8 100644 --- a/src/aiperf/workers/inference_client.py +++ b/src/aiperf/workers/inference_client.py @@ -128,18 +128,9 @@ async def _send_request_to_transport( request_info.endpoint_params = self.endpoint.get_endpoint_params(request_info) if request_info.payload_bytes is not None: # PAYLOAD_BYTES fast path: bytes were validated at dataset-load time - # by the mmap loader / DatasetManager. Defensive guard against any - # invalid bytes that bypass upstream validation — round-trip - # through orjson.loads so a malformed payload turns into an error - # RequestRecord rather than reaching the wire. Body-mutating features + # by the mmap loader / DatasetManager, and body-mutating features # (cache-bust, Dynamo session_control) are refused against this # verbatim-bytes path at dataset load, so nothing is injected here. - try: - orjson.loads(request_info.payload_bytes) - except (orjson.JSONDecodeError, ValueError, TypeError) as e: - raise ValueError( - f"invalid JSON in pre-serialised payload_bytes: {e}" - ) from e formatted_payload = request_info.payload_bytes else: current_turn = request_info.turns[-1] if request_info.turns else None diff --git a/tests/component_integration/dataset/test_raw_payload_replay_adversarial.py b/tests/component_integration/dataset/test_raw_payload_replay_adversarial.py index 007b90300..443c49072 100644 --- a/tests/component_integration/dataset/test_raw_payload_replay_adversarial.py +++ b/tests/component_integration/dataset/test_raw_payload_replay_adversarial.py @@ -53,7 +53,11 @@ ModelInfo, ModelListInfo, ) -from aiperf.common.models.record_models import RawRecordInfo, RequestInfo +from aiperf.common.models.record_models import ( + RawRecordInfo, + RequestInfo, + RequestRecord, +) from aiperf.dataset.dataset_manager import DatasetManager from aiperf.dataset.loader.raw_payload import RawPayloadDatasetLoader from aiperf.plugin.enums import CustomDatasetType, EndpointType, TransportType @@ -372,14 +376,12 @@ def test_mixed_state_conversation_raises_during_generate_input_payloads_end_to_e @pytest.mark.asyncio -async def test_inference_client_rejects_invalid_json_payload_bytes_end_to_end() -> None: - """Post-W2-E: InferenceClient validates pre-serialised payload_bytes by - round-tripping through ``orjson.loads`` before the transport call. - Invalid JSON must never hit the wire — the broad catch in - ``_send_request_internal`` turns the ValueError into an error - RequestRecord whose message mentions 'invalid JSON'.""" +async def test_inference_client_forwards_invalid_json_payload_bytes_verbatim() -> None: + """Per-request ``orjson.loads`` validation of pre-serialised + ``payload_bytes`` was removed — invalid-JSON detection happens at + dataset-load time, not on every send. Unparsable bytes are forwarded to + the transport verbatim rather than being turned into an error record.""" client = _make_inference_client() - client.transport.send_request = AsyncMock() turn = Turn(texts=[Text(contents=["x"])], role="user", model="test-model") info = RequestInfo( @@ -393,27 +395,30 @@ async def test_inference_client_rejects_invalid_json_payload_bytes_end_to_end() conversation_id="conv", payload_bytes=b"}", ) + client.transport.send_request = AsyncMock( + return_value=RequestRecord(request_info=info) + ) - record = await client.send_request(info) + await client.send_request(info) - client.transport.send_request.assert_not_called() - assert record.error is not None - assert "invalid JSON" in record.error.message + client.transport.send_request.assert_called_once() + assert client.transport.send_request.call_args.kwargs["payload"] == b"}" # --------------------------------------------------------------------------- -# 7: RawRecordWriter drops invalid fragment with counter +# 7: RawRecordWriter splices payload_bytes verbatim (no per-record re-parse) # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_raw_record_writer_drops_invalid_fragment_with_counter_end_to_end( +async def test_raw_record_writer_splices_invalid_fragment_verbatim_end_to_end( tmp_path: Path, ) -> None: - """Post-W2-D: RawRecordWriterProcessor validates ``payload_bytes`` via - ``orjson.loads`` before the Fragment splice. Invalid JSON bytes are - dropped, ``dropped_record_count`` increments, and the output file - contains no corrupt lines.""" + """The per-record ``orjson.loads`` re-validation was removed: + ``RawRecordWriterProcessor`` splices ``payload_bytes`` verbatim via + ``orjson.Fragment``. Invalid JSON bytes are written (no drop, no counter) + and the resulting line is corrupt — the accepted tradeoff of trusting + dataset-load-time validation instead of re-parsing every exported record.""" user_config = _user_config_raw(tmp_path) processor = RawRecordWriterProcessor(service_id="rrw-ci", user_config=user_config) await processor.initialize() @@ -432,18 +437,16 @@ async def test_raw_record_writer_drops_invalid_fragment_with_counter_end_to_end( ) await processor.buffered_write(bad) - assert processor.dropped_record_count == 1 - assert processor.lines_written == 0 + assert processor.dropped_record_count == 0 + assert processor.lines_written == 1 finally: await processor.stop() - # Any lines that *did* make it to disk must parse cleanly; no corrupt splice. - if processor.output_file.exists(): - raw = processor.output_file.read_bytes() - assert b'"payload":}' not in raw - for line in raw.splitlines(): - if line.strip(): - orjson.loads(line) + # The verbatim splice produced a line that no longer parses as JSON. + raw = processor.output_file.read_bytes() + line = next(line for line in raw.splitlines() if line.strip()) + with pytest.raises(orjson.JSONDecodeError): + orjson.loads(line) def _metric_metadata(): diff --git a/tests/unit/post_processors/test_raw_record_writer_adversarial.py b/tests/unit/post_processors/test_raw_record_writer_adversarial.py index 526c9cee3..063cdfc77 100644 --- a/tests/unit/post_processors/test_raw_record_writer_adversarial.py +++ b/tests/unit/post_processors/test_raw_record_writer_adversarial.py @@ -164,13 +164,15 @@ async def spy(self, record): assert parsed["payload"] == {"k": "v"} @pytest.mark.asyncio - async def test_buffered_write_empty_bytes_payload_bytes_dropped_with_counter( + async def test_buffered_write_empty_bytes_payload_bytes_spliced_verbatim( self, user_config_raw: UserConfig, ): - """``payload_bytes=b""`` is not valid JSON — post-Wave-2 fix drops it - at the ingest check rather than splicing an empty Fragment and - emitting a ``"payload":`` with no value. Counter bumps. + """``payload_bytes=b""`` — the per-record JSON re-validation was removed, + so empty bytes splice verbatim (no drop, no counter). orjson emits + ``"payload":`` with no value, i.e. a deliberately corrupt line: the + accepted tradeoff of trusting dataset-load-time validation instead of + re-parsing every record on the export hot path. """ record = _make_raw_record(payload_bytes=b"") @@ -178,28 +180,23 @@ async def test_buffered_write_empty_bytes_payload_bytes_dropped_with_counter( "processor-empty", user_config_raw ) as processor: await processor.buffered_write(record) - assert processor.dropped_record_count == 1 - assert processor.lines_written == 0 + assert processor.dropped_record_count == 0 + assert processor.lines_written == 1 - raw = ( - processor.output_file.read_bytes() - if processor.output_file.exists() - else b"" - ) - assert b'"payload":,' not in raw and b'"payload":}' not in raw - for line in raw.splitlines(): - if line.strip(): - orjson.loads(line) + line = processor.output_file.read_bytes().splitlines()[0] + # Spliced verbatim -> the line is no longer valid JSON. + with pytest.raises(orjson.JSONDecodeError): + orjson.loads(line) @pytest.mark.asyncio - async def test_buffered_write_invalid_json_payload_bytes_dropped_with_counter( + async def test_buffered_write_invalid_json_payload_bytes_spliced_verbatim( self, user_config_raw: UserConfig, ): - """``payload_bytes=b"}"`` — post-Wave-2 fix: invalid JSON bytes are - rejected at ingest via an ``orjson.loads`` round-trip check so the - Fragment splice never emits corrupt bytes. The record is dropped - and ``dropped_record_count`` increments. + """``payload_bytes=b"}"`` — with the per-record ``orjson.loads`` ingest + check removed, invalid bytes splice verbatim via ``orjson.Fragment`` + (no drop, no counter). The emitted line is corrupt; this documents the + accepted tradeoff of trusting upstream/dataset-load validation. """ record = _make_raw_record(payload_bytes=b"}") @@ -207,29 +204,22 @@ async def test_buffered_write_invalid_json_payload_bytes_dropped_with_counter( "processor-bad-json", user_config_raw ) as processor: await processor.buffered_write(record) - assert processor.dropped_record_count == 1 - assert processor.lines_written == 0 + assert processor.dropped_record_count == 0 + assert processor.lines_written == 1 - # Output must not contain the corrupt splice artefact. - raw = ( - processor.output_file.read_bytes() - if processor.output_file.exists() - else b"" - ) - assert b'"payload":}' not in raw - # Every surviving line (if any) must parse cleanly. - for line in raw.splitlines(): - if line.strip(): - orjson.loads(line) + line = processor.output_file.read_bytes().splitlines()[0] + with pytest.raises(orjson.JSONDecodeError): + orjson.loads(line) @pytest.mark.asyncio - async def test_buffered_write_truncated_json_payload_bytes_dropped_with_counter( + async def test_buffered_write_truncated_json_payload_bytes_spliced_verbatim( self, user_config_raw: UserConfig, ): - """Truncated JSON ``b'{"a":1'`` — post-Wave-2 fix: the ingest-time - ``orjson.loads`` check rejects the partial bytes before the Fragment - splice, so no corrupt line is emitted and the drop counter bumps. + """Truncated JSON ``b'{"a":1'`` — with the ingest ``orjson.loads`` check + removed, the partial bytes splice verbatim (no drop, no counter) and the + emitted line is corrupt: the accepted tradeoff of skipping a per-record + re-parse on the export hot path. """ record = _make_raw_record(payload_bytes=b'{"a":1') @@ -237,19 +227,12 @@ async def test_buffered_write_truncated_json_payload_bytes_dropped_with_counter( "processor-trunc", user_config_raw ) as processor: await processor.buffered_write(record) - assert processor.dropped_record_count == 1 - assert processor.lines_written == 0 + assert processor.dropped_record_count == 0 + assert processor.lines_written == 1 - raw = ( - processor.output_file.read_bytes() - if processor.output_file.exists() - else b"" - ) - # No truncated splice artefact. - assert b'"payload":{"a":1' not in raw - for line in raw.splitlines(): - if line.strip(): - orjson.loads(line) + line = processor.output_file.read_bytes().splitlines()[0] + with pytest.raises(orjson.JSONDecodeError): + orjson.loads(line) @pytest.mark.asyncio async def test_buffered_write_payload_bytes_with_trailing_whitespace_still_valid_fragment( @@ -322,9 +305,11 @@ async def test_buffered_write_non_json_non_bytes_payload_bytes_dropped_with_coun self, user_config_raw: UserConfig, ): - """``payload_bytes=123`` (int) — post-Wave-2 fix: ``orjson.loads(123)`` - raises ``TypeError`` at the ingest validation, which is caught and - the record is dropped with the counter bumped. + """``payload_bytes=123`` (int) — even without the per-record JSON + re-validation, ``orjson.Fragment(123)`` raises ``TypeError`` at + construction, which the serialisation ``except`` catches so the record + is dropped with the counter bumped (genuine serialisation failures are + still surfaced). We construct the ``RawRecordInfo`` via ``model_construct`` because pydantic validation would reject ``payload_bytes=123``. @@ -511,7 +496,7 @@ async def test_buffered_write_invalid_json_payload_bytes_raises_or_increments_co attribute so operators can see drops. """ # Use the same shape as test_non_json_non_bytes which hits the - # TypeError path (orjson.loads rejects int). + # TypeError path (orjson.Fragment rejects a non-bytes/str int). record = RawRecordInfo.model_construct( metadata=create_metric_metadata(), start_perf_ns=1_000_000_000,