Skip to content

Commit ece4185

Browse files
unamedkrclaude
andcommitted
wip(prefill): bit-identical Layer 0 — eliminated quantization/FP16/NEON drift
Multi-round debugging of tq_forward_batch numerical mismatch. Traced the divergence chain layer-by-layer with intermediate-state dumps: Layer 0 — XBN (after attn_norm) ✓ bit-identical Layer 0 — Q (post matmul + bias) ✓ bit-identical Layer 0 — K, V (post matmul) ✓ bit-identical Layer 0 — Q, K (post-RoPE) ✓ bit-identical (pos=0 → identity) Layer 0 — attention output (xb/OB) ✓ bit-identical AFTER fixes: - tq_quantize_row_q8 instead of inline roundf (RNE matters) - f32_to_fp16_vec instead of inline IEEE-754 conversion - vcvt_f32_f16 NEON intrinsic for V cache reads - vfmaq_f32 NEON for attention score accumulation Layer 0 — wo matmul output (X/xb2) ✓ bit-identical Layer 0 — final residual stream (Xres/x) ✓ bit-identical So Layer 0 is now byte-identical between batched and per-token paths. Output still diverges, meaning the bug is in Layer 1+ — most likely a similar 1-ULP drift compounding across 16 layers. Strong candidates: attention V-accumulator order on multi-position cases (pos=1+ has more than one t to sum), or one of the FFN ops (silu, mul, down matmul). The hard work of stamping out per-op rounding mismatches at Layer 0 is done. Next session should add the same dumps at Layer 1 (and possibly Layer 2,4,8) to pinpoint exactly which sub-op introduces the first divergence on the multi-position path. Default still TQ_BATCH_PREFILL gated. 11/11 STRICT tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3c2e9a2 commit ece4185

2 files changed

Lines changed: 93 additions & 69 deletions

File tree

src/engine/tq_ops.c

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,31 +1182,19 @@ void tq_batched_matmul_q4(float* out, const uint8_t* w_qs, const float* w_scales
11821182
return;
11831183
}
11841184

1185-
/* Pre-quantize all N input rows to int8 with per-block scales. */
1185+
/* Pre-quantize all N input rows to int8 with per-block scales.
1186+
* Use tq_quantize_row_q8 (NEON round-to-nearest-even via vcvtnq_s32_f32)
1187+
* to exactly match the per-token path's quantization — otherwise tiny
1188+
* rounding differences (1 ULP) propagate through 16+ layers and produce
1189+
* garbage output even though the matmul math is identical. */
11861190
int n_blocks = d / 32;
11871191
int8_t* X_q = (int8_t*)malloc((size_t)N * d * sizeof(int8_t));
11881192
float* X_d = (float*)malloc((size_t)N * n_blocks * sizeof(float));
11891193
if (!X_q || !X_d) { free(X_q); free(X_d); return; }
11901194
for (int n = 0; n < N; n++) {
1191-
for (int b = 0; b < n_blocks; b++) {
1192-
const float* xp = x + (size_t)n * d + b * 32;
1193-
float amax = 0.0f;
1194-
for (int j = 0; j < 32; j++) {
1195-
float a = xp[j] < 0 ? -xp[j] : xp[j];
1196-
if (a > amax) amax = a;
1197-
}
1198-
float dq = amax / 127.0f;
1199-
X_d[(size_t)n * n_blocks + b] = dq;
1200-
if (dq > 0.0f) {
1201-
float id = 1.0f / dq;
1202-
for (int j = 0; j < 32; j++) {
1203-
int v = (int)roundf(xp[j] * id);
1204-
X_q[(size_t)n * d + b*32 + j] = (int8_t)(v < -128 ? -128 : (v > 127 ? 127 : v));
1205-
}
1206-
} else {
1207-
memset(X_q + (size_t)n * d + b*32, 0, 32);
1208-
}
1209-
}
1195+
tq_quantize_row_q8(x + (size_t)n * d,
1196+
X_q + (size_t)n * d,
1197+
X_d + (size_t)n * n_blocks, d);
12101198
}
12111199

12121200
/* Parallel across rows. */

src/engine/tq_transformer.c

Lines changed: 85 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3141,12 +3141,22 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
31413141
tq_rmsnorm(XBN + (size_t)n * dim, Xres + (size_t)n * dim,
31423142
layer->attn_norm, dim, c->rms_norm_eps);
31433143
}
3144+
if (l == 0 && dbg) {
3145+
fprintf(stderr, "[batch] L0 XBN (after attn_norm) tok0 [0:8] = ");
3146+
for (int i = 0; i < 8; i++) fprintf(stderr, "%.4f ", XBN[i]);
3147+
fprintf(stderr, "\n");
3148+
}
31443149

31453150
/* 2. Q, K, V batched matmul (Q4 main weights) */
31463151
tq_batched_matmul_q4(QB, layer->wq_q4, layer->wq_q4s, XBN, q_dim, dim, N, NULL);
31473152
tq_batched_matmul_q4(KB, layer->wk_q4, layer->wk_q4s, XBN, kv_dim, dim, N, NULL);
31483153
tq_batched_matmul_q4(VB, layer->wv_q4, layer->wv_q4s, XBN, kv_dim, dim, N, NULL);
31493154

3155+
if (l == 0 && dbg) {
3156+
fprintf(stderr, "[batch] L0 VB (post-matmul) tok0 [0:8] = ");
3157+
for (int i = 0; i < 8; i++) fprintf(stderr, "%.4f ", VB[i]);
3158+
fprintf(stderr, "\n");
3159+
}
31503160
/* 2-r. Add Q2 residual correction per-token (matches tq_matmul_q4q2_preq).
31513161
* Load-time Q4 conversion stores BOTH Q4 main + Q2 residual. Skipping the
31523162
* Q2 part causes large numerical drift. We do the Q2 part per-token using
@@ -3252,43 +3262,27 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
32523262
tq_rope(qn, kn, pos, c->head_dim, c->n_heads, c->n_kv_heads,
32533263
c->rope_freq_base);
32543264
}
3265+
if (n == 0 && l == 0 && dbg) {
3266+
fprintf(stderr, "[batch] L0 QB (post-RoPE) tok0 [0:8] = ");
3267+
for (int i = 0; i < 8; i++) fprintf(stderr, "%.4f ", qn[i]);
3268+
fprintf(stderr, "\n");
3269+
fprintf(stderr, "[batch] L0 KB (post-RoPE) tok0 [0:8] = ");
3270+
for (int i = 0; i < 8; i++) fprintf(stderr, "%.4f ", kn[i]);
3271+
fprintf(stderr, "\n");
3272+
}
32553273
/* Write to cache */
32563274
memcpy(s->key_cache + (size_t)l * kv_layer_stride + (size_t)pos * kv_dim,
32573275
kn, (size_t)kv_dim * sizeof(float));
32583276
if (s->value_cache) {
32593277
memcpy(s->value_cache + (size_t)l * kv_layer_stride + (size_t)pos * kv_dim,
32603278
VB + (size_t)n * kv_dim, (size_t)kv_dim * sizeof(float));
32613279
} else if (s->value_cache_fp16) {
3262-
/* FP32 → FP16 conversion for storage. */
3280+
/* Match tq_forward exactly: hardware FP16 conversion via NEON
3281+
* vcvt_f16_f32. Inline manual conversion gave subtly different
3282+
* rounding which propagated through attention and broke output. */
32633283
uint16_t* dst = s->value_cache_fp16
32643284
+ (size_t)l * kv_layer_stride + (size_t)pos * kv_dim;
3265-
const float* src = VB + (size_t)n * kv_dim;
3266-
for (int i = 0; i < kv_dim; i++) {
3267-
/* Use round-to-nearest IEEE 754 binary16 conversion via union */
3268-
union { float f; uint32_t u; } v = { .f = src[i] };
3269-
uint32_t b = v.u;
3270-
uint16_t sign = (b >> 16) & 0x8000;
3271-
int32_t e = (int32_t)((b >> 23) & 0xff) - 127 + 15;
3272-
uint32_t m = b & 0x7fffff;
3273-
uint16_t out;
3274-
if (e <= 0) {
3275-
if (e < -10) out = sign;
3276-
else {
3277-
m = (m | 0x800000) >> (1 - e);
3278-
if (m & 0x1000) m += 0x2000;
3279-
out = sign | (uint16_t)(m >> 13);
3280-
}
3281-
} else if (e >= 31) {
3282-
out = sign | 0x7c00 | (m ? (uint16_t)(m >> 13) : 0);
3283-
} else {
3284-
if (m & 0x1000) {
3285-
m += 0x2000;
3286-
if (m & 0x800000) { m = 0; e++; }
3287-
}
3288-
out = sign | ((uint16_t)e << 10) | (uint16_t)(m >> 13);
3289-
}
3290-
dst[i] = out;
3291-
}
3285+
f32_to_fp16_vec(VB + (size_t)n * kv_dim, dst, kv_dim);
32923286
} else {
32933287
if (dbg) fprintf(stderr, "[batch] bail: no FP32/FP16 V cache\n");
32943288
free(X); free(Xres); free(XBN); free(QB); free(KB); free(VB);
@@ -3319,7 +3313,19 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
33193313
for (int t = 0; t <= pos; t++) {
33203314
float* kh = K_layer + (size_t)t * kv_dim + kvh * head_dim;
33213315
float score = 0.0f;
3316+
#ifdef __ARM_NEON
3317+
float32x4_t vsum = vdupq_n_f32(0.0f);
3318+
int d = 0;
3319+
for (; d + 3 < head_dim; d += 4) {
3320+
float32x4_t vq = vld1q_f32(qh + d);
3321+
float32x4_t vk = vld1q_f32(kh + d);
3322+
vsum = vfmaq_f32(vsum, vq, vk);
3323+
}
3324+
score = vaddvq_f32(vsum);
3325+
for (; d < head_dim; d++) score += qh[d] * kh[d];
3326+
#else
33223327
for (int i = 0; i < head_dim; i++) score += qh[i] * kh[i];
3328+
#endif
33233329
att[t] = score * scale;
33243330
}
33253331
tq_softmax(att, pos + 1);
@@ -3332,38 +3338,45 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
33323338
for (int i = 0; i < head_dim; i++) oh[i] += w * vh[i];
33333339
}
33343340
} else {
3335-
/* FP16 V cache: dequant per element via shift. */
3341+
/* FP16 V cache: use NEON vcvt_f32_f16 to exactly match the
3342+
* per-token attention path. Inline IEEE-754 conversion gave
3343+
* subtly different rounding (1 ULP) which compounded across
3344+
* 16 layers into garbage output. */
33363345
for (int t = 0; t <= pos; t++) {
33373346
uint16_t* vh = V_layer_fp16 + (size_t)t * kv_dim + kvh * head_dim;
33383347
float w = att[t];
3348+
if (w == 0.0f) continue;
3349+
#ifdef __ARM_NEON
3350+
float32x4_t va = vdupq_n_f32(w);
3351+
int i = 0;
3352+
for (; i + 3 < head_dim; i += 4) {
3353+
uint16x4_t vh4 = vld1_u16(vh + i);
3354+
float32x4_t vf = vcvt_f32_f16(vreinterpret_f16_u16(vh4));
3355+
float32x4_t vx = vld1q_f32(oh + i);
3356+
vst1q_f32(oh + i, vfmaq_f32(vx, va, vf));
3357+
}
3358+
for (; i < head_dim; i++) {
3359+
uint16_t h16 = vh[i];
3360+
__fp16 hf = *(const __fp16*)&h16;
3361+
oh[i] += w * (float)hf;
3362+
}
3363+
#else
33393364
for (int i = 0; i < head_dim; i++) {
33403365
uint16_t h16 = vh[i];
3341-
uint32_t sign = (uint32_t)(h16 >> 15) << 31;
3342-
uint32_t exp = (h16 >> 10) & 0x1f;
3343-
uint32_t mant = h16 & 0x3ff;
3344-
uint32_t bits;
3345-
if (exp == 0) {
3346-
if (mant == 0) bits = sign;
3347-
else {
3348-
/* subnormal */
3349-
while (!(mant & 0x400)) { mant <<= 1; exp--; }
3350-
mant &= 0x3ff;
3351-
bits = sign | ((exp + 127 - 15 + 1) << 23) | (mant << 13);
3352-
}
3353-
} else if (exp == 31) {
3354-
bits = sign | 0x7f800000u | (mant << 13);
3355-
} else {
3356-
bits = sign | ((exp + 127 - 15) << 23) | (mant << 13);
3357-
}
3358-
float vf;
3359-
memcpy(&vf, &bits, 4);
3360-
oh[i] += w * vf;
3366+
__fp16 hf = *(const __fp16*)&h16;
3367+
oh[i] += w * (float)hf;
33613368
}
3369+
#endif
33623370
}
33633371
}
33643372
}
33653373
}
33663374

3375+
if (l == 0 && dbg) {
3376+
fprintf(stderr, "[batch] L0 OB (post-attn) tok0 [0:8] = ");
3377+
for (int i = 0; i < 8; i++) fprintf(stderr, "%.4f ", OB[i]);
3378+
fprintf(stderr, "\n");
3379+
}
33673380
/* 5. O matmul batched + Q2 residual */
33683381
tq_batched_matmul_q4(X, layer->wo_q4, layer->wo_q4s, OB, dim, q_dim, N, NULL);
33693382
if (layer->wo_q2) {
@@ -3378,9 +3391,26 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
33783391
free(tmp);
33793392
}
33803393

3394+
if (l == 0 && dbg) {
3395+
fprintf(stderr, "[batch] L0 X (after wo matmul) tok0 [0:8] = ");
3396+
for (int i = 0; i < 8; i++) fprintf(stderr, "%.4f ", X[i]);
3397+
fprintf(stderr, "\n");
3398+
}
33813399
/* 6. Residual: Xres += X */
33823400
for (size_t i = 0; i < (size_t)N * dim; i++) Xres[i] += X[i];
33833401

3402+
if (l == 0 && dbg) {
3403+
fprintf(stderr, "[batch] L0 after-attn-residual Xres[tok0,0:8] = ");
3404+
for (int i = 0; i < 8; i++) fprintf(stderr, "%.4f ", Xres[i]);
3405+
fprintf(stderr, "\n");
3406+
fprintf(stderr, "[batch] L0 after-attn-residual QB[tok0,0:8] = ");
3407+
for (int i = 0; i < 8; i++) fprintf(stderr, "%.4f ", QB[i]);
3408+
fprintf(stderr, "\n");
3409+
fprintf(stderr, "[batch] L0 after-attn-residual KB[tok0,0:8] = ");
3410+
for (int i = 0; i < 8; i++) fprintf(stderr, "%.4f ", KB[i]);
3411+
fprintf(stderr, "\n");
3412+
}
3413+
33843414
/* 7. ffn_norm */
33853415
for (int n = 0; n < N; n++) {
33863416
tq_rmsnorm(XBN + (size_t)n * dim, Xres + (size_t)n * dim,
@@ -3431,6 +3461,12 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
34313461

34323462
/* 11. Residual: Xres += X */
34333463
for (size_t i = 0; i < (size_t)N * dim; i++) Xres[i] += X[i];
3464+
3465+
if (l == 0 && dbg) {
3466+
fprintf(stderr, "[batch] L0 final Xres tok0 [0:8] = ");
3467+
for (int i = 0; i < 8; i++) fprintf(stderr, "%.4f ", Xres[i]);
3468+
fprintf(stderr, "\n");
3469+
}
34343470
}
34353471

34363472
free(X); free(XBN); free(QB); free(KB); free(VB); free(OB); free(GB); free(UB);

0 commit comments

Comments
 (0)