Skip to content

Commit 3c2e9a2

Browse files
unamedkrclaude
andcommitted
wip(prefill): add Q2 residual handling, biases, QK-norm, sanity mode + handoff
Continued debugging of tq_forward_batch numerical mismatch: Added (all needed for correctness across architectures): - Q/K/V bias application after batched matmul (Qwen2/2.5/3 — NULL for Llama) - QK-norm (Qwen3 — NULL for Llama) - Q2 residual correction per-token after Q4 batched matmul (matches tq_matmul_q4q2_preq math; Q2 weights are NULL on Llama 3.2 anyway, confirmed via debug print) - rope_n_dims selection (was using head_dim, now uses c->rope_n_dims if set) - SANITY mode (TQ_BATCH_SANITY=1): make tq_forward_batch just call tq_forward N times — VERIFIED PASSES, proving the integration with tq_generate is correct and the bug is purely in the batched math path. Status: sanity passes, batched path still diverges. Next session needs intermediate-state diff between baseline and batched at layer 0 to isolate which sub-op is wrong. Strong suspect documented in handoff: tq_forward uses tq_matmul_q4 (FP32 input) for wo/gate/up/down vs tq_matmul_q4q2_preq (Q8 input) for wq/wk/wv — subtle quantization rounding differences may compound across layers. Default behavior unchanged (TQ_BATCH_PREFILL gates the new path; off by default until verified). 11/11 STRICT tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 43b4148 commit 3c2e9a2

2 files changed

Lines changed: 152 additions & 8 deletions

File tree

docs/dev/batched_prefill_handoff.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,31 @@ DYLD_LIBRARY_PATH=build \
2727
# prints: " I'm so excited" ← baseline (correct)
2828
```
2929

30+
## Latest session findings (2026-04-15 evening)
31+
32+
-**SANITY mode confirms orchestration is correct**. Setting
33+
`TQ_BATCH_SANITY=1` makes `tq_forward_batch` simply call `tq_forward`
34+
N times and the output matches baseline ("I'm so excited"). The bug
35+
is purely in the per-token unrolled batched code, not in the integration
36+
with `tq_generate`.
37+
38+
-**Q4 matmul primitive verified at runtime** with both bias and Q-norm
39+
fixes added (NULL for Llama anyway). Q2 residual handling added too —
40+
but Llama 3.2 1B's load-time Q4 conversion does NOT produce Q2
41+
residuals (`wq_q2 == NULL` confirmed by debug print), so Q2 isn't the
42+
culprit either.
43+
44+
-**Bug still present in actual batched path**. Output remains
45+
"hell hel hell..." at N=2.
46+
47+
- 🔍 **Strong suspect**: tq_forward uses different matmul function
48+
variants for different projections (`tq_matmul_q4q2_preq` for wq/wk/wv,
49+
`tq_matmul_q4` for wo/gate/up/down). Although they should be
50+
functionally equivalent, there may be subtle differences (e.g.,
51+
rounding mode in input quantization, scale normalization). The
52+
systematic next step is to **dump s->x[0..3] after layer 0 from both
53+
paths** — this isolates which sub-op diverges.
54+
3055
## Debugging plan for next session
3156

3257
1. **Add intermediate-state dumps** to `tq_forward` and `tq_forward_batch`.

src/engine/tq_transformer.c

Lines changed: 127 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3035,6 +3035,16 @@ float* tq_forward(tq_model_t* model, tq_state_t* s, int token, int pos) {
30353035
int tq_forward_batch(tq_model_t* model, tq_state_t* s,
30363036
const int* tokens, int N, int pos_start) {
30373037
if (N <= 0) return pos_start;
3038+
/* SANITY CHECK MODE: just call tq_forward N times. If THIS gives
3039+
* different results than the per-token tq_generate loop, the bug
3040+
* is in the orchestration outside the matmul work. Set
3041+
* TQ_BATCH_SANITY=1 to enable. */
3042+
if (getenv("TQ_BATCH_SANITY")) {
3043+
for (int n = 0; n < N; n++) {
3044+
tq_forward(model, s, tokens[n], pos_start + n);
3045+
}
3046+
return pos_start + N;
3047+
}
30383048
tq_model_config_t* c = &model->config;
30393049

30403050
/* Architectural gating: only standard Llama for now. */
@@ -3120,18 +3130,86 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
31203130
free(OB); free(GB); free(UB);
31213131
return -1;
31223132
}
3133+
if (l == 0 && dbg) {
3134+
fprintf(stderr, "[batch] layer 0 q2 presence: wq=%p wk=%p wv=%p wo=%p g=%p u=%p d=%p\n",
3135+
(void*)layer->wq_q2, (void*)layer->wk_q2, (void*)layer->wv_q2,
3136+
(void*)layer->wo_q2, (void*)layer->w_gate_q2, (void*)layer->w_up_q2, (void*)layer->w_down_q2);
3137+
}
31233138

31243139
/* 1. attn RMSNorm (per-row) */
31253140
for (int n = 0; n < N; n++) {
31263141
tq_rmsnorm(XBN + (size_t)n * dim, Xres + (size_t)n * dim,
31273142
layer->attn_norm, dim, c->rms_norm_eps);
31283143
}
31293144

3130-
/* 2. Q, K, V batched matmul */
3145+
/* 2. Q, K, V batched matmul (Q4 main weights) */
31313146
tq_batched_matmul_q4(QB, layer->wq_q4, layer->wq_q4s, XBN, q_dim, dim, N, NULL);
31323147
tq_batched_matmul_q4(KB, layer->wk_q4, layer->wk_q4s, XBN, kv_dim, dim, N, NULL);
31333148
tq_batched_matmul_q4(VB, layer->wv_q4, layer->wv_q4s, XBN, kv_dim, dim, N, NULL);
31343149

3150+
/* 2-r. Add Q2 residual correction per-token (matches tq_matmul_q4q2_preq).
3151+
* Load-time Q4 conversion stores BOTH Q4 main + Q2 residual. Skipping the
3152+
* Q2 part causes large numerical drift. We do the Q2 part per-token using
3153+
* the existing primitive — Q2 is small (2 bits) so the per-token cost is
3154+
* a fraction of the Q4 batched savings. */
3155+
if (layer->wq_q2 || layer->wk_q2 || layer->wv_q2) {
3156+
int n_blocks_d = dim / 32;
3157+
int8_t* xq = s->xb_q8; /* reuse state's per-token Q8 buffer */
3158+
float* xs = s->xb_q8s;
3159+
float* tmp_q = (float*)malloc((size_t)q_dim * sizeof(float));
3160+
float* tmp_k = (float*)malloc((size_t)kv_dim * sizeof(float));
3161+
float* tmp_v = (float*)malloc((size_t)kv_dim * sizeof(float));
3162+
for (int n = 0; n < N; n++) {
3163+
/* Quantize this row's XBN to Q8 once. */
3164+
tq_quantize_row_q8(XBN + (size_t)n * dim, xq, xs, dim);
3165+
if (layer->wq_q2) {
3166+
tq_matmul_q2_preq(tmp_q, layer->wq_q2, layer->wq_q2s, xq, xs, q_dim, dim);
3167+
for (int i = 0; i < q_dim; i++) QB[(size_t)n * q_dim + i] += tmp_q[i];
3168+
}
3169+
if (layer->wk_q2) {
3170+
tq_matmul_q2_preq(tmp_k, layer->wk_q2, layer->wk_q2s, xq, xs, kv_dim, dim);
3171+
for (int i = 0; i < kv_dim; i++) KB[(size_t)n * kv_dim + i] += tmp_k[i];
3172+
}
3173+
if (layer->wv_q2) {
3174+
tq_matmul_q2_preq(tmp_v, layer->wv_q2, layer->wv_q2s, xq, xs, kv_dim, dim);
3175+
for (int i = 0; i < kv_dim; i++) VB[(size_t)n * kv_dim + i] += tmp_v[i];
3176+
}
3177+
}
3178+
free(tmp_q); free(tmp_k); free(tmp_v);
3179+
(void)n_blocks_d;
3180+
}
3181+
3182+
/* 2a. Apply Q/K/V biases (Qwen2/2.5/3 — NULL for Llama). */
3183+
if (layer->q_bias) {
3184+
for (int n = 0; n < N; n++)
3185+
for (int i = 0; i < q_dim; i++) QB[(size_t)n * q_dim + i] += layer->q_bias[i];
3186+
}
3187+
if (layer->k_bias) {
3188+
for (int n = 0; n < N; n++)
3189+
for (int i = 0; i < kv_dim; i++) KB[(size_t)n * kv_dim + i] += layer->k_bias[i];
3190+
}
3191+
if (layer->v_bias) {
3192+
for (int n = 0; n < N; n++)
3193+
for (int i = 0; i < kv_dim; i++) VB[(size_t)n * kv_dim + i] += layer->v_bias[i];
3194+
}
3195+
/* 2b. QK-norm (Qwen3 — NULL for Llama). */
3196+
if (layer->q_norm) {
3197+
for (int n = 0; n < N; n++) {
3198+
for (int h = 0; h < c->n_heads; h++) {
3199+
float* qh = QB + (size_t)n * q_dim + h * c->head_dim;
3200+
tq_rmsnorm(qh, qh, layer->q_norm, c->head_dim, c->rms_norm_eps);
3201+
}
3202+
}
3203+
}
3204+
if (layer->k_norm) {
3205+
for (int n = 0; n < N; n++) {
3206+
for (int h = 0; h < c->n_kv_heads; h++) {
3207+
float* kh = KB + (size_t)n * kv_dim + h * c->head_dim;
3208+
tq_rmsnorm(kh, kh, layer->k_norm, c->head_dim, c->rms_norm_eps);
3209+
}
3210+
}
3211+
}
3212+
31353213
/* 3. RoPE + KV cache write (per-token).
31363214
* Mirror tq_forward's RoPE selection: if model->rope_freqs is set
31373215
* (Llama 3.x learned RoPE scaling, 64 freq factors), apply per-pair
@@ -3141,13 +3219,15 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
31413219
float* kn = KB + (size_t)n * kv_dim;
31423220
int pos = pos_start + n;
31433221
if (model->rope_freqs && model->rope_freqs_len > 0) {
3144-
int rope_pairs = c->head_dim / 2;
3222+
/* Match tq_forward's rope_n_dims selection: c->rope_n_dims may
3223+
* differ from head_dim (e.g., Gemma partial RoPE). */
3224+
int rope_n_dims = (c->rope_n_dims > 0) ? c->rope_n_dims : c->head_dim;
3225+
int rope_pairs = rope_n_dims / 2;
31453226
if (rope_pairs > model->rope_freqs_len) rope_pairs = model->rope_freqs_len;
3146-
/* Llama 3 uses interleaved layout (a=2i, b=2i+1) */
31473227
for (int h = 0; h < c->n_heads; h++) {
31483228
float* qh = qn + h * c->head_dim;
31493229
for (int i = 0; i < rope_pairs; i++) {
3150-
float base = 1.0f / powf(c->rope_freq_base, 2.0f * i / (float)c->head_dim);
3230+
float base = 1.0f / powf(c->rope_freq_base, 2.0f * i / (float)rope_n_dims);
31513231
float freq = base / model->rope_freqs[i];
31523232
float theta = pos * freq;
31533233
float ct = cosf(theta), st = sinf(theta);
@@ -3159,7 +3239,7 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
31593239
for (int h = 0; h < c->n_kv_heads; h++) {
31603240
float* kh = kn + h * c->head_dim;
31613241
for (int i = 0; i < rope_pairs; i++) {
3162-
float base = 1.0f / powf(c->rope_freq_base, 2.0f * i / (float)c->head_dim);
3242+
float base = 1.0f / powf(c->rope_freq_base, 2.0f * i / (float)rope_n_dims);
31633243
float freq = base / model->rope_freqs[i];
31643244
float theta = pos * freq;
31653245
float ct = cosf(theta), st = sinf(theta);
@@ -3284,8 +3364,19 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
32843364
}
32853365
}
32863366

3287-
/* 5. O matmul batched */
3367+
/* 5. O matmul batched + Q2 residual */
32883368
tq_batched_matmul_q4(X, layer->wo_q4, layer->wo_q4s, OB, dim, q_dim, N, NULL);
3369+
if (layer->wo_q2) {
3370+
int8_t* xq = s->xb_q8;
3371+
float* xs = s->xb_q8s;
3372+
float* tmp = (float*)malloc((size_t)dim * sizeof(float));
3373+
for (int n = 0; n < N; n++) {
3374+
tq_quantize_row_q8(OB + (size_t)n * q_dim, xq, xs, q_dim);
3375+
tq_matmul_q2_preq(tmp, layer->wo_q2, layer->wo_q2s, xq, xs, dim, q_dim);
3376+
for (int i = 0; i < dim; i++) X[(size_t)n * dim + i] += tmp[i];
3377+
}
3378+
free(tmp);
3379+
}
32893380

32903381
/* 6. Residual: Xres += X */
32913382
for (size_t i = 0; i < (size_t)N * dim; i++) Xres[i] += X[i];
@@ -3296,9 +3387,26 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
32963387
layer->ffn_norm, dim, c->rms_norm_eps);
32973388
}
32983389

3299-
/* 8. gate, up batched matmul */
3390+
/* 8. gate, up batched matmul + Q2 residuals */
33003391
tq_batched_matmul_q4(GB, layer->w_gate_q4, layer->w_gate_q4s, XBN, inter, dim, N, NULL);
33013392
tq_batched_matmul_q4(UB, layer->w_up_q4, layer->w_up_q4s, XBN, inter, dim, N, NULL);
3393+
if (layer->w_gate_q2 || layer->w_up_q2) {
3394+
int8_t* xq = s->xb_q8;
3395+
float* xs = s->xb_q8s;
3396+
float* tmp = (float*)malloc((size_t)inter * sizeof(float));
3397+
for (int n = 0; n < N; n++) {
3398+
tq_quantize_row_q8(XBN + (size_t)n * dim, xq, xs, dim);
3399+
if (layer->w_gate_q2) {
3400+
tq_matmul_q2_preq(tmp, layer->w_gate_q2, layer->w_gate_q2s, xq, xs, inter, dim);
3401+
for (int i = 0; i < inter; i++) GB[(size_t)n * inter + i] += tmp[i];
3402+
}
3403+
if (layer->w_up_q2) {
3404+
tq_matmul_q2_preq(tmp, layer->w_up_q2, layer->w_up_q2s, xq, xs, inter, dim);
3405+
for (int i = 0; i < inter; i++) UB[(size_t)n * inter + i] += tmp[i];
3406+
}
3407+
}
3408+
free(tmp);
3409+
}
33023410

33033411
/* 9. SiLU(gate) * up (per-element) */
33043412
for (size_t i = 0; i < (size_t)N * inter; i++) {
@@ -3307,8 +3415,19 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
33073415
GB[i] = silu * UB[i];
33083416
}
33093417

3310-
/* 10. down matmul batched (output back into X) */
3418+
/* 10. down matmul batched (output back into X) + Q2 residual */
33113419
tq_batched_matmul_q4(X, layer->w_down_q4, layer->w_down_q4s, GB, dim, inter, N, NULL);
3420+
if (layer->w_down_q2) {
3421+
int8_t* xq = s->xb_q8;
3422+
float* xs = s->xb_q8s;
3423+
float* tmp = (float*)malloc((size_t)dim * sizeof(float));
3424+
for (int n = 0; n < N; n++) {
3425+
tq_quantize_row_q8(GB + (size_t)n * inter, xq, xs, inter);
3426+
tq_matmul_q2_preq(tmp, layer->w_down_q2, layer->w_down_q2s, xq, xs, dim, inter);
3427+
for (int i = 0; i < dim; i++) X[(size_t)n * dim + i] += tmp[i];
3428+
}
3429+
free(tmp);
3430+
}
33123431

33133432
/* 11. Residual: Xres += X */
33143433
for (size_t i = 0; i < (size_t)N * dim; i++) Xres[i] += X[i];

0 commit comments

Comments
 (0)