@@ -1003,9 +1003,15 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
10031003 tq_rope (s -> q , s -> k , pos , head_dim , n_heads , n_kv_heads , rope_base );
10041004 }
10051005
1006- /* Store K,V in cache */
1006+ /* Store K,V in cache.
1007+ * When quantized KV is active, skip FP32 key storage — the quantized
1008+ * cache is the single source of truth. This eliminates the duplicate
1009+ * FP32 copy and is the basis for real memory savings. */
1010+ int use_quant_kv = (s -> kv_quant_type < TQ_TYPE_COUNT && s -> quant_key_cache != NULL );
10071011 float * key_cache_layer = s -> key_cache + l * kv_layer_stride ;
1008- memcpy (key_cache_layer + (size_t )pos * kv_dim , s -> k , kv_dim * sizeof (float ));
1012+ if (!use_quant_kv ) {
1013+ memcpy (key_cache_layer + (size_t )pos * kv_dim , s -> k , kv_dim * sizeof (float ));
1014+ }
10091015
10101016 /* KV profiling: accumulate pre/post-RHT statistics for this layer's keys */
10111017 if (s -> profile_kv && s -> profile_accum ) {
@@ -1077,7 +1083,7 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
10771083 * Note: 1-bit/2b/3b sign-based quantization now expands sketch_dim to
10781084 * at least 128 bits for small head_dim (QJL paper: m/d >= 2), so no
10791085 * fallback is needed. */
1080- int use_int_attn = ( s -> kv_quant_type < TQ_TYPE_COUNT && s -> quant_key_cache != NULL ) ;
1086+ int use_int_attn = use_quant_kv ;
10811087 if (use_int_attn ) {
10821088 const tq_type_traits_t * traits = & TQ_TRAITS [s -> kv_quant_type ];
10831089 for (int kh = 0 ; kh < n_kv_heads ; kh ++ ) {
@@ -1161,8 +1167,48 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
11611167 for (int t = 0 ; t < attn_start ; t ++ ) {
11621168 atth [t ] = -1e30f ;
11631169 }
1170+ } else if (use_quant_kv ) {
1171+ /* Dequant attention: read from quantized key cache, dequantize
1172+ * each position's key on the fly, then compute FP32 dot product.
1173+ * This is the path that delivers REAL memory savings — no FP32
1174+ * key cache is stored for previous positions. */
1175+ const tq_type_traits_t * traits = & TQ_TRAITS [s -> kv_quant_type ];
1176+ float inv_scale = 1.0f / sqrtf (attn_scale_dim );
1177+ float dequant_buf [256 ]; /* temp buffer for one head's dequantized key */
1178+
1179+ for (int t = 0 ; t < attn_start ; t ++ ) atth [t ] = -1e30f ;
1180+
1181+ for (int t = attn_start ; t < seq_len ; t ++ ) {
1182+ const uint8_t * quant_src = (const uint8_t * )s -> quant_key_cache
1183+ + (size_t )l * s -> quant_kv_stride
1184+ + (size_t )t * n_kv_heads * s -> quant_head_stride
1185+ + (size_t )kv_h * s -> quant_head_stride ;
1186+
1187+ traits -> dequantize (quant_src , dequant_buf , head_dim );
1188+
1189+ float score = 0.0f ;
1190+ #ifdef __ARM_NEON
1191+ /* NEON-optimized dot product */
1192+ float32x4_t vsum = vdupq_n_f32 (0.0f );
1193+ int d = 0 ;
1194+ for (; d + 4 <= head_dim ; d += 4 ) {
1195+ float32x4_t vq = vld1q_f32 (qh + d );
1196+ float32x4_t vk = vld1q_f32 (dequant_buf + d );
1197+ vsum = vfmaq_f32 (vsum , vq , vk );
1198+ }
1199+ score = vaddvq_f32 (vsum );
1200+ for (; d < head_dim ; d ++ ) {
1201+ score += qh [d ] * dequant_buf [d ];
1202+ }
1203+ #else
1204+ for (int d = 0 ; d < head_dim ; d ++ ) {
1205+ score += qh [d ] * dequant_buf [d ];
1206+ }
1207+ #endif
1208+ atth [t ] = score * inv_scale ;
1209+ }
11641210 } else {
1165- /* FP32 attention scores (short sequences or no quantization) */
1211+ /* FP32 attention scores (no quantization) */
11661212 float inv_scale = 1.0f / sqrtf (attn_scale_dim );
11671213 /* Set positions outside sliding window to -inf */
11681214 for (int t = 0 ; t < attn_start ; t ++ ) {
0 commit comments