|
| 1 | +# Batched Prefill — Implementation Handoff (2026-04-15) |
| 2 | + |
| 3 | +## Status |
| 4 | + |
| 5 | +- ✅ **Strategy**: documented in `bench/results/2026-04-15_accelerate_gemm_microbench.md`. |
| 6 | + Apple AMX delivers 1.2 TFLOPS via cblas_sgemm; GEMV path peaks at ~15 GFLOPS. |
| 7 | + 100× speedup is real but requires the workload to be batched. |
| 8 | +- ✅ **Primitive**: `tq_batched_matmul_q4()` in `src/engine/tq_ops.c`. |
| 9 | + Unit-tested in `tools/test_batched_matmul.c` — 12/12 PASS, max_rel=0.0000, |
| 10 | + observed speedups 1.2-3.0× across realistic shapes (Phi-3.5, Llama 3.x). |
| 11 | +- ✅ **Integration scaffolding**: `tq_forward_batch()` in `src/engine/tq_transformer.c` |
| 12 | + + opt-in flag `TQ_BATCH_PREFILL=1` in `src/engine/tq_generate.c`. |
| 13 | + Compiles, runs, falls back gracefully on unsupported architectures. |
| 14 | +- ❌ **Numerical correctness of tq_forward_batch**: end-to-end output diverges |
| 15 | + from baseline. Matmul primitive is bit-identical (verified at primitive |
| 16 | + level), so the bug is somewhere in the surrounding orchestration (state, |
| 17 | + RoPE, KV cache layout, residual flow, or embedding source). |
| 18 | + |
| 19 | +Reproduce divergence: |
| 20 | +```bash |
| 21 | +DYLD_LIBRARY_PATH=build TQ_BATCH_PREFILL=1 \ |
| 22 | + build/quant models/Llama-3.2-1B-Instruct-Q8_0.gguf -p "Hello world" -n 5 -T 0 |
| 23 | +# prints: " hell hel hell h hel" |
| 24 | + |
| 25 | +DYLD_LIBRARY_PATH=build \ |
| 26 | + build/quant models/Llama-3.2-1B-Instruct-Q8_0.gguf -p "Hello world" -n 5 -T 0 |
| 27 | +# prints: " I'm so excited" ← baseline (correct) |
| 28 | +``` |
| 29 | + |
| 30 | +## Debugging plan for next session |
| 31 | + |
| 32 | +1. **Add intermediate-state dumps** to `tq_forward` and `tq_forward_batch`. |
| 33 | + Compare layer-0 outputs (Xres after attention residual) byte-by-byte |
| 34 | + for the same single token. If they differ, the bug is at layer 0 |
| 35 | + before any batching matters. |
| 36 | + |
| 37 | +2. **Likely suspects ranked by probability**: |
| 38 | + |
| 39 | + **(a) RMSNorm input vs output buffer.** My code calls |
| 40 | + `tq_rmsnorm(XBN+n*dim, Xres+n*dim, ...)`. tq_forward calls |
| 41 | + `tq_rmsnorm(s->xb, s->x, ...)`. The pattern is the same, but verify |
| 42 | + the eps value is exactly `c->rms_norm_eps` and the weight pointer is |
| 43 | + `layer->attn_norm` (not `layer->ffn_norm`). |
| 44 | + |
| 45 | + **(b) KV cache stride.** Forward uses `cache_kv_dim` (computed via |
| 46 | + sliding/full max). For Llama 3.x non-Gemma this should equal kv_dim, |
| 47 | + but worth printing both at write time to confirm. |
| 48 | + |
| 49 | + **(c) attn_output_gate.** Llama doesn't have it, but verify |
| 50 | + `c->attn_output_gate == 0`. |
| 51 | + |
| 52 | + **(d) Output deinterleave.** When attn_output_gate is set, Q is |
| 53 | + interleaved with a gate. We don't handle this in tq_forward_batch |
| 54 | + because we bail when the flag is set... but is the bail check there? |
| 55 | + (Currently no — should add.) |
| 56 | + |
| 57 | +3. **RoPE freq formula.** My batched code: |
| 58 | + ```c |
| 59 | + float base = 1.0f / powf(c->rope_freq_base, 2.0f * i / (float)c->head_dim); |
| 60 | + float freq = base / model->rope_freqs[i]; |
| 61 | + ``` |
| 62 | + Compare to tq_forward line 1217-1219: |
| 63 | + ```c |
| 64 | + float base_freq = 1.0f / powf(rope_base, 2.0f * i / (float)rope_n_dims); |
| 65 | + float freq = base_freq / model->rope_freqs[i]; |
| 66 | + ``` |
| 67 | + Note: `rope_n_dims` may not equal `head_dim`! For Gemma 4 they differ. |
| 68 | + For Llama 3 should be same but verify `c->rope_n_dims` and use it |
| 69 | + instead of head_dim. |
| 70 | + |
| 71 | +4. **The attention computation**. Mine uses the simplest causal scan over |
| 72 | + K/V cache positions. tq_forward uses the same but might apply scaling |
| 73 | + factors (logit_softcap for Gemma, attention_bias for some, etc.). |
| 74 | + Llama 3 should be plain — verify no scale factor is missed. |
| 75 | + |
| 76 | +## Architectural targets |
| 77 | + |
| 78 | +Once correctness is achieved, expected gains (per microbench): |
| 79 | +- N=8 prefill chunk: ~3× per-matmul vs single |
| 80 | +- N=32 prefill chunk: ~30× per-matmul vs single |
| 81 | +- N=128 prefill chunk: ~60-100× per-matmul vs single |
| 82 | + |
| 83 | +The `tq_batched_matmul_q4` primitive currently gives 1.5-3× because it |
| 84 | +works over Q4 weights and dequant overhead caps the win below the AMX |
| 85 | +ceiling. Future optimizations (in priority): |
| 86 | + |
| 87 | +1. **Persistent FP16 lm_head** — ~1.5× decode + huge prefill win for Qwen3.5-4B |
| 88 | +2. **BNNS quantized GEMM** — direct AMX with int8 weights, no dequant overhead |
| 89 | +3. **MPSGraph for multi-layer fused forward** — entire layer on GPU |
| 90 | + |
| 91 | +## Out-of-scope for this implementation |
| 92 | + |
| 93 | +Don't try to make tq_forward_batch handle every architecture in v1. Bail |
| 94 | +on: |
| 95 | +- `is_gemma4`, `is_moe`, `has_fused_qkv`, `has_fused_up_gate` |
| 96 | +- `n_kv_shared_layers > 0` |
| 97 | +- Any layer with `delta_a_log` (DeltaNet) |
| 98 | +- `attn_output_gate` (rare; just bail) |
| 99 | +- `partial_rotary_factor > 0` (Phi-3 LongRoPE) |
| 100 | + |
| 101 | +The fast path should cover Llama 1B/3B/8B Q8_0 and Q4_K_M with load-time |
| 102 | +Q4 conversion. Other models stay on the per-token path. We can extend |
| 103 | +case-by-case. |
| 104 | + |
| 105 | +## Files touched |
| 106 | + |
| 107 | +- `include/turboquant/tq_engine.h` — declared `tq_forward_batch` and |
| 108 | + `tq_batched_matmul_q4`. |
| 109 | +- `src/engine/tq_ops.c` — implemented `tq_batched_matmul_q4` and worker. |
| 110 | +- `src/engine/tq_transformer.c` — implemented `tq_forward_batch` (WIP). |
| 111 | +- `src/engine/tq_generate.c` — gated integration behind TQ_BATCH_PREFILL. |
| 112 | +- `tools/test_batched_matmul.c` — primitive correctness + speed test. |
| 113 | +- `bench/results/2026-04-15_accelerate_gemm_microbench.md` — strategy doc. |
0 commit comments