@@ -538,6 +538,7 @@ typedef struct {
538538 int model_type; /* 0=qwen35, 1=gemma3, 2=qwen2moe */
539539 int is_gemma4; /* 1 if Gemma 4 (STEP35): uses SwiGLU, no post-norms */
540540 int sliding_window; /* sliding window size (512 for gemma3, 0 for unlimited) */
541+ int n_kv_shared_layers; /* Gemma 4: last N layers share KV from earlier same-type layers (0=disabled) */
541542 float rope_local_base_freq; /* RoPE base freq for local/sliding layers (10000.0 for gemma3) */
542543 int n_norms_per_block; /* 2 for qwen35, 4 for gemma3 */
543544 float query_pre_attn_scalar; /* attention scaling: 1/sqrt(this) instead of 1/sqrt(head_dim), 0=use head_dim */
@@ -11456,6 +11457,11 @@ tq_model_t* tq_load_gguf(const char* path) {
1145611457
1145711458 /* Sliding window + local RoPE base */
1145811459 c->sliding_window = (int)tq_gguf_get_u32(gguf, GGUF_KEY("attention.sliding_window"), 0);
11460+ c->n_kv_shared_layers = (int)tq_gguf_get_u32(gguf, GGUF_KEY("attention.shared_kv_layers"), 0);
11461+ if (c->n_kv_shared_layers > 0) {
11462+ fprintf(stderr, "tq_load_gguf: KV sharing enabled — last %d layers reuse KV from earlier same-type layers\n",
11463+ c->n_kv_shared_layers);
11464+ }
1145911465 /* Local/sliding RoPE base: try Gemma4 naming first, then generic */
1146011466 c->rope_local_base_freq = tq_gguf_get_f32(gguf, GGUF_KEY("rope.freq_base_swa"),
1146111467 tq_gguf_get_f32(gguf, GGUF_KEY("rope.local.freq_base"),
@@ -14084,6 +14090,29 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
1408414090 tq_quantize_row_q8(s->xb, s->xb_q8, s->xb_q8s, dim);
1408514091 }
1408614092
14093+ /* Gemma 4 KV sharing: last n_kv_shared_layers layers skip K/V projection
14094+ * and reuse the KV cache from the last non-shared layer of the same
14095+ * attention type (sliding or full). Only Q is computed fresh. */
14096+ int kv_shared_skip = 0;
14097+ int kv_shared_ref_layer = -1;
14098+ /* KV sharing: disabled by default until segfault in value_cache
14099+ * FP16 stride is fixed. Enable with TQ_KV_SHARE=1 for testing. */
14100+ if (c->n_kv_shared_layers > 0 && getenv("TQ_KV_SHARE")) {
14101+ int shared_start = c->n_layers - c->n_kv_shared_layers;
14102+ if (l >= shared_start) {
14103+ kv_shared_skip = 1;
14104+ /* Find reference layer: last non-shared layer of the same type */
14105+ int is_sliding_l = (model->layer_is_sliding && model->layer_is_sliding[l]);
14106+ for (int r = shared_start - 1; r >= 0; r--) {
14107+ int is_sliding_r = (model->layer_is_sliding && model->layer_is_sliding[r]);
14108+ if (is_sliding_l == is_sliding_r) {
14109+ kv_shared_ref_layer = r;
14110+ break;
14111+ }
14112+ }
14113+ }
14114+ }
14115+
1408714116 /* QKV projections (timed as matmul) */
1408814117 TQ_PROF_START(_tp);
1408914118 /* When attn_output_gate is enabled, wq has shape [2*n_heads*head_dim, dim]
@@ -14153,7 +14182,38 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
1415314182 tq_matmul(s->q, s->xb, layer->wq, n_heads * head_dim, dim);
1415414183 }
1415514184 }
14156- if (has_fused_qkv_layer) {
14185+ if (kv_shared_skip && kv_shared_ref_layer >= 0) {
14186+ /* KV sharing: skip K/V projection for shared layers.
14187+ * Read K from the reference layer's FP32 key cache.
14188+ * V: read from reference layer's value cache (FP32 or FP16).
14189+ * NOTE: we copy into s->k/s->v so the KV cache write below
14190+ * stores them into THIS layer's cache slot (for attention).
14191+ *
14192+ * IMPORTANT: kv_layer_stride is in FLOATS for key_cache (FP32).
14193+ * value_cache uses FP16 when use_fp16_values is set, but the
14194+ * stride is STILL in float-units because value_cache is cast
14195+ * to uint16_t* only at write/read time. The allocation uses
14196+ * sizeof(float) for FP32 and sizeof(uint16_t) for FP16 — but
14197+ * the STRIDE variable is in elements, not bytes. For FP16 values,
14198+ * the value_cache pointer is actually a uint16_t* in disguise.
14199+ */
14200+ float* ref_key_layer = s->key_cache + (size_t)kv_shared_ref_layer * kv_layer_stride;
14201+ memcpy(s->k, ref_key_layer + (size_t)pos * cache_kv_dim, (size_t)kv_dim * sizeof(float));
14202+
14203+ if (s->use_fp16_values) {
14204+ /* Value cache stores as FP16 (uint16_t). Stride is in FP16 elements. */
14205+ size_t v_stride = (size_t)c->max_seq_len * cache_kv_dim; /* in uint16 elements */
14206+ const uint16_t* v16_cache = (const uint16_t*)s->value_cache;
14207+ const uint16_t* ref_v = v16_cache + (size_t)kv_shared_ref_layer * v_stride + (size_t)pos * cache_kv_dim;
14208+ for (int i = 0; i < kv_dim; i++) {
14209+ uint32_t bits = ((uint32_t)ref_v[i]) << 16;
14210+ memcpy(&s->v[i], &bits, 4);
14211+ }
14212+ } else {
14213+ float* ref_val_layer = s->value_cache + (size_t)kv_shared_ref_layer * kv_layer_stride;
14214+ memcpy(s->v, ref_val_layer + (size_t)pos * cache_kv_dim, (size_t)kv_dim * sizeof(float));
14215+ }
14216+ } else if (has_fused_qkv_layer) {
1415714217 /* Already populated s->q/s->k/s->v above — skip the standard
1415814218 * K and V projection blocks. */
1415914219 } else if (layer->wk_q2) {
0 commit comments