@@ -143,6 +143,12 @@ tq_state_t* tq_create_state_ex(const tq_model_config_t* config, tq_type kv_type,
143143 int max_seq = config -> max_seq_len ;
144144 int n_layers = config -> n_layers ;
145145
146+ /* For hybrid attention (Gemma 4), full layers have larger kv_dim.
147+ * Allocate K/V buffers and KV cache with the MAX of sliding and full kv_dim. */
148+ int full_kv_dim = (config -> full_n_kv_heads > 0 && config -> full_head_dim > 0 )
149+ ? config -> full_n_kv_heads * config -> full_head_dim : kv_dim ;
150+ int max_kv_dim = (full_kv_dim > kv_dim ) ? full_kv_dim : kv_dim ;
151+
146152 tq_state_t * s = (tq_state_t * )calloc (1 , sizeof (tq_state_t ));
147153 if (!s ) return NULL ;
148154
@@ -171,15 +177,15 @@ tq_state_t* tq_create_state_ex(const tq_model_config_t* config, tq_type kv_type,
171177 s -> xb = (float * )calloc ((size_t )max_dim , sizeof (float ));
172178 s -> xb2 = (float * )calloc ((size_t )max_dim , sizeof (float ));
173179 s -> q = (float * )calloc ((size_t )max_q_dim , sizeof (float ));
174- s -> k = (float * )calloc ((size_t )kv_dim , sizeof (float ));
175- s -> v = (float * )calloc ((size_t )kv_dim , sizeof (float ));
180+ s -> k = (float * )calloc ((size_t )max_kv_dim , sizeof (float ));
181+ s -> v = (float * )calloc ((size_t )max_kv_dim , sizeof (float ));
176182 s -> att = (float * )calloc ((size_t )n_heads * max_seq , sizeof (float ));
177183 s -> hb = (float * )calloc ((size_t )inter_dim , sizeof (float ));
178184 s -> hb2 = (float * )calloc ((size_t )inter_dim , sizeof (float ));
179185 s -> logits = (float * )calloc ((size_t )config -> vocab_size , sizeof (float ));
180186
181- /* KV cache for self_attn layers */
182- size_t kv_layer_size = (size_t )max_seq * kv_dim ;
187+ /* KV cache for self_attn layers — use max_kv_dim for hybrid attention compatibility */
188+ size_t kv_layer_size = (size_t )max_seq * max_kv_dim ;
183189 s -> key_cache = (float * )calloc ((size_t )n_layers * kv_layer_size , sizeof (float ));
184190
185191 /* Value cache quantization: Q4 or Q2 for aggressive V compression.
@@ -188,8 +194,8 @@ tq_state_t* tq_create_state_ex(const tq_model_config_t* config, tq_type kv_type,
188194 * Q2: 8 packed bytes + 1 float scale per block of 32 = 12 bytes/32 values */
189195 s -> value_quant_bits = value_quant_bits ;
190196 if (value_quant_bits == 4 || value_quant_bits == 2 ) {
191- /* Quantized V cache */
192- int n_blocks_per_pos = (kv_dim + 31 ) / 32 ; /* blocks per position (all heads) */
197+ /* Quantized V cache — use max_kv_dim for hybrid attention compatibility */
198+ int n_blocks_per_pos = (max_kv_dim + 31 ) / 32 ; /* blocks per position (all heads) */
193199 size_t packed_per_block = (value_quant_bits == 4 ) ? 16 : 8 ;
194200 s -> value_stride_qs = (size_t )n_blocks_per_pos * packed_per_block ;
195201 s -> value_stride_scales = (size_t )n_blocks_per_pos ;
@@ -883,8 +889,12 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
883889
884890 int kv_dim = n_kv_heads * head_dim ;
885891 int kv_mul = n_heads / n_kv_heads ;
886- /* KV cache stride uses the global (sliding) config for uniform allocation */
887- int cache_kv_dim = c -> n_kv_heads * c -> head_dim ;
892+ /* KV cache stride uses the MAX of sliding and full kv_dim for uniform allocation.
893+ * This ensures full attention layers (with larger kv_dim) don't overflow the cache. */
894+ int sliding_kv_dim = c -> n_kv_heads * c -> head_dim ;
895+ int full_kv_dim_cache = (c -> full_n_kv_heads > 0 && c -> full_head_dim > 0 )
896+ ? c -> full_n_kv_heads * c -> full_head_dim : sliding_kv_dim ;
897+ int cache_kv_dim = (full_kv_dim_cache > sliding_kv_dim ) ? full_kv_dim_cache : sliding_kv_dim ;
888898 size_t kv_layer_stride = (size_t )c -> max_seq_len * cache_kv_dim ;
889899
890900 /* Pre-quantize activation to Q8 once for all Q2/Q4 projections in this layer.
@@ -1222,8 +1232,10 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
12221232 * Others: scale = 1/sqrt(head_dim) */
12231233 float attn_scale_dim = (float )head_dim ;
12241234 if (c -> use_qk_norm && c -> model_type == 1 && c -> full_head_dim > 0 && !c -> is_moe ) {
1225- /* Gemma 4 dense (E2B): attention_scale = 1.0 (QK-norm handles scaling) */
1226- attn_scale_dim = 1.0f ; /* will compute 1/sqrt(1) = 1.0 */
1235+ /* Gemma 4: QK-norm normalizes Q,K per head, but we still need 1/sqrt(head_dim)
1236+ * scaling. QK-norm ensures ||Q||=||K||~sqrt(head_dim) after norm weights,
1237+ * so the dot product scales as head_dim without explicit scaling. */
1238+ attn_scale_dim = (float )head_dim ;
12271239 } else if (c -> query_pre_attn_scalar > 0.0f ) {
12281240 attn_scale_dim = c -> query_pre_attn_scalar ;
12291241 if (c -> full_head_dim > 0 && model -> layer_is_sliding && !model -> layer_is_sliding [l ]) {
@@ -1439,6 +1451,15 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
14391451 }
14401452 }
14411453
1454+ /* Attention logit soft-capping (Gemma 2/3/4): cap * tanh(score / cap) */
1455+ if (c -> attn_logit_softcap > 0.0f ) {
1456+ float cap = c -> attn_logit_softcap ;
1457+ float inv_cap = 1.0f / cap ;
1458+ for (int t = attn_start ; t < seq_len ; t ++ ) {
1459+ atth [t ] = cap * tanhf (atth [t ] * inv_cap );
1460+ }
1461+ }
1462+
14421463 /* Softmax */
14431464 tq_softmax (atth , seq_len );
14441465
@@ -1789,7 +1810,7 @@ float* tq_forward(tq_model_t* model, tq_state_t* s, int token, int pos) {
17891810 * 1. per_layer_token_embd[token] (dequant from Q5_K) → reshape [n_layers, ple_dim]
17901811 * 2. per_layer_model_proj @ embed_raw (FP32 matmul) → reshape [n_layers, ple_dim]
17911812 * 3. Combine with RMS-norm and averaging. */
1792- if (model -> ple_dim > 0 && model -> ple_embedding && model -> ple_proj ) {
1813+ if (model -> ple_dim > 0 && model -> ple_embedding && model -> ple_proj && ! getenv ( "TQ_NO_PLE" ) ) {
17931814 int ple_dim = model -> ple_dim ;
17941815 int n_layers = c -> n_layers ;
17951816 int total_ple = n_layers * ple_dim ; /* e.g., 35 * 256 = 8960 */
@@ -2033,12 +2054,13 @@ float* tq_forward(tq_model_t* model, tq_state_t* s, int token, int pos) {
20332054 }
20342055
20352056 /* Gemma 4 PLE: apply per-layer embedding after FFN, before layer_output_scale.
2057+ * Can be disabled with TQ_NO_PLE=1 for debugging.
20362058 * 1. gate_out = gelu(inp_gate @ hidden_state) → [ple_dim]
20372059 * 2. mixed = gate_out * ple_input[l] → elementwise [ple_dim]
20382060 * 3. proj_out = proj @ mixed → [hidden_dim]
20392061 * 4. normed = rms_norm(proj_out, post_norm) → [hidden_dim]
20402062 * 5. hidden_state = hidden_state + normed */
2041- if (model -> ple_dim > 0 && s -> ple_buf && layer -> ple_gate && layer -> ple_proj && layer -> ple_norm ) {
2063+ if (model -> ple_dim > 0 && s -> ple_buf && layer -> ple_gate && layer -> ple_proj && layer -> ple_norm && ! getenv ( "TQ_NO_PLE" ) ) {
20422064 int ple_dim = model -> ple_dim ;
20432065 float ple_gate_out [256 ]; /* ple_dim <= 256 */
20442066 float ple_mixed [256 ];
0 commit comments