@@ -3035,6 +3035,16 @@ float* tq_forward(tq_model_t* model, tq_state_t* s, int token, int pos) {
30353035int tq_forward_batch (tq_model_t * model , tq_state_t * s ,
30363036 const int * tokens , int N , int pos_start ) {
30373037 if (N <= 0 ) return pos_start ;
3038+ /* SANITY CHECK MODE: just call tq_forward N times. If THIS gives
3039+ * different results than the per-token tq_generate loop, the bug
3040+ * is in the orchestration outside the matmul work. Set
3041+ * TQ_BATCH_SANITY=1 to enable. */
3042+ if (getenv ("TQ_BATCH_SANITY" )) {
3043+ for (int n = 0 ; n < N ; n ++ ) {
3044+ tq_forward (model , s , tokens [n ], pos_start + n );
3045+ }
3046+ return pos_start + N ;
3047+ }
30383048 tq_model_config_t * c = & model -> config ;
30393049
30403050 /* Architectural gating: only standard Llama for now. */
@@ -3120,18 +3130,86 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
31203130 free (OB ); free (GB ); free (UB );
31213131 return -1 ;
31223132 }
3133+ if (l == 0 && dbg ) {
3134+ fprintf (stderr , "[batch] layer 0 q2 presence: wq=%p wk=%p wv=%p wo=%p g=%p u=%p d=%p\n" ,
3135+ (void * )layer -> wq_q2 , (void * )layer -> wk_q2 , (void * )layer -> wv_q2 ,
3136+ (void * )layer -> wo_q2 , (void * )layer -> w_gate_q2 , (void * )layer -> w_up_q2 , (void * )layer -> w_down_q2 );
3137+ }
31233138
31243139 /* 1. attn RMSNorm (per-row) */
31253140 for (int n = 0 ; n < N ; n ++ ) {
31263141 tq_rmsnorm (XBN + (size_t )n * dim , Xres + (size_t )n * dim ,
31273142 layer -> attn_norm , dim , c -> rms_norm_eps );
31283143 }
31293144
3130- /* 2. Q, K, V batched matmul */
3145+ /* 2. Q, K, V batched matmul (Q4 main weights) */
31313146 tq_batched_matmul_q4 (QB , layer -> wq_q4 , layer -> wq_q4s , XBN , q_dim , dim , N , NULL );
31323147 tq_batched_matmul_q4 (KB , layer -> wk_q4 , layer -> wk_q4s , XBN , kv_dim , dim , N , NULL );
31333148 tq_batched_matmul_q4 (VB , layer -> wv_q4 , layer -> wv_q4s , XBN , kv_dim , dim , N , NULL );
31343149
3150+ /* 2-r. Add Q2 residual correction per-token (matches tq_matmul_q4q2_preq).
3151+ * Load-time Q4 conversion stores BOTH Q4 main + Q2 residual. Skipping the
3152+ * Q2 part causes large numerical drift. We do the Q2 part per-token using
3153+ * the existing primitive — Q2 is small (2 bits) so the per-token cost is
3154+ * a fraction of the Q4 batched savings. */
3155+ if (layer -> wq_q2 || layer -> wk_q2 || layer -> wv_q2 ) {
3156+ int n_blocks_d = dim / 32 ;
3157+ int8_t * xq = s -> xb_q8 ; /* reuse state's per-token Q8 buffer */
3158+ float * xs = s -> xb_q8s ;
3159+ float * tmp_q = (float * )malloc ((size_t )q_dim * sizeof (float ));
3160+ float * tmp_k = (float * )malloc ((size_t )kv_dim * sizeof (float ));
3161+ float * tmp_v = (float * )malloc ((size_t )kv_dim * sizeof (float ));
3162+ for (int n = 0 ; n < N ; n ++ ) {
3163+ /* Quantize this row's XBN to Q8 once. */
3164+ tq_quantize_row_q8 (XBN + (size_t )n * dim , xq , xs , dim );
3165+ if (layer -> wq_q2 ) {
3166+ tq_matmul_q2_preq (tmp_q , layer -> wq_q2 , layer -> wq_q2s , xq , xs , q_dim , dim );
3167+ for (int i = 0 ; i < q_dim ; i ++ ) QB [(size_t )n * q_dim + i ] += tmp_q [i ];
3168+ }
3169+ if (layer -> wk_q2 ) {
3170+ tq_matmul_q2_preq (tmp_k , layer -> wk_q2 , layer -> wk_q2s , xq , xs , kv_dim , dim );
3171+ for (int i = 0 ; i < kv_dim ; i ++ ) KB [(size_t )n * kv_dim + i ] += tmp_k [i ];
3172+ }
3173+ if (layer -> wv_q2 ) {
3174+ tq_matmul_q2_preq (tmp_v , layer -> wv_q2 , layer -> wv_q2s , xq , xs , kv_dim , dim );
3175+ for (int i = 0 ; i < kv_dim ; i ++ ) VB [(size_t )n * kv_dim + i ] += tmp_v [i ];
3176+ }
3177+ }
3178+ free (tmp_q ); free (tmp_k ); free (tmp_v );
3179+ (void )n_blocks_d ;
3180+ }
3181+
3182+ /* 2a. Apply Q/K/V biases (Qwen2/2.5/3 — NULL for Llama). */
3183+ if (layer -> q_bias ) {
3184+ for (int n = 0 ; n < N ; n ++ )
3185+ for (int i = 0 ; i < q_dim ; i ++ ) QB [(size_t )n * q_dim + i ] += layer -> q_bias [i ];
3186+ }
3187+ if (layer -> k_bias ) {
3188+ for (int n = 0 ; n < N ; n ++ )
3189+ for (int i = 0 ; i < kv_dim ; i ++ ) KB [(size_t )n * kv_dim + i ] += layer -> k_bias [i ];
3190+ }
3191+ if (layer -> v_bias ) {
3192+ for (int n = 0 ; n < N ; n ++ )
3193+ for (int i = 0 ; i < kv_dim ; i ++ ) VB [(size_t )n * kv_dim + i ] += layer -> v_bias [i ];
3194+ }
3195+ /* 2b. QK-norm (Qwen3 — NULL for Llama). */
3196+ if (layer -> q_norm ) {
3197+ for (int n = 0 ; n < N ; n ++ ) {
3198+ for (int h = 0 ; h < c -> n_heads ; h ++ ) {
3199+ float * qh = QB + (size_t )n * q_dim + h * c -> head_dim ;
3200+ tq_rmsnorm (qh , qh , layer -> q_norm , c -> head_dim , c -> rms_norm_eps );
3201+ }
3202+ }
3203+ }
3204+ if (layer -> k_norm ) {
3205+ for (int n = 0 ; n < N ; n ++ ) {
3206+ for (int h = 0 ; h < c -> n_kv_heads ; h ++ ) {
3207+ float * kh = KB + (size_t )n * kv_dim + h * c -> head_dim ;
3208+ tq_rmsnorm (kh , kh , layer -> k_norm , c -> head_dim , c -> rms_norm_eps );
3209+ }
3210+ }
3211+ }
3212+
31353213 /* 3. RoPE + KV cache write (per-token).
31363214 * Mirror tq_forward's RoPE selection: if model->rope_freqs is set
31373215 * (Llama 3.x learned RoPE scaling, 64 freq factors), apply per-pair
@@ -3141,13 +3219,15 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
31413219 float * kn = KB + (size_t )n * kv_dim ;
31423220 int pos = pos_start + n ;
31433221 if (model -> rope_freqs && model -> rope_freqs_len > 0 ) {
3144- int rope_pairs = c -> head_dim / 2 ;
3222+ /* Match tq_forward's rope_n_dims selection: c->rope_n_dims may
3223+ * differ from head_dim (e.g., Gemma partial RoPE). */
3224+ int rope_n_dims = (c -> rope_n_dims > 0 ) ? c -> rope_n_dims : c -> head_dim ;
3225+ int rope_pairs = rope_n_dims / 2 ;
31453226 if (rope_pairs > model -> rope_freqs_len ) rope_pairs = model -> rope_freqs_len ;
3146- /* Llama 3 uses interleaved layout (a=2i, b=2i+1) */
31473227 for (int h = 0 ; h < c -> n_heads ; h ++ ) {
31483228 float * qh = qn + h * c -> head_dim ;
31493229 for (int i = 0 ; i < rope_pairs ; i ++ ) {
3150- float base = 1.0f / powf (c -> rope_freq_base , 2.0f * i / (float )c -> head_dim );
3230+ float base = 1.0f / powf (c -> rope_freq_base , 2.0f * i / (float )rope_n_dims );
31513231 float freq = base / model -> rope_freqs [i ];
31523232 float theta = pos * freq ;
31533233 float ct = cosf (theta ), st = sinf (theta );
@@ -3159,7 +3239,7 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
31593239 for (int h = 0 ; h < c -> n_kv_heads ; h ++ ) {
31603240 float * kh = kn + h * c -> head_dim ;
31613241 for (int i = 0 ; i < rope_pairs ; i ++ ) {
3162- float base = 1.0f / powf (c -> rope_freq_base , 2.0f * i / (float )c -> head_dim );
3242+ float base = 1.0f / powf (c -> rope_freq_base , 2.0f * i / (float )rope_n_dims );
31633243 float freq = base / model -> rope_freqs [i ];
31643244 float theta = pos * freq ;
31653245 float ct = cosf (theta ), st = sinf (theta );
@@ -3284,8 +3364,19 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
32843364 }
32853365 }
32863366
3287- /* 5. O matmul batched */
3367+ /* 5. O matmul batched + Q2 residual */
32883368 tq_batched_matmul_q4 (X , layer -> wo_q4 , layer -> wo_q4s , OB , dim , q_dim , N , NULL );
3369+ if (layer -> wo_q2 ) {
3370+ int8_t * xq = s -> xb_q8 ;
3371+ float * xs = s -> xb_q8s ;
3372+ float * tmp = (float * )malloc ((size_t )dim * sizeof (float ));
3373+ for (int n = 0 ; n < N ; n ++ ) {
3374+ tq_quantize_row_q8 (OB + (size_t )n * q_dim , xq , xs , q_dim );
3375+ tq_matmul_q2_preq (tmp , layer -> wo_q2 , layer -> wo_q2s , xq , xs , dim , q_dim );
3376+ for (int i = 0 ; i < dim ; i ++ ) X [(size_t )n * dim + i ] += tmp [i ];
3377+ }
3378+ free (tmp );
3379+ }
32893380
32903381 /* 6. Residual: Xres += X */
32913382 for (size_t i = 0 ; i < (size_t )N * dim ; i ++ ) Xres [i ] += X [i ];
@@ -3296,9 +3387,26 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
32963387 layer -> ffn_norm , dim , c -> rms_norm_eps );
32973388 }
32983389
3299- /* 8. gate, up batched matmul */
3390+ /* 8. gate, up batched matmul + Q2 residuals */
33003391 tq_batched_matmul_q4 (GB , layer -> w_gate_q4 , layer -> w_gate_q4s , XBN , inter , dim , N , NULL );
33013392 tq_batched_matmul_q4 (UB , layer -> w_up_q4 , layer -> w_up_q4s , XBN , inter , dim , N , NULL );
3393+ if (layer -> w_gate_q2 || layer -> w_up_q2 ) {
3394+ int8_t * xq = s -> xb_q8 ;
3395+ float * xs = s -> xb_q8s ;
3396+ float * tmp = (float * )malloc ((size_t )inter * sizeof (float ));
3397+ for (int n = 0 ; n < N ; n ++ ) {
3398+ tq_quantize_row_q8 (XBN + (size_t )n * dim , xq , xs , dim );
3399+ if (layer -> w_gate_q2 ) {
3400+ tq_matmul_q2_preq (tmp , layer -> w_gate_q2 , layer -> w_gate_q2s , xq , xs , inter , dim );
3401+ for (int i = 0 ; i < inter ; i ++ ) GB [(size_t )n * inter + i ] += tmp [i ];
3402+ }
3403+ if (layer -> w_up_q2 ) {
3404+ tq_matmul_q2_preq (tmp , layer -> w_up_q2 , layer -> w_up_q2s , xq , xs , inter , dim );
3405+ for (int i = 0 ; i < inter ; i ++ ) UB [(size_t )n * inter + i ] += tmp [i ];
3406+ }
3407+ }
3408+ free (tmp );
3409+ }
33023410
33033411 /* 9. SiLU(gate) * up (per-element) */
33043412 for (size_t i = 0 ; i < (size_t )N * inter ; i ++ ) {
@@ -3307,8 +3415,19 @@ int tq_forward_batch(tq_model_t* model, tq_state_t* s,
33073415 GB [i ] = silu * UB [i ];
33083416 }
33093417
3310- /* 10. down matmul batched (output back into X) */
3418+ /* 10. down matmul batched (output back into X) + Q2 residual */
33113419 tq_batched_matmul_q4 (X , layer -> w_down_q4 , layer -> w_down_q4s , GB , dim , inter , N , NULL );
3420+ if (layer -> w_down_q2 ) {
3421+ int8_t * xq = s -> xb_q8 ;
3422+ float * xs = s -> xb_q8s ;
3423+ float * tmp = (float * )malloc ((size_t )dim * sizeof (float ));
3424+ for (int n = 0 ; n < N ; n ++ ) {
3425+ tq_quantize_row_q8 (GB + (size_t )n * inter , xq , xs , inter );
3426+ tq_matmul_q2_preq (tmp , layer -> w_down_q2 , layer -> w_down_q2s , xq , xs , dim , inter );
3427+ for (int i = 0 ; i < dim ; i ++ ) X [(size_t )n * dim + i ] += tmp [i ];
3428+ }
3429+ free (tmp );
3430+ }
33123431
33133432 /* 11. Residual: Xres += X */
33143433 for (size_t i = 0 ; i < (size_t )N * dim ; i ++ ) Xres [i ] += X [i ];
0 commit comments