@@ -3141,12 +3141,22 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
31413141 tq_rmsnorm (XBN + (size_t )n * dim , Xres + (size_t )n * dim ,
31423142 layer -> attn_norm , dim , c -> rms_norm_eps );
31433143 }
3144+ if (l == 0 && dbg ) {
3145+ fprintf (stderr , "[batch] L0 XBN (after attn_norm) tok0 [0:8] = " );
3146+ for (int i = 0 ; i < 8 ; i ++ ) fprintf (stderr , "%.4f " , XBN [i ]);
3147+ fprintf (stderr , "\n" );
3148+ }
31443149
31453150 /* 2. Q, K, V batched matmul (Q4 main weights) */
31463151 tq_batched_matmul_q4 (QB , layer -> wq_q4 , layer -> wq_q4s , XBN , q_dim , dim , N , NULL );
31473152 tq_batched_matmul_q4 (KB , layer -> wk_q4 , layer -> wk_q4s , XBN , kv_dim , dim , N , NULL );
31483153 tq_batched_matmul_q4 (VB , layer -> wv_q4 , layer -> wv_q4s , XBN , kv_dim , dim , N , NULL );
31493154
3155+ if (l == 0 && dbg ) {
3156+ fprintf (stderr , "[batch] L0 VB (post-matmul) tok0 [0:8] = " );
3157+ for (int i = 0 ; i < 8 ; i ++ ) fprintf (stderr , "%.4f " , VB [i ]);
3158+ fprintf (stderr , "\n" );
3159+ }
31503160 /* 2-r. Add Q2 residual correction per-token (matches tq_matmul_q4q2_preq).
31513161 * Load-time Q4 conversion stores BOTH Q4 main + Q2 residual. Skipping the
31523162 * Q2 part causes large numerical drift. We do the Q2 part per-token using
@@ -3252,43 +3262,27 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
32523262 tq_rope (qn , kn , pos , c -> head_dim , c -> n_heads , c -> n_kv_heads ,
32533263 c -> rope_freq_base );
32543264 }
3265+ if (n == 0 && l == 0 && dbg ) {
3266+ fprintf (stderr , "[batch] L0 QB (post-RoPE) tok0 [0:8] = " );
3267+ for (int i = 0 ; i < 8 ; i ++ ) fprintf (stderr , "%.4f " , qn [i ]);
3268+ fprintf (stderr , "\n" );
3269+ fprintf (stderr , "[batch] L0 KB (post-RoPE) tok0 [0:8] = " );
3270+ for (int i = 0 ; i < 8 ; i ++ ) fprintf (stderr , "%.4f " , kn [i ]);
3271+ fprintf (stderr , "\n" );
3272+ }
32553273 /* Write to cache */
32563274 memcpy (s -> key_cache + (size_t )l * kv_layer_stride + (size_t )pos * kv_dim ,
32573275 kn , (size_t )kv_dim * sizeof (float ));
32583276 if (s -> value_cache ) {
32593277 memcpy (s -> value_cache + (size_t )l * kv_layer_stride + (size_t )pos * kv_dim ,
32603278 VB + (size_t )n * kv_dim , (size_t )kv_dim * sizeof (float ));
32613279 } else if (s -> value_cache_fp16 ) {
3262- /* FP32 → FP16 conversion for storage. */
3280+ /* Match tq_forward exactly: hardware FP16 conversion via NEON
3281+ * vcvt_f16_f32. Inline manual conversion gave subtly different
3282+ * rounding which propagated through attention and broke output. */
32633283 uint16_t * dst = s -> value_cache_fp16
32643284 + (size_t )l * kv_layer_stride + (size_t )pos * kv_dim ;
3265- const float * src = VB + (size_t )n * kv_dim ;
3266- for (int i = 0 ; i < kv_dim ; i ++ ) {
3267- /* Use round-to-nearest IEEE 754 binary16 conversion via union */
3268- union { float f ; uint32_t u ; } v = { .f = src [i ] };
3269- uint32_t b = v .u ;
3270- uint16_t sign = (b >> 16 ) & 0x8000 ;
3271- int32_t e = (int32_t )((b >> 23 ) & 0xff ) - 127 + 15 ;
3272- uint32_t m = b & 0x7fffff ;
3273- uint16_t out ;
3274- if (e <= 0 ) {
3275- if (e < -10 ) out = sign ;
3276- else {
3277- m = (m | 0x800000 ) >> (1 - e );
3278- if (m & 0x1000 ) m += 0x2000 ;
3279- out = sign | (uint16_t )(m >> 13 );
3280- }
3281- } else if (e >= 31 ) {
3282- out = sign | 0x7c00 | (m ? (uint16_t )(m >> 13 ) : 0 );
3283- } else {
3284- if (m & 0x1000 ) {
3285- m += 0x2000 ;
3286- if (m & 0x800000 ) { m = 0 ; e ++ ; }
3287- }
3288- out = sign | ((uint16_t )e << 10 ) | (uint16_t )(m >> 13 );
3289- }
3290- dst [i ] = out ;
3291- }
3285+ f32_to_fp16_vec (VB + (size_t )n * kv_dim , dst , kv_dim );
32923286 } else {
32933287 if (dbg ) fprintf (stderr , "[batch] bail: no FP32/FP16 V cache\n" );
32943288 free (X ); free (Xres ); free (XBN ); free (QB ); free (KB ); free (VB );
@@ -3319,7 +3313,19 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
33193313 for (int t = 0 ; t <= pos ; t ++ ) {
33203314 float * kh = K_layer + (size_t )t * kv_dim + kvh * head_dim ;
33213315 float score = 0.0f ;
3316+ #ifdef __ARM_NEON
3317+ float32x4_t vsum = vdupq_n_f32 (0.0f );
3318+ int d = 0 ;
3319+ for (; d + 3 < head_dim ; d += 4 ) {
3320+ float32x4_t vq = vld1q_f32 (qh + d );
3321+ float32x4_t vk = vld1q_f32 (kh + d );
3322+ vsum = vfmaq_f32 (vsum , vq , vk );
3323+ }
3324+ score = vaddvq_f32 (vsum );
3325+ for (; d < head_dim ; d ++ ) score += qh [d ] * kh [d ];
3326+ #else
33223327 for (int i = 0 ; i < head_dim ; i ++ ) score += qh [i ] * kh [i ];
3328+ #endif
33233329 att [t ] = score * scale ;
33243330 }
33253331 tq_softmax (att , pos + 1 );
@@ -3332,38 +3338,45 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
33323338 for (int i = 0 ; i < head_dim ; i ++ ) oh [i ] += w * vh [i ];
33333339 }
33343340 } else {
3335- /* FP16 V cache: dequant per element via shift. */
3341+ /* FP16 V cache: use NEON vcvt_f32_f16 to exactly match the
3342+ * per-token attention path. Inline IEEE-754 conversion gave
3343+ * subtly different rounding (1 ULP) which compounded across
3344+ * 16 layers into garbage output. */
33363345 for (int t = 0 ; t <= pos ; t ++ ) {
33373346 uint16_t * vh = V_layer_fp16 + (size_t )t * kv_dim + kvh * head_dim ;
33383347 float w = att [t ];
3348+ if (w == 0.0f ) continue ;
3349+ #ifdef __ARM_NEON
3350+ float32x4_t va = vdupq_n_f32 (w );
3351+ int i = 0 ;
3352+ for (; i + 3 < head_dim ; i += 4 ) {
3353+ uint16x4_t vh4 = vld1_u16 (vh + i );
3354+ float32x4_t vf = vcvt_f32_f16 (vreinterpret_f16_u16 (vh4 ));
3355+ float32x4_t vx = vld1q_f32 (oh + i );
3356+ vst1q_f32 (oh + i , vfmaq_f32 (vx , va , vf ));
3357+ }
3358+ for (; i < head_dim ; i ++ ) {
3359+ uint16_t h16 = vh [i ];
3360+ __fp16 hf = * (const __fp16 * )& h16 ;
3361+ oh [i ] += w * (float )hf ;
3362+ }
3363+ #else
33393364 for (int i = 0 ; i < head_dim ; i ++ ) {
33403365 uint16_t h16 = vh [i ];
3341- uint32_t sign = (uint32_t )(h16 >> 15 ) << 31 ;
3342- uint32_t exp = (h16 >> 10 ) & 0x1f ;
3343- uint32_t mant = h16 & 0x3ff ;
3344- uint32_t bits ;
3345- if (exp == 0 ) {
3346- if (mant == 0 ) bits = sign ;
3347- else {
3348- /* subnormal */
3349- while (!(mant & 0x400 )) { mant <<= 1 ; exp -- ; }
3350- mant &= 0x3ff ;
3351- bits = sign | ((exp + 127 - 15 + 1 ) << 23 ) | (mant << 13 );
3352- }
3353- } else if (exp == 31 ) {
3354- bits = sign | 0x7f800000u | (mant << 13 );
3355- } else {
3356- bits = sign | ((exp + 127 - 15 ) << 23 ) | (mant << 13 );
3357- }
3358- float vf ;
3359- memcpy (& vf , & bits , 4 );
3360- oh [i ] += w * vf ;
3366+ __fp16 hf = * (const __fp16 * )& h16 ;
3367+ oh [i ] += w * (float )hf ;
33613368 }
3369+ #endif
33623370 }
33633371 }
33643372 }
33653373 }
33663374
3375+ if (l == 0 && dbg ) {
3376+ fprintf (stderr , "[batch] L0 OB (post-attn) tok0 [0:8] = " );
3377+ for (int i = 0 ; i < 8 ; i ++ ) fprintf (stderr , "%.4f " , OB [i ]);
3378+ fprintf (stderr , "\n" );
3379+ }
33673380 /* 5. O matmul batched + Q2 residual */
33683381 tq_batched_matmul_q4 (X , layer -> wo_q4 , layer -> wo_q4s , OB , dim , q_dim , N , NULL );
33693382 if (layer -> wo_q2 ) {
@@ -3378,9 +3391,26 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
33783391 free (tmp );
33793392 }
33803393
3394+ if (l == 0 && dbg ) {
3395+ fprintf (stderr , "[batch] L0 X (after wo matmul) tok0 [0:8] = " );
3396+ for (int i = 0 ; i < 8 ; i ++ ) fprintf (stderr , "%.4f " , X [i ]);
3397+ fprintf (stderr , "\n" );
3398+ }
33813399 /* 6. Residual: Xres += X */
33823400 for (size_t i = 0 ; i < (size_t )N * dim ; i ++ ) Xres [i ] += X [i ];
33833401
3402+ if (l == 0 && dbg ) {
3403+ fprintf (stderr , "[batch] L0 after-attn-residual Xres[tok0,0:8] = " );
3404+ for (int i = 0 ; i < 8 ; i ++ ) fprintf (stderr , "%.4f " , Xres [i ]);
3405+ fprintf (stderr , "\n" );
3406+ fprintf (stderr , "[batch] L0 after-attn-residual QB[tok0,0:8] = " );
3407+ for (int i = 0 ; i < 8 ; i ++ ) fprintf (stderr , "%.4f " , QB [i ]);
3408+ fprintf (stderr , "\n" );
3409+ fprintf (stderr , "[batch] L0 after-attn-residual KB[tok0,0:8] = " );
3410+ for (int i = 0 ; i < 8 ; i ++ ) fprintf (stderr , "%.4f " , KB [i ]);
3411+ fprintf (stderr , "\n" );
3412+ }
3413+
33843414 /* 7. ffn_norm */
33853415 for (int n = 0 ; n < N ; n ++ ) {
33863416 tq_rmsnorm (XBN + (size_t )n * dim , Xres + (size_t )n * dim ,
@@ -3431,6 +3461,12 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
34313461
34323462 /* 11. Residual: Xres += X */
34333463 for (size_t i = 0 ; i < (size_t )N * dim ; i ++ ) Xres [i ] += X [i ];
3464+
3465+ if (l == 0 && dbg ) {
3466+ fprintf (stderr , "[batch] L0 final Xres tok0 [0:8] = " );
3467+ for (int i = 0 ; i < 8 ; i ++ ) fprintf (stderr , "%.4f " , Xres [i ]);
3468+ fprintf (stderr , "\n" );
3469+ }
34343470 }
34353471
34363472 free (X ); free (XBN ); free (QB ); free (KB ); free (VB ); free (OB ); free (GB ); free (UB );
0 commit comments