Skip to content

Commit 43b4148

Browse files
unamedkrclaude
andcommitted
docs: handoff brief for batched prefill debugging
Documents the state of the batched prefill effort: - Primitive (tq_batched_matmul_q4) verified correct and fast (12/12, 1.5-3×) - Microbench validates 30-100× ceiling via Apple AMX (cblas_sgemm @ 1.2 TFLOPS) - tq_forward_batch integration scaffolded but numerically diverges - Specific debug plan + likely-suspects ranking for next session - Architectural targets and out-of-scope guardrails Reading this brief in next session should give a 1-hour head-start vs re-deriving the strategy. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 174f891 commit 43b4148

1 file changed

Lines changed: 113 additions & 0 deletions

File tree

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Batched Prefill — Implementation Handoff (2026-04-15)
2+
3+
## Status
4+
5+
-**Strategy**: documented in `bench/results/2026-04-15_accelerate_gemm_microbench.md`.
6+
Apple AMX delivers 1.2 TFLOPS via cblas_sgemm; GEMV path peaks at ~15 GFLOPS.
7+
100× speedup is real but requires the workload to be batched.
8+
-**Primitive**: `tq_batched_matmul_q4()` in `src/engine/tq_ops.c`.
9+
Unit-tested in `tools/test_batched_matmul.c` — 12/12 PASS, max_rel=0.0000,
10+
observed speedups 1.2-3.0× across realistic shapes (Phi-3.5, Llama 3.x).
11+
-**Integration scaffolding**: `tq_forward_batch()` in `src/engine/tq_transformer.c`
12+
+ opt-in flag `TQ_BATCH_PREFILL=1` in `src/engine/tq_generate.c`.
13+
Compiles, runs, falls back gracefully on unsupported architectures.
14+
-**Numerical correctness of tq_forward_batch**: end-to-end output diverges
15+
from baseline. Matmul primitive is bit-identical (verified at primitive
16+
level), so the bug is somewhere in the surrounding orchestration (state,
17+
RoPE, KV cache layout, residual flow, or embedding source).
18+
19+
Reproduce divergence:
20+
```bash
21+
DYLD_LIBRARY_PATH=build TQ_BATCH_PREFILL=1 \
22+
build/quant models/Llama-3.2-1B-Instruct-Q8_0.gguf -p "Hello world" -n 5 -T 0
23+
# prints: " hell hel hell h hel"
24+
25+
DYLD_LIBRARY_PATH=build \
26+
build/quant models/Llama-3.2-1B-Instruct-Q8_0.gguf -p "Hello world" -n 5 -T 0
27+
# prints: " I'm so excited" ← baseline (correct)
28+
```
29+
30+
## Debugging plan for next session
31+
32+
1. **Add intermediate-state dumps** to `tq_forward` and `tq_forward_batch`.
33+
Compare layer-0 outputs (Xres after attention residual) byte-by-byte
34+
for the same single token. If they differ, the bug is at layer 0
35+
before any batching matters.
36+
37+
2. **Likely suspects ranked by probability**:
38+
39+
**(a) RMSNorm input vs output buffer.** My code calls
40+
`tq_rmsnorm(XBN+n*dim, Xres+n*dim, ...)`. tq_forward calls
41+
`tq_rmsnorm(s->xb, s->x, ...)`. The pattern is the same, but verify
42+
the eps value is exactly `c->rms_norm_eps` and the weight pointer is
43+
`layer->attn_norm` (not `layer->ffn_norm`).
44+
45+
**(b) KV cache stride.** Forward uses `cache_kv_dim` (computed via
46+
sliding/full max). For Llama 3.x non-Gemma this should equal kv_dim,
47+
but worth printing both at write time to confirm.
48+
49+
**(c) attn_output_gate.** Llama doesn't have it, but verify
50+
`c->attn_output_gate == 0`.
51+
52+
**(d) Output deinterleave.** When attn_output_gate is set, Q is
53+
interleaved with a gate. We don't handle this in tq_forward_batch
54+
because we bail when the flag is set... but is the bail check there?
55+
(Currently no — should add.)
56+
57+
3. **RoPE freq formula.** My batched code:
58+
```c
59+
float base = 1.0f / powf(c->rope_freq_base, 2.0f * i / (float)c->head_dim);
60+
float freq = base / model->rope_freqs[i];
61+
```
62+
Compare to tq_forward line 1217-1219:
63+
```c
64+
float base_freq = 1.0f / powf(rope_base, 2.0f * i / (float)rope_n_dims);
65+
float freq = base_freq / model->rope_freqs[i];
66+
```
67+
Note: `rope_n_dims` may not equal `head_dim`! For Gemma 4 they differ.
68+
For Llama 3 should be same but verify `c->rope_n_dims` and use it
69+
instead of head_dim.
70+
71+
4. **The attention computation**. Mine uses the simplest causal scan over
72+
K/V cache positions. tq_forward uses the same but might apply scaling
73+
factors (logit_softcap for Gemma, attention_bias for some, etc.).
74+
Llama 3 should be plain — verify no scale factor is missed.
75+
76+
## Architectural targets
77+
78+
Once correctness is achieved, expected gains (per microbench):
79+
- N=8 prefill chunk: ~3× per-matmul vs single
80+
- N=32 prefill chunk: ~30× per-matmul vs single
81+
- N=128 prefill chunk: ~60-100× per-matmul vs single
82+
83+
The `tq_batched_matmul_q4` primitive currently gives 1.5-3× because it
84+
works over Q4 weights and dequant overhead caps the win below the AMX
85+
ceiling. Future optimizations (in priority):
86+
87+
1. **Persistent FP16 lm_head**~1.5× decode + huge prefill win for Qwen3.5-4B
88+
2. **BNNS quantized GEMM** — direct AMX with int8 weights, no dequant overhead
89+
3. **MPSGraph for multi-layer fused forward** — entire layer on GPU
90+
91+
## Out-of-scope for this implementation
92+
93+
Don't try to make tq_forward_batch handle every architecture in v1. Bail
94+
on:
95+
- `is_gemma4`, `is_moe`, `has_fused_qkv`, `has_fused_up_gate`
96+
- `n_kv_shared_layers > 0`
97+
- Any layer with `delta_a_log` (DeltaNet)
98+
- `attn_output_gate` (rare; just bail)
99+
- `partial_rotary_factor > 0` (Phi-3 LongRoPE)
100+
101+
The fast path should cover Llama 1B/3B/8B Q8_0 and Q4_K_M with load-time
102+
Q4 conversion. Other models stay on the per-token path. We can extend
103+
case-by-case.
104+
105+
## Files touched
106+
107+
- `include/turboquant/tq_engine.h` — declared `tq_forward_batch` and
108+
`tq_batched_matmul_q4`.
109+
- `src/engine/tq_ops.c` — implemented `tq_batched_matmul_q4` and worker.
110+
- `src/engine/tq_transformer.c` — implemented `tq_forward_batch` (WIP).
111+
- `src/engine/tq_generate.c` — gated integration behind TQ_BATCH_PREFILL.
112+
- `tools/test_batched_matmul.c` — primitive correctness + speed test.
113+
- `bench/results/2026-04-15_accelerate_gemm_microbench.md` — strategy doc.

0 commit comments

Comments
 (0)