Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/aiperf/dataset/loader/weka_parallel_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def _process_task(task: _WekaTraceTask) -> _WekaProcessTaskResult:
assert _worker_state is not None
from aiperf.dataset.loader.weka_synth_buf import (
ConversationReconstructor,
compute_asst_block_caps,
)

state = _worker_state
Expand All @@ -382,6 +383,10 @@ def _process_task(task: _WekaTraceTask) -> _WekaProcessTaskResult:
parent_turns: list[_WekaParentTurnDict] = []
outer_to_turn_pos: dict[int, int] = {}
normals: list[tuple[int, _WekaNormalRequestPayload]] = parent["normals"]
parent_asst_block_caps = compute_asst_block_caps(
[(r["hash_ids"], r["input_length"]) for _, r in normals],
bs,
)
for k, (outer_idx, req) in enumerate(normals):
seed = f"{task.trace_id}:turn_{k}:partial_tail"
is_tool_result = req.get("input_kind") == "tool_result"
Expand All @@ -404,6 +409,7 @@ def _process_task(task: _WekaTraceTask) -> _WekaProcessTaskResult:
curr_in_tokens=req["input_length"],
seed=seed,
is_tool_result=is_tool_result,
max_asst_blocks=parent_asst_block_caps[k],
)

if "effective_t" in req:
Expand Down Expand Up @@ -617,6 +623,10 @@ def _process_task(task: _WekaTraceTask) -> _WekaProcessTaskResult:

child_turns: list[_WekaParentTurnDict] = []
creqs: list[_WekaNormalRequestPayload] = cp["requests"]
child_asst_block_caps = compute_asst_block_caps(
[(r["hash_ids"], r["input_length"]) for r in creqs],
bs,
)
for k, creq in enumerate(creqs):
seed = f"{cp['session_id']}:turn_{k}:partial_tail"
is_tool_result = creq.get("input_kind") == "tool_result"
Expand All @@ -639,6 +649,7 @@ def _process_task(task: _WekaTraceTask) -> _WekaProcessTaskResult:
curr_in_tokens=creq["input_length"],
seed=seed,
is_tool_result=is_tool_result,
max_asst_blocks=child_asst_block_caps[k],
)
if "effective_t" in creq:
t_ms = creq["effective_t"] * 1000.0
Expand Down
237 changes: 222 additions & 15 deletions src/aiperf/dataset/loader/weka_synth_buf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from dataclasses import dataclass, field
from typing import Literal

from aiperf.common.aiperf_logger import AIPerfLogger

# Re-exported for backwards compatibility; the composer lives in its own
# module since it is independent of the synthesis-buffer state machine.
from aiperf.dataset.loader.weka_prompt_compose import ( # noqa: E402
Expand All @@ -24,6 +26,8 @@
tool_shape_segment_messages,
)

_logger = AIPerfLogger(__name__)


@dataclass
class TurnDelta:
Expand Down Expand Up @@ -116,6 +120,17 @@ class ConversationReconstructor:
_emitted_segment_count: int = 0
_last_disturbance_at: int | None = None
_turn_index: int = 0
_trailing_non_user_turns: list[int] = field(default_factory=list)
"""Turn indices whose segment list could NOT be made to end with a user
block. The wire invariant (every turn ends with a user message) is
enforced for every turn that contributes any new content, but a turn that
appends zero new tokens (a fully block-aligned pull-back/compaction:
``new_blocks_count == 0`` with no partial/missing tail) has nothing to
relabel as user, and a turn-0 prompt entirely consumed by the tool/system
prefix has only the cached system segment. Those degenerate shapes leave a
trailing non-user segment; they are recorded here (and warned once) rather
than faked, since synthesizing tokens would break the byte-exact
``sum(seg.tokens) == in_tokens`` invariant."""

def init_turn_0(
self,
Expand Down Expand Up @@ -241,6 +256,7 @@ def init_turn_0(
self._segments = segs
self._emitted_segment_count = 0
self._last_disturbance_at = None
self._assert_trailing_user()

def advance_turn(
self,
Expand All @@ -251,6 +267,8 @@ def advance_turn(
curr_in_tokens: int,
seed: str,
is_tool_result: bool = False,
*,
max_asst_blocks: int | None = None,
) -> None:
"""Advance synth_buf to turn k via LCP-driven symmetric attribution.

Expand All @@ -276,25 +294,41 @@ def advance_turn(
hash-content invariant — every cached block emits its full content,
unmodified by any terminator stamp.

Wire invariant: the segment list always ends with a user segment. A
chat request cannot ask the model to continue from a trailing
assistant message, so when the assistant target would consume the
entire new region and no partial/missing tail remains to seed a
trailing user segment, the final new block is handed back to the user
(relabeled, not resized — byte-exactness holds). The only turns that
cannot satisfy this are recorded on ``_trailing_non_user_turns`` (see
:meth:`_assert_trailing_user`): a turn that appends zero new tokens has
nothing to relabel.

``prev_in_tokens`` is retained for the recorded-request contract
(mirroring ``prev_hash_ids`` / ``prev_out_tokens``) but no longer
feeds the truncation: the buffer self-describes its trailing
overhang, which is also correct when the boundary segment is not the
trailing one.

``max_asst_blocks`` is the optional two-pass planning cap from
:func:`compute_asst_block_caps`: an upper bound on this turn's assistant
block count chosen so a future block-aligned pull-back truncates onto a
user block instead of an assistant one. ``None`` means no cap. It only
ever lowers ``asst_blocks`` (an extra ``min``), so it composes with the
context-loss rule and the trailing-user reservation and never breaks the
byte-exact ``sum(seg.tokens) == curr_in_tokens`` invariant.
"""
bs = self.block_size
m_curr = len(curr_hash_ids)
m_curr_full = curr_in_tokens // bs
# Mirror init_turn_0's ``covered_blocks = min(m_full, len(hash_ids))``:
# a hashed-but-partial last block (len(curr_hash_ids) > curr_in_tokens
# // bs, e.g. in=250 with hash_ids=[..,4] at bs=64) covers only its
# ``curr_in_tokens % bs`` partial tail, not a full ``bs``-token block.
# Decoding it as a full block AND appending the partial tail below
# double-counts ~bs tokens and breaks the sum(seg.tokens) ==
# curr_in_tokens invariant the byte-exact contract depends on.
m_curr_covered = min(m_curr, m_curr_full)
missing_block_tokens = max(0, (m_curr_full - m_curr) * bs)
lcp = longest_common_prefix(prev_hash_ids, curr_hash_ids)
# Single source of truth for this turn's block accounting (covered
# blocks, LCP, new-region size, synth tail), shared with the Pass-1
# planner ``compute_asst_block_caps`` so the two never drift.
geo = compute_turn_block_geometry(
prev_hash_ids, curr_hash_ids, curr_in_tokens, bs
)
lcp = geo.lcp
m_curr_covered = geo.m_curr_covered
synth_tail_n = geo.synth_tail_n
new_blocks_count = geo.new_blocks_count

truncate_disturbance = truncate_synth_buf_at_block(
self._segments,
Expand All @@ -305,14 +339,11 @@ def advance_turn(
self._last_disturbance_at = truncate_disturbance

new_blocks = curr_hash_ids[lcp:m_curr_covered]
new_partial_tail_n = curr_in_tokens % bs
new_region_tokens = self.decode_block_tokens(new_blocks)
synth_tail_n = missing_block_tokens + new_partial_tail_n
if synth_tail_n > 0:
new_region_tokens.extend(
self.sample_partial_tail_tokens(synth_tail_n, seed)
)
new_blocks_count = max(0, m_curr_covered - lcp)

self._turn_index += 1
asst_blocks_target = (
Expand All @@ -325,6 +356,25 @@ def advance_turn(
# any user input. The whole new region becomes user content.
asst_blocks_target = 0
asst_blocks = min(asst_blocks_target, new_blocks_count)
if max_asst_blocks is not None:
# Two-pass planning cap (see compute_asst_block_caps): a future
# block-aligned pull-back will truncate exactly to a boundary inside
# this turn's region. Shrinking the assistant here so that boundary
# block lands in THIS turn's user segment makes that future turn end
# with a user block — decided up front and stable across every turn,
# so the rendered prefix never diverges and KV-cache reuse holds.
asst_blocks = min(asst_blocks, max_asst_blocks)
if synth_tail_n == 0 and asst_blocks == new_blocks_count and asst_blocks > 0:
# Wire invariant: every turn must end with a user message — a chat
# request cannot ask the model to continue from a trailing
# assistant message. When the assistant would consume the entire
# new region and no partial/missing tail remains to seed a trailing
# user segment, hand the final new block back to the user. This
# relabels block content (role only), so the byte-exact
# sum(seg.tokens) == curr_in_tokens invariant is preserved. When
# the region is a single block the assistant segment vanishes and
# the turn becomes user-only.
asst_blocks -= 1
asst_emit_size = asst_blocks * bs

cursor = lcp
Expand Down Expand Up @@ -361,6 +411,34 @@ def advance_turn(
)
)

self._assert_trailing_user()

def _assert_trailing_user(self) -> None:
"""Record (and warn once) when the segment list does not end with a user.

The wire invariant — every turn ends with a user message — is enforced
for any turn that contributes new content. The only shapes that cannot
satisfy it without faking tokens are recorded on
``_trailing_non_user_turns``: a turn appending zero new tokens (a fully
block-aligned pull-back with ``new_blocks_count == 0`` and no tail) has
nothing to relabel as user, and a turn-0 prompt entirely consumed by
the cached tool/system prefix has only a system segment. Synthesizing a
user block in either case would violate the byte-exact
``sum(seg.tokens) == in_tokens`` invariant, so the caveat is surfaced
rather than hidden.
"""
if not self._segments or self._segments[-1].role == "user":
return
self._trailing_non_user_turns.append(self._turn_index)
trailing_role = self._segments[-1].role
_logger.warning(
f"weka reconstructor: turn {self._turn_index} ends with a "
f"'{trailing_role}' segment, not 'user'. The turn added no new "
f"content to relabel as a trailing user block (block-aligned "
f"pull-back or system-only turn-0); synthesizing one would break "
f"the byte-exact ISL invariant, so it is left as-is."
)

def turn_delta(self) -> TurnDelta:
"""Compute the raw_messages to emit for the just-completed turn.

Expand Down Expand Up @@ -435,6 +513,135 @@ def longest_common_prefix(prev_hash_ids: list[int], curr_hash_ids: list[int]) ->
return n


@dataclass(frozen=True)
class TurnBlockGeometry:
"""Role-independent block accounting for one turn transition.

Single source of truth shared by
:meth:`ConversationReconstructor.advance_turn` (which truncates to ``lcp``
then appends the new region) and :func:`compute_asst_block_caps` (which
replays the same tile to plan assistant caps). Depends only on hash ids and
token counts — never on role labeling — so the two callers cannot drift.
"""

lcp: int
"""Longest common prefix (in blocks) of prev/curr hash ids — the block the
buffer truncates back to before appending this turn's new region."""
m_curr_covered: int
"""Covered block count ``min(len(curr_hash_ids), curr_in_tokens // bs)``. A
partial last hashed block (``len > in // bs``) contributes only its tail,
not a full block."""
new_blocks_count: int
"""Full blocks appended after truncating to ``lcp``: ``max(0, m_curr_covered - lcp)``."""
synth_tail_n: int
"""Synthesized trailing tokens: the partial tail (``in % bs``) plus any
missing-block region when the recording stored fewer hash blocks than
``in // bs``."""


def compute_turn_block_geometry(
prev_hash_ids: list[int],
curr_hash_ids: list[int],
curr_in_tokens: int,
block_size: int,
) -> TurnBlockGeometry:
"""Compute the role-independent block geometry for one turn transition.

A hashed-but-partial last block (``len(curr_hash_ids) > curr_in_tokens //
bs``, e.g. ``in=250`` with ``hash_ids=[..,4]`` at ``bs=64``) covers only its
``curr_in_tokens % bs`` partial tail, not a full ``bs``-token block —
counting it as a full block AND appending the tail would double-count ~``bs``
tokens and break the ``sum(seg.tokens) == curr_in_tokens`` invariant.
"""
bs = block_size
m_curr = len(curr_hash_ids)
m_curr_full = curr_in_tokens // bs
m_curr_covered = min(m_curr, m_curr_full)
missing_block_tokens = max(0, (m_curr_full - m_curr) * bs)
lcp = longest_common_prefix(prev_hash_ids, curr_hash_ids)
return TurnBlockGeometry(
lcp=lcp,
m_curr_covered=m_curr_covered,
new_blocks_count=max(0, m_curr_covered - lcp),
synth_tail_n=curr_in_tokens % bs + missing_block_tokens,
)


def compute_asst_block_caps(
turns: list[tuple[list[int], int]],
block_size: int,
) -> list[int | None]:
"""Pass-1 planner: per-turn upper bounds on assistant block count.

``turns`` is one ``(hash_ids, in_tokens)`` pair per turn, in order.

Returns ``caps`` where ``caps[k]`` bounds the assistant blocks
:meth:`ConversationReconstructor.advance_turn` may attribute at turn ``k``
(``None`` = no constraint, including turn 0). The bounds make every future
block-aligned pull-back truncate onto a user block rather than an assistant
one, so each such turn ends with a user message.

Pure and role-independent: LCP, the per-turn covered-block region, and the
block tile depend only on ``hash_ids`` / ``in_tokens``, never on how blocks
are later labeled. So the caps can be computed once up front and applied at
the turn that *creates* each block — the boundary is fixed on first emission
and never relabeled, preserving cross-turn KV-cache reuse.

Mirrors ``advance_turn`` + ``truncate_synth_buf_at_block``: replays the
block tile (owning turn per block) via ``compute_turn_block_geometry``,
clamping ``eff_lcp = min(lcp, len(tile))`` since truncate can't grow the
buffer. A turn that appends nothing (``new_blocks_count == 0`` and
``synth_tail_n == 0``) truncates onto block ``T-1`` (``T = eff_lcp``); its
owning turn ``j = tile[T-1]`` must keep that block in its user segment, so
``cap_j = min((T-1) - eff_lcp_j)``. Turn-0 owners are skipped (no assistant).
"""
n_turns = len(turns)
caps: list[int | None] = [None] * n_turns
if n_turns == 0:
return caps

bs = block_size
tile: list[int] = []
eff_lcp_per_turn: list[int] = [0] * n_turns

for k in range(n_turns):
hash_ids, in_tokens = turns[k]
prev_hash_ids = turns[k - 1][0] if k else []
# Same geometry function advance_turn uses, so the two never drift.
geo = compute_turn_block_geometry(prev_hash_ids, hash_ids, in_tokens, bs)
if k == 0:
eff_lcp_per_turn[0] = 0
tile = [0] * geo.m_curr_covered
continue

# eff_lcp clamps the raw LCP to the prior buffer's real block count:
# truncate is a no-op past the buffer end and cannot grow the tile.
eff_lcp = min(geo.lcp, len(tile))
eff_lcp_per_turn[k] = eff_lcp
new_blocks_count = max(0, geo.m_curr_covered - eff_lcp)
synth_tail_n = geo.synth_tail_n

if new_blocks_count == 0 and synth_tail_n == 0:
# Degenerate pull-back: truncate to T=eff_lcp exposes block T-1.
# target == eff_lcp <= len(tile) by construction, so the only guard
# needed is target >= 1 (a turn-0-boundary pull-back has nothing to
# cap).
target = eff_lcp
if target >= 1:
owner = tile[target - 1]
if owner != 0:
bound = (target - 1) - eff_lcp_per_turn[owner]
if bound >= 0:
prev = caps[owner]
caps[owner] = bound if prev is None else min(prev, bound)

# Mutate the tile suffix in place (O(new)) instead of rebuilding it.
del tile[eff_lcp:]
tile.extend([k] * new_blocks_count)

return caps


def truncate_synth_buf_at_block(
segments: list[RoleSegment],
target_blocks: int,
Expand Down
Loading
Loading