Skip to content

[WIP] [WS1] Add tests to ensure kv cache consistency for prefill & decode#178

Draft
zhangj1an wants to merge 2 commits into
RL-Align:mainfrom
zhangj1an:jian/kv_cache_consistency
Draft

[WIP] [WS1] Add tests to ensure kv cache consistency for prefill & decode#178
zhangj1an wants to merge 2 commits into
RL-Align:mainfrom
zhangj1an:jian/kv_cache_consistency

Conversation

@zhangj1an

@zhangj1an zhangj1an commented Jun 22, 2026

Copy link
Copy Markdown

Latest Status [22 June]

Code is mostly done, will write up PR description and make as ready tomorrow

Purpose

Closes #152.

Write a few unit tests to ensure that, given a prefill path and decode path, as long as they are using the same kernel, their output is within error threshold over a long sequence of tokens.

Open questions

  1. should the KV Cache and Parity Report class stay within rl_engine/testing/kv_consistency.py?
  2. should the offline rollout feature [FEAT][executors]: implement basic vLLM worker for offline rollout #10 be implemented first?

Key Changes

Summary by CodeRabbit

  • Tests
    • Added comprehensive KV-cache consistency testing infrastructure to verify attention implementations across different execution paths (prefill and decode modes).
    • Added test suite validating KV-cache behavior under various scenarios including padded sequences, dtype variations, and generation workflows.

Signed-off-by: Zhang Jian <jianmusings@gmail.com>
@coderabbitai

coderabbitai Bot commented Jun 22, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

Adds rl_engine/testing/kv_consistency.py, a KV-cache path-consistency harness built around a shared fixed-order attention reduction contract (AttentionSpec, attend_single_query, KVCache, fixed_order_attention, replay_decode), parity utilities, and a deterministic TinyCausalLM. Adds tests/test_kv_cache_consistency.py with eight tests covering all five equivalence claims from issue #152.

Changes

KV-cache path-consistency harness and tests

Layer / File(s) Summary
AttentionSpec contract and canonical single-query attention primitive
rl_engine/testing/kv_consistency.py
Frozen AttentionSpec dataclass (validation, scale, GQA group), expand_kv_heads helper, and attend_single_query — the shared fp32 attention primitive with optional key masking, NaN guarding, and ascending-order softmax.
KVCache storage and fixed_order_attention prefill path
rl_engine/testing/kv_consistency.py
KVCache preallocated buffer with typed append/context retrieval; fixed_order_attention iterates query positions calling attend_single_query over ascending K/V prefixes with causal and masking support.
replay_decode decode path with optional kv_dtype drift
rl_engine/testing/kv_consistency.py
replay_decode builds a KVCache, appends per-timestep K/V with optional dtype casting, and calls attend_single_query over the growing context to mirror prefill reduction order; includes _check_qkv tensor validation.
Parity utilities: ParityReport and assert_path_parity
rl_engine/testing/kv_consistency.py
ParityReport dataclass, parity_report (bitwise equality + abs-error statistics), and assert_path_parity (bitwise assertion or torch.testing.assert_close with configurable tolerances).
TinyCausalLM end-to-end deterministic test model
rl_engine/testing/kv_consistency.py
Single-layer causal LM with seed-initialized embeddings and Q/K/V projections; exposes prefill_logits, decode_logits (with optional kv_dtype), and generate (greedy decode returning token ids and per-step logprobs).
KV-cache consistency test suite (5 PR claims)
tests/test_kv_cache_consistency.py
Shared SPEC/helpers and eight tests: full-vs-chunked-prefill parity, naive SDPA mismatch boundary, decode-vs-prefill parity (batched + padded), matched/low-precision stored-KV dtype parity, generate-then-rescore equivalence, and parameterized smoke coverage over short/long/varlen/padded sequences.

Sequence Diagram(s)

sequenceDiagram
  participant Test
  participant TinyCausalLM
  participant fixed_order_attention
  participant replay_decode
  participant KVCache
  participant attend_single_query

  rect rgba(70, 130, 180, 0.5)
    Note over Test,attend_single_query: Prefill path
    Test->>TinyCausalLM: prefill_logits(input_ids, attention_mask)
    TinyCausalLM->>fixed_order_attention: q, k, v, key_mask
    loop t = 0..T-1
      fixed_order_attention->>attend_single_query: q_t, k[0..t], v[0..t]
      attend_single_query-->>fixed_order_attention: out_t
    end
    fixed_order_attention-->>TinyCausalLM: attn_out [B,T,H,D]
    TinyCausalLM-->>Test: logits [B,T,V]
  end

  rect rgba(60, 160, 100, 0.5)
    Note over Test,attend_single_query: Decode path (replay)
    Test->>TinyCausalLM: decode_logits(input_ids, kv_dtype)
    TinyCausalLM->>replay_decode: q, k, v, kv_dtype
    replay_decode->>KVCache: init(dtype=kv_dtype)
    loop t = 0..T-1
      replay_decode->>KVCache: append(k_t, v_t)
      KVCache-->>replay_decode: context k[0..t], v[0..t]
      replay_decode->>attend_single_query: q_t, cached_k, cached_v
      attend_single_query-->>replay_decode: out_t
    end
    replay_decode-->>TinyCausalLM: attn_out [B,T,H,D]
    TinyCausalLM-->>Test: logits [B,T,V]
  end

  Test->>Test: assert_path_parity(decode_logits, prefill_logits)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Poem

🐇 Hoppity-hop through prefill and decode,
The cache and the contract share one fixed road.
Each query attends in ascending key order,
No drift between writer and reader — no border!
Bitwise they match, or within tolerance fall,
The rabbit has tested and verified all. ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 38.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly indicates this PR adds KV cache consistency tests for prefill and decode paths, directly matching the main changes which are the testing harness and test cases for this feature.
Linked Issues check ✅ Passed The PR implements all major objectives from issue #152: fixed-order attention contract [1], prefill vs chunked-prefill tests [2], decode vs prefill parity with masking [3], stored-KV dtype handling [4], generate-then-rescore validation [5], and CI smoke tests across sequence shapes [6].
Out of Scope Changes check ✅ Passed All changes are directly scoped to KV-cache consistency testing: the testing harness in rl_engine/testing/kv_consistency.py and tests in tests/test_kv_cache_consistency.py are both core to issue #152 objectives, with no unrelated modifications.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@zhangj1an zhangj1an marked this pull request as draft June 22, 2026 15:21

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/test_kv_cache_consistency.py`:
- Line 39: Remove the global torch.manual_seed(0) call at line 39 that sets RNG
state for the entire pytest process and creates order-dependent test behavior.
Instead, apply torch.manual_seed() locally only where determinism is actually
needed (such as around line 149 for prompt sampling operations). This keeps RNG
seeding scoped to specific test operations rather than mutating global state
that can affect unrelated tests.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1bdc7bce-5b02-4472-8894-ce49f81438ab

📥 Commits

Reviewing files that changed from the base of the PR and between 51b8b21 and 87064b3.

📒 Files selected for processing (2)
  • rl_engine/testing/kv_consistency.py
  • tests/test_kv_cache_consistency.py

)
from rl_engine.testing.reference_ops import selected_logprobs_reference

torch.manual_seed(0)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Localize RNG seeding to avoid cross-test state coupling.

Line 39 sets global RNG state for the entire pytest process. That can create order-dependent behavior in unrelated tests. Keep determinism local (e.g., for Line 149 prompt sampling) instead of mutating global state.

Suggested change
- torch.manual_seed(0)
@@
 def test_generate_then_rescore_equivalence():
     model = TinyCausalLM(vocab_size=64, d_model=48, spec=SPEC, seed=1)
-    prompt = torch.randint(0, 64, (2, 5))
+    gen = torch.Generator().manual_seed(0)
+    prompt = torch.randint(0, 64, (2, 5), generator=gen)

Also applies to: 149-149

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_kv_cache_consistency.py` at line 39, Remove the global
torch.manual_seed(0) call at line 39 that sets RNG state for the entire pytest
process and creates order-dependent test behavior. Instead, apply
torch.manual_seed() locally only where determinism is actually needed (such as
around line 149 for prompt sampling operations). This keeps RNG seeding scoped
to specific test operations rather than mutating global state that can affect
unrelated tests.

Signed-off-by: Zhang Jian <jianmusings@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[WS1] KV-cache path consistency (prefill & decode)

1 participant