Skip to content

Commit 1739176

Browse files
unamedkrclaude
andcommitted
REAL KV compression: FP32 key cache eliminated, honest PPL measured
Critical architecture fix: - Keys now stored ONLY in quantized cache (not duplicated to FP32) - Attention reads from quant cache → dequant → FP32 dot product - NEON-optimized dequant+dot inner loop - This is REAL memory savings (no FP32 key cache) Honest PPL results (SmolLM2 1.7B, 814 tokens, dequant path): baseline: PPL = 9.51 3-bit K: PPL = 22.45 (+136%) — moderate degradation 1-bit K: PPL = 1294.8 — catastrophic, unusable Previous "PPL +0.00%" was measured with FP32 fallback (keys never actually read from quantized cache). That claim was incorrect. 33/33 tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 938c1f4 commit 1739176

1 file changed

Lines changed: 50 additions & 4 deletions

File tree

src/engine/tq_transformer.c

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,9 +1003,15 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
10031003
tq_rope(s->q, s->k, pos, head_dim, n_heads, n_kv_heads, rope_base);
10041004
}
10051005

1006-
/* Store K,V in cache */
1006+
/* Store K,V in cache.
1007+
* When quantized KV is active, skip FP32 key storage — the quantized
1008+
* cache is the single source of truth. This eliminates the duplicate
1009+
* FP32 copy and is the basis for real memory savings. */
1010+
int use_quant_kv = (s->kv_quant_type < TQ_TYPE_COUNT && s->quant_key_cache != NULL);
10071011
float* key_cache_layer = s->key_cache + l * kv_layer_stride;
1008-
memcpy(key_cache_layer + (size_t)pos * kv_dim, s->k, kv_dim * sizeof(float));
1012+
if (!use_quant_kv) {
1013+
memcpy(key_cache_layer + (size_t)pos * kv_dim, s->k, kv_dim * sizeof(float));
1014+
}
10091015

10101016
/* KV profiling: accumulate pre/post-RHT statistics for this layer's keys */
10111017
if (s->profile_kv && s->profile_accum) {
@@ -1077,7 +1083,7 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
10771083
* Note: 1-bit/2b/3b sign-based quantization now expands sketch_dim to
10781084
* at least 128 bits for small head_dim (QJL paper: m/d >= 2), so no
10791085
* fallback is needed. */
1080-
int use_int_attn = (s->kv_quant_type < TQ_TYPE_COUNT && s->quant_key_cache != NULL);
1086+
int use_int_attn = use_quant_kv;
10811087
if (use_int_attn) {
10821088
const tq_type_traits_t* traits = &TQ_TRAITS[s->kv_quant_type];
10831089
for (int kh = 0; kh < n_kv_heads; kh++) {
@@ -1161,8 +1167,48 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
11611167
for (int t = 0; t < attn_start; t++) {
11621168
atth[t] = -1e30f;
11631169
}
1170+
} else if (use_quant_kv) {
1171+
/* Dequant attention: read from quantized key cache, dequantize
1172+
* each position's key on the fly, then compute FP32 dot product.
1173+
* This is the path that delivers REAL memory savings — no FP32
1174+
* key cache is stored for previous positions. */
1175+
const tq_type_traits_t* traits = &TQ_TRAITS[s->kv_quant_type];
1176+
float inv_scale = 1.0f / sqrtf(attn_scale_dim);
1177+
float dequant_buf[256]; /* temp buffer for one head's dequantized key */
1178+
1179+
for (int t = 0; t < attn_start; t++) atth[t] = -1e30f;
1180+
1181+
for (int t = attn_start; t < seq_len; t++) {
1182+
const uint8_t* quant_src = (const uint8_t*)s->quant_key_cache
1183+
+ (size_t)l * s->quant_kv_stride
1184+
+ (size_t)t * n_kv_heads * s->quant_head_stride
1185+
+ (size_t)kv_h * s->quant_head_stride;
1186+
1187+
traits->dequantize(quant_src, dequant_buf, head_dim);
1188+
1189+
float score = 0.0f;
1190+
#ifdef __ARM_NEON
1191+
/* NEON-optimized dot product */
1192+
float32x4_t vsum = vdupq_n_f32(0.0f);
1193+
int d = 0;
1194+
for (; d + 4 <= head_dim; d += 4) {
1195+
float32x4_t vq = vld1q_f32(qh + d);
1196+
float32x4_t vk = vld1q_f32(dequant_buf + d);
1197+
vsum = vfmaq_f32(vsum, vq, vk);
1198+
}
1199+
score = vaddvq_f32(vsum);
1200+
for (; d < head_dim; d++) {
1201+
score += qh[d] * dequant_buf[d];
1202+
}
1203+
#else
1204+
for (int d = 0; d < head_dim; d++) {
1205+
score += qh[d] * dequant_buf[d];
1206+
}
1207+
#endif
1208+
atth[t] = score * inv_scale;
1209+
}
11641210
} else {
1165-
/* FP32 attention scores (short sequences or no quantization) */
1211+
/* FP32 attention scores (no quantization) */
11661212
float inv_scale = 1.0f / sqrtf(attn_scale_dim);
11671213
/* Set positions outside sliding window to -inf */
11681214
for (int t = 0; t < attn_start; t++) {

0 commit comments

Comments
 (0)