@@ -1117,7 +1117,65 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
11171117 + (size_t )l * s -> quant_kv_stride
11181118 + (size_t )pos * cache_n_kv_heads * s -> quant_head_stride
11191119 + (size_t )kh * s -> quant_head_stride ;
1120- traits -> quantize (key_src , quant_dst , head_dim );
1120+
1121+ if (s -> delta_kv_enabled && pos > 0 ) {
1122+ /* Delta compression with periodic I-frames.
1123+ * I-frames store absolute keys to bound accumulated drift.
1124+ * P-frames store delta = key[t] - reconstruct(key[t-1]). */
1125+ int iframe_int = s -> delta_iframe_interval > 0 ? s -> delta_iframe_interval : 16 ;
1126+ int is_iframe = (pos % iframe_int == 0 );
1127+
1128+ if (is_iframe ) {
1129+ /* I-frame: quantize absolute key (drift reset) */
1130+ traits -> quantize (key_src , quant_dst , head_dim );
1131+ } else {
1132+ /* P-frame: quantize delta from previous position's reconstruction */
1133+ const uint8_t * prev_quant = (const uint8_t * )s -> quant_key_cache
1134+ + (size_t )l * s -> quant_kv_stride
1135+ + (size_t )(pos - 1 ) * cache_n_kv_heads * s -> quant_head_stride
1136+ + (size_t )kh * s -> quant_head_stride ;
1137+ float prev_recon [512 ];
1138+ traits -> dequantize (prev_quant , prev_recon , head_dim );
1139+
1140+ /* If previous was an I-frame, prev_recon is absolute.
1141+ * If previous was a P-frame, prev_recon is the delta.
1142+ * We need the full reconstruction of the previous key.
1143+ * Since we can't easily track this here, we reconstruct
1144+ * from the last I-frame. */
1145+ int last_iframe = (pos / iframe_int ) * iframe_int ;
1146+ if (pos - 1 > last_iframe ) {
1147+ /* Reconstruct key[pos-1] from last I-frame through deltas */
1148+ const uint8_t * iframe_src = (const uint8_t * )s -> quant_key_cache
1149+ + (size_t )l * s -> quant_kv_stride
1150+ + (size_t )last_iframe * cache_n_kv_heads * s -> quant_head_stride
1151+ + (size_t )kh * s -> quant_head_stride ;
1152+ float recon [512 ];
1153+ traits -> dequantize (iframe_src , recon , head_dim );
1154+ float tmp [512 ];
1155+ for (int ti = last_iframe + 1 ; ti <= pos - 1 ; ti ++ ) {
1156+ const uint8_t * delta_src = (const uint8_t * )s -> quant_key_cache
1157+ + (size_t )l * s -> quant_kv_stride
1158+ + (size_t )ti * cache_n_kv_heads * s -> quant_head_stride
1159+ + (size_t )kh * s -> quant_head_stride ;
1160+ traits -> dequantize (delta_src , tmp , head_dim );
1161+ for (int d = 0 ; d < head_dim ; d ++ ) {
1162+ recon [d ] += tmp [d ];
1163+ }
1164+ }
1165+ memcpy (prev_recon , recon , (size_t )head_dim * sizeof (float ));
1166+ }
1167+ /* else: pos-1 == last_iframe, prev_recon from dequant is correct */
1168+
1169+ float delta_buf [512 ];
1170+ for (int d = 0 ; d < head_dim ; d ++ ) {
1171+ delta_buf [d ] = key_src [d ] - prev_recon [d ];
1172+ }
1173+ traits -> quantize (delta_buf , quant_dst , head_dim );
1174+ }
1175+ } else {
1176+ /* First position (I-frame) or non-delta mode: quantize absolute key */
1177+ traits -> quantize (key_src , quant_dst , head_dim );
1178+ }
11211179 }
11221180 }
11231181
@@ -1195,6 +1253,81 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
11951253 for (int t = 0 ; t < attn_start ; t ++ ) {
11961254 atth [t ] = -1e30f ;
11971255 }
1256+ } else if (use_quant_kv && s -> delta_kv_enabled ) {
1257+ /* Delta KV attention with periodic I-frames.
1258+ * I-frames (pos % iframe_int == 0) store absolute keys.
1259+ * P-frames store deltas. Reconstruct by accumulating from last I-frame.
1260+ * This bounds drift to at most iframe_int steps. */
1261+ const tq_type_traits_t * traits = & TQ_TRAITS [s -> kv_quant_type ];
1262+ float inv_scale = 1.0f / sqrtf (attn_scale_dim );
1263+ int iframe_int = s -> delta_iframe_interval > 0 ? s -> delta_iframe_interval : 16 ;
1264+ float recon_key [512 ];
1265+ float dequant_buf [512 ];
1266+
1267+ for (int t = 0 ; t < attn_start ; t ++ ) atth [t ] = -1e30f ;
1268+
1269+ for (int t = attn_start ; t < seq_len ; t ++ ) {
1270+ const uint8_t * quant_src = (const uint8_t * )s -> quant_key_cache
1271+ + (size_t )l * s -> quant_kv_stride
1272+ + (size_t )t * cache_n_kv_heads * s -> quant_head_stride
1273+ + (size_t )kv_h * s -> quant_head_stride ;
1274+
1275+ if (t % iframe_int == 0 ) {
1276+ /* I-frame: dequantize directly */
1277+ traits -> dequantize (quant_src , recon_key , head_dim );
1278+ } else {
1279+ /* P-frame: need reconstruction from last I-frame */
1280+ int last_iframe = (t / iframe_int ) * iframe_int ;
1281+
1282+ /* If we're processing sequentially from last I-frame, recon_key
1283+ * already holds the previous position's reconstruction (if t-1
1284+ * was processed in this loop). Otherwise, reconstruct from scratch. */
1285+ if (t - 1 >= attn_start && t - 1 >= last_iframe ) {
1286+ /* recon_key holds recon[t-1], just add delta[t] */
1287+ traits -> dequantize (quant_src , dequant_buf , head_dim );
1288+ for (int d = 0 ; d < head_dim ; d ++ ) {
1289+ recon_key [d ] += dequant_buf [d ];
1290+ }
1291+ } else {
1292+ /* Reconstruct from last I-frame */
1293+ const uint8_t * iframe_src = (const uint8_t * )s -> quant_key_cache
1294+ + (size_t )l * s -> quant_kv_stride
1295+ + (size_t )last_iframe * cache_n_kv_heads * s -> quant_head_stride
1296+ + (size_t )kv_h * s -> quant_head_stride ;
1297+ traits -> dequantize (iframe_src , recon_key , head_dim );
1298+ for (int ti = last_iframe + 1 ; ti <= t ; ti ++ ) {
1299+ const uint8_t * delta_src = (const uint8_t * )s -> quant_key_cache
1300+ + (size_t )l * s -> quant_kv_stride
1301+ + (size_t )ti * cache_n_kv_heads * s -> quant_head_stride
1302+ + (size_t )kv_h * s -> quant_head_stride ;
1303+ traits -> dequantize (delta_src , dequant_buf , head_dim );
1304+ for (int d = 0 ; d < head_dim ; d ++ ) {
1305+ recon_key [d ] += dequant_buf [d ];
1306+ }
1307+ }
1308+ }
1309+ }
1310+
1311+ float score = 0.0f ;
1312+ #ifdef __ARM_NEON
1313+ float32x4_t vsum = vdupq_n_f32 (0.0f );
1314+ int d = 0 ;
1315+ for (; d + 4 <= head_dim ; d += 4 ) {
1316+ float32x4_t vq = vld1q_f32 (qh + d );
1317+ float32x4_t vk = vld1q_f32 (recon_key + d );
1318+ vsum = vfmaq_f32 (vsum , vq , vk );
1319+ }
1320+ score = vaddvq_f32 (vsum );
1321+ for (; d < head_dim ; d ++ ) {
1322+ score += qh [d ] * recon_key [d ];
1323+ }
1324+ #else
1325+ for (int d = 0 ; d < head_dim ; d ++ ) {
1326+ score += qh [d ] * recon_key [d ];
1327+ }
1328+ #endif
1329+ atth [t ] = score * inv_scale ;
1330+ }
11981331 } else if (use_quant_kv ) {
11991332 /* Dequant attention: read from quantized key cache, dequantize
12001333 * each position's key on the fly, then compute FP32 dot product.
0 commit comments