[WIP] [WS1] Add tests to ensure kv cache consistency for prefill & decode#178
[WIP] [WS1] Add tests to ensure kv cache consistency for prefill & decode#178zhangj1an wants to merge 2 commits into
Conversation
Signed-off-by: Zhang Jian <jianmusings@gmail.com>
📝 WalkthroughWalkthroughAdds ChangesKV-cache path-consistency harness and tests
Sequence Diagram(s)sequenceDiagram
participant Test
participant TinyCausalLM
participant fixed_order_attention
participant replay_decode
participant KVCache
participant attend_single_query
rect rgba(70, 130, 180, 0.5)
Note over Test,attend_single_query: Prefill path
Test->>TinyCausalLM: prefill_logits(input_ids, attention_mask)
TinyCausalLM->>fixed_order_attention: q, k, v, key_mask
loop t = 0..T-1
fixed_order_attention->>attend_single_query: q_t, k[0..t], v[0..t]
attend_single_query-->>fixed_order_attention: out_t
end
fixed_order_attention-->>TinyCausalLM: attn_out [B,T,H,D]
TinyCausalLM-->>Test: logits [B,T,V]
end
rect rgba(60, 160, 100, 0.5)
Note over Test,attend_single_query: Decode path (replay)
Test->>TinyCausalLM: decode_logits(input_ids, kv_dtype)
TinyCausalLM->>replay_decode: q, k, v, kv_dtype
replay_decode->>KVCache: init(dtype=kv_dtype)
loop t = 0..T-1
replay_decode->>KVCache: append(k_t, v_t)
KVCache-->>replay_decode: context k[0..t], v[0..t]
replay_decode->>attend_single_query: q_t, cached_k, cached_v
attend_single_query-->>replay_decode: out_t
end
replay_decode-->>TinyCausalLM: attn_out [B,T,H,D]
TinyCausalLM-->>Test: logits [B,T,V]
end
Test->>Test: assert_path_parity(decode_logits, prefill_logits)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/test_kv_cache_consistency.py`:
- Line 39: Remove the global torch.manual_seed(0) call at line 39 that sets RNG
state for the entire pytest process and creates order-dependent test behavior.
Instead, apply torch.manual_seed() locally only where determinism is actually
needed (such as around line 149 for prompt sampling operations). This keeps RNG
seeding scoped to specific test operations rather than mutating global state
that can affect unrelated tests.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1bdc7bce-5b02-4472-8894-ce49f81438ab
📒 Files selected for processing (2)
rl_engine/testing/kv_consistency.pytests/test_kv_cache_consistency.py
| ) | ||
| from rl_engine.testing.reference_ops import selected_logprobs_reference | ||
|
|
||
| torch.manual_seed(0) |
There was a problem hiding this comment.
Localize RNG seeding to avoid cross-test state coupling.
Line 39 sets global RNG state for the entire pytest process. That can create order-dependent behavior in unrelated tests. Keep determinism local (e.g., for Line 149 prompt sampling) instead of mutating global state.
Suggested change
- torch.manual_seed(0)
@@
def test_generate_then_rescore_equivalence():
model = TinyCausalLM(vocab_size=64, d_model=48, spec=SPEC, seed=1)
- prompt = torch.randint(0, 64, (2, 5))
+ gen = torch.Generator().manual_seed(0)
+ prompt = torch.randint(0, 64, (2, 5), generator=gen)Also applies to: 149-149
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_kv_cache_consistency.py` at line 39, Remove the global
torch.manual_seed(0) call at line 39 that sets RNG state for the entire pytest
process and creates order-dependent test behavior. Instead, apply
torch.manual_seed() locally only where determinism is actually needed (such as
around line 149 for prompt sampling operations). This keeps RNG seeding scoped
to specific test operations rather than mutating global state that can affect
unrelated tests.
Signed-off-by: Zhang Jian <jianmusings@gmail.com>
Latest Status [22 June]
Code is mostly done, will write up PR description and make as ready tomorrow
Purpose
Closes #152.
Write a few unit tests to ensure that, given a prefill path and decode path, as long as they are using the same kernel, their output is within error threshold over a long sequence of tokens.
Open questions
KV CacheandParity Reportclass stay withinrl_engine/testing/kv_consistency.py?Key Changes
Summary by CodeRabbit