Skip to content

Commit c135ad9

Browse files
unamedkrclaude
andcommitted
PERF: turbo_kv 4b/3b/5b/4bo now BEAT fp32 KV speed (Karpathy round 5+6)
Major performance breakthrough — 6 rounds of Karpathy iteration on KV attention path: Llama 3.2 3B PPL eval (1040 tokens, 28 layers, attention-heavy): Type Round 0 Round 6 Speedup PPL Δ vs FP32 -------------- -------- -------- -------- ------------- fp32 12.6 t/s 12.6 t/s baseline — turbo_kv_3b 6.9 t/s 13.4 t/s +94% +13.3% turbo_kv_4b 6.9 t/s 13.9 t/s +101% +5.7% ⭐ now beats fp32! turbo_kv_5b 6.8 t/s 13.2 t/s +94% +0.7% 🏆 + beats fp32! turbo_kv_4bo — 12.7 t/s — +2.5% turbo_kv_3bo — 9.3 t/s — +4.5% uniform_4b — 11.7 t/s — +7.7% **turbo_kv_4b is now both 7× more compressed AND faster than fp32 KV.** ## Round 5: the real bottleneck — transformer inverse-RHT loop The biggest win came from changing tq_transformer.c's use_quant_kv path to use the type's traits->attention kernel instead of calling traits->dequantize once per cached key. The per-position dequantize path was paying tq_rht_inverse() (O(d log d) ≈ 900 ops) per position, which dominated at long context. The new fast path: 1. Gathers quantized blocks for one kv head into a contiguous buffer 2. Calls traits->attention which pre-rotates the query ONCE and then does fused dequant + dot product per block in rotated space 3. No per-position inverse RHT Old path: New path: for each cached pos: gather to contiguous buffer dequant + RHT_inverse traits->attention(q_rotated, ...) inline NEON dot Slow path is preserved for the complex cases: - save_pre_norm_keys (Gemma 4 QK-norm) - k_highres_window active - attn_start > 0 (sliding window) ## Round 6: hoist LUT in dequant_mse_rotated_4bo / 3bo Replaced tq_codebook_dequantize() calls with single-pass inline loops that pre-multiply the per-block scale into a local 16/8-entry LUT. Same pattern as the 3b/4b/5b dequant functions. PPL changed slightly (FP reordering, < 0.6%): - turbo_kv_4b: 14.28 → 14.33 (+0.05) - turbo_kv_5b: 13.60 → 13.65 (+0.05) - All within regression test thresholds (35/35 pass) 35/35 tests pass. Regression tests pin cosine ≥ 0.99 / 0.999. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1db8a55 commit c135ad9

2 files changed

Lines changed: 273 additions & 104 deletions

File tree

src/core/tq_turbo_kv.c

Lines changed: 196 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ extern void tq_codebook_quantize(const float* src, uint8_t* dst_indices,
2929
int n, int bits, float inv_std);
3030
extern void tq_codebook_dequantize(const uint8_t* indices, float* dst,
3131
int n, int bits, float inv_std);
32+
extern const float* tq_codebook_centroids(int bits);
3233

3334
/* ============================================================
3435
* FP16 helpers (local copies to avoid cross-module dependencies)
@@ -152,26 +153,84 @@ static void compute_qjl_signs(const float* residual, uint8_t* signs,
152153

153154
static void dequant_mse_rotated_3bit_v2(const block_tq_turbo_kv_3b* block,
154155
float* rotated, int dim) {
155-
/* Variant F (3b): 3-bit codebook (8 levels) + max-abs scaling, no QJL */
156+
/* Variant F (3b): 3-bit codebook (8 levels) + max-abs scaling.
157+
* Single-pass fused unpack + LUT lookup + scale (Round 1 pattern). */
156158
float inv_std = tkv_fp16_to_fp32(block->inv_std_fp16);
157159
if (inv_std < 1e-10f) inv_std = sqrtf((float)dim);
158-
uint8_t indices[TQ_BK] = {0};
159-
unpack_3bit(block->mse_indices, indices, dim);
160-
tq_codebook_dequantize(indices, rotated, dim, 3, inv_std);
160+
float scale = 1.0f / inv_std;
161+
const float* cb = tq_codebook_centroids(3);
162+
float lut[8];
163+
for (int i = 0; i < 8; i++) lut[i] = cb[i] * scale;
164+
/* 3-bit packing is bit-stream LSB-first, 8 elements per 3 bytes */
165+
const uint8_t* p = block->mse_indices;
166+
int i = 0;
167+
for (; i + 7 < dim; i += 8) {
168+
/* 3 bytes encode 8 indices */
169+
uint32_t w = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
170+
rotated[i + 0] = lut[(w >> 0) & 7];
171+
rotated[i + 1] = lut[(w >> 3) & 7];
172+
rotated[i + 2] = lut[(w >> 6) & 7];
173+
rotated[i + 3] = lut[(w >> 9) & 7];
174+
rotated[i + 4] = lut[(w >> 12) & 7];
175+
rotated[i + 5] = lut[(w >> 15) & 7];
176+
rotated[i + 6] = lut[(w >> 18) & 7];
177+
rotated[i + 7] = lut[(w >> 21) & 7];
178+
p += 3;
179+
}
180+
/* Tail */
181+
for (; i < dim; i++) {
182+
int bit_off = i * 3;
183+
int byte_idx = bit_off / 8;
184+
int bit_pos = bit_off % 8;
185+
uint16_t v = block->mse_indices[byte_idx];
186+
if (bit_pos > 5 && byte_idx + 1 < (dim * 3 + 7) / 8) {
187+
v |= (uint16_t)block->mse_indices[byte_idx + 1] << 8;
188+
}
189+
rotated[i] = lut[(v >> bit_pos) & 7];
190+
}
161191
}
162192

163193
static void dequant_mse_rotated_4bit_v2(const block_tq_turbo_kv_4b* block,
164194
float* rotated, int dim) {
165-
/* Variant F: 4-bit codebook (16 levels) + max-abs scaling */
195+
/* Variant F: 4-bit codebook (16 levels) + max-abs scaling.
196+
*
197+
* Single-pass fused unpack + codebook lookup + scale.
198+
* Pre-multiply the per-block scale into a local 16-entry table so
199+
* the inner loop is one byte load + two table lookups + two stores.
200+
*/
166201
float inv_std = tkv_fp16_to_fp32(block->inv_std_fp16);
167202
if (inv_std < 1e-10f) inv_std = sqrtf((float)dim);
168-
/* Unpack 4-bit indices: 2 per byte, LSB-first */
169-
uint8_t indices[TQ_BK];
170-
for (int i = 0; i < dim; i++) {
171-
uint8_t b = block->mse_indices[i / 2];
172-
indices[i] = (i & 1) ? (b >> 4) : (b & 0x0F);
203+
float scale = 1.0f / inv_std;
204+
205+
/* Pre-scaled local codebook (16 entries) */
206+
const float* cb = tq_codebook_centroids(4);
207+
float lut[16];
208+
for (int i = 0; i < 16; i++) lut[i] = cb[i] * scale;
209+
210+
const uint8_t* mi = block->mse_indices;
211+
/* Process 2 elements per byte, unrolled by 2 bytes per iteration */
212+
int i = 0;
213+
int byte_n = dim / 2;
214+
for (int b = 0; b + 1 < byte_n; b += 2) {
215+
uint8_t b0 = mi[b];
216+
uint8_t b1 = mi[b + 1];
217+
rotated[i + 0] = lut[b0 & 0x0F];
218+
rotated[i + 1] = lut[b0 >> 4];
219+
rotated[i + 2] = lut[b1 & 0x0F];
220+
rotated[i + 3] = lut[b1 >> 4];
221+
i += 4;
222+
}
223+
for (int b = i / 2; b < byte_n; b++) {
224+
uint8_t bv = mi[b];
225+
rotated[i + 0] = lut[bv & 0x0F];
226+
rotated[i + 1] = lut[bv >> 4];
227+
i += 2;
228+
}
229+
/* Tail (odd dim) */
230+
if (i < dim) {
231+
uint8_t bv = mi[i / 2];
232+
rotated[i] = lut[bv & 0x0F];
173233
}
174-
tq_codebook_dequantize(indices, rotated, dim, 4, inv_std);
175234
}
176235

177236
/* ============================================================
@@ -410,41 +469,46 @@ void tq_turbo_kv_4b_attention_ref(const float* query, const void* kv_cache,
410469
for (int i = dim; i < TQ_BK; i++) q_rot[i] = 0.0f;
411470
tq_rht_transform(q_rot, dim, TKV_DEFAULT_SEED);
412471

472+
/* Hoist codebook pointer (constant for all blocks) */
473+
const float* cb = tq_codebook_centroids(4);
474+
413475
for (int seq = 0; seq < seq_len; seq++) {
414476
const block_tq_turbo_kv_4b* block = &blocks_4b[seq];
415477
float norm = tkv_fp16_to_fp32(block->norm);
416-
417-
float rotated[TQ_BK];
418-
dequant_mse_rotated_4bit_v2(block, rotated, dim);
419-
420-
float mse_dot = 0.0f;
421-
#ifdef __ARM_NEON
422-
{
423-
float32x4_t acc0 = vdupq_n_f32(0.0f);
424-
float32x4_t acc1 = vdupq_n_f32(0.0f);
425-
float32x4_t acc2 = vdupq_n_f32(0.0f);
426-
float32x4_t acc3 = vdupq_n_f32(0.0f);
427-
int d = 0;
428-
for (; d + 15 < dim; d += 16) {
429-
acc0 = vfmaq_f32(acc0, vld1q_f32(&q_rot[d]), vld1q_f32(&rotated[d]));
430-
acc1 = vfmaq_f32(acc1, vld1q_f32(&q_rot[d + 4]), vld1q_f32(&rotated[d + 4]));
431-
acc2 = vfmaq_f32(acc2, vld1q_f32(&q_rot[d + 8]), vld1q_f32(&rotated[d + 8]));
432-
acc3 = vfmaq_f32(acc3, vld1q_f32(&q_rot[d + 12]), vld1q_f32(&rotated[d + 12]));
433-
}
434-
acc0 = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
435-
for (; d + 3 < dim; d += 4) {
436-
acc0 = vfmaq_f32(acc0, vld1q_f32(&q_rot[d]), vld1q_f32(&rotated[d]));
437-
}
438-
mse_dot = vaddvq_f32(acc0);
439-
for (; d < dim; d++) {
440-
mse_dot += q_rot[d] * rotated[d];
441-
}
478+
float inv_std = tkv_fp16_to_fp32(block->inv_std_fp16);
479+
if (inv_std < 1e-10f) inv_std = sqrtf((float)dim);
480+
float scale = 1.0f / inv_std;
481+
482+
/* Per-block pre-scaled LUT (16 floats, fits in 64 bytes — L1 hot) */
483+
float lut[16];
484+
for (int j = 0; j < 16; j++) lut[j] = cb[j] * scale;
485+
486+
/* Round 4: fused scalar dequant + dot product, 4 accumulators.
487+
* Eliminates the rotated[] intermediate buffer entirely.
488+
* Apple Silicon's 6 ALUs + L1-hot LUT make scalar gather fast. */
489+
const uint8_t* mi = block->mse_indices;
490+
float a0 = 0, a1 = 0, a2 = 0, a3 = 0;
491+
int d = 0;
492+
for (; d + 7 < dim; d += 8) {
493+
uint8_t b0 = mi[d / 2 + 0];
494+
uint8_t b1 = mi[d / 2 + 1];
495+
uint8_t b2 = mi[d / 2 + 2];
496+
uint8_t b3 = mi[d / 2 + 3];
497+
a0 += q_rot[d + 0] * lut[b0 & 0x0F];
498+
a1 += q_rot[d + 1] * lut[b0 >> 4];
499+
a2 += q_rot[d + 2] * lut[b1 & 0x0F];
500+
a3 += q_rot[d + 3] * lut[b1 >> 4];
501+
a0 += q_rot[d + 4] * lut[b2 & 0x0F];
502+
a1 += q_rot[d + 5] * lut[b2 >> 4];
503+
a2 += q_rot[d + 6] * lut[b3 & 0x0F];
504+
a3 += q_rot[d + 7] * lut[b3 >> 4];
442505
}
443-
#else
444-
for (int d = 0; d < dim; d++) {
445-
mse_dot += q_rot[d] * rotated[d];
506+
float mse_dot = (a0 + a1) + (a2 + a3);
507+
for (; d < dim; d++) {
508+
uint8_t bv = mi[d / 2];
509+
int idx = (d & 1) ? (bv >> 4) : (bv & 0x0F);
510+
mse_dot += q_rot[d] * lut[idx];
446511
}
447-
#endif
448512

449513
scores[seq] = norm * mse_dot;
450514
}
@@ -985,11 +1049,43 @@ void tq_turbo_kv_5b_quantize_ref(const float* src, void* dst, int n) {
9851049

9861050
static void dequant_mse_rotated_5bit(const block_tq_turbo_kv_5b* block,
9871051
float* rotated, int dim) {
1052+
/* Single-pass fused unpack + LUT lookup + scale (Round 1 pattern).
1053+
* 5-bit packing: 5 bytes encode 8 indices (40 bits). */
9881054
float inv_std = tkv_fp16_to_fp32(block->inv_std_fp16);
9891055
if (inv_std < 1e-10f) inv_std = sqrtf((float)dim);
990-
uint8_t indices[TQ_BK] = {0};
991-
unpack_5bit(block->mse_indices, indices, dim);
992-
tq_codebook_dequantize(indices, rotated, dim, 5, inv_std);
1056+
float scale = 1.0f / inv_std;
1057+
const float* cb = tq_codebook_centroids(5);
1058+
float lut[32];
1059+
for (int i = 0; i < 32; i++) lut[i] = cb[i] * scale;
1060+
const uint8_t* p = block->mse_indices;
1061+
int i = 0;
1062+
for (; i + 7 < dim; i += 8) {
1063+
uint64_t w = (uint64_t)p[0]
1064+
| ((uint64_t)p[1] << 8)
1065+
| ((uint64_t)p[2] << 16)
1066+
| ((uint64_t)p[3] << 24)
1067+
| ((uint64_t)p[4] << 32);
1068+
rotated[i + 0] = lut[(w >> 0) & 31];
1069+
rotated[i + 1] = lut[(w >> 5) & 31];
1070+
rotated[i + 2] = lut[(w >> 10) & 31];
1071+
rotated[i + 3] = lut[(w >> 15) & 31];
1072+
rotated[i + 4] = lut[(w >> 20) & 31];
1073+
rotated[i + 5] = lut[(w >> 25) & 31];
1074+
rotated[i + 6] = lut[(w >> 30) & 31];
1075+
rotated[i + 7] = lut[(w >> 35) & 31];
1076+
p += 5;
1077+
}
1078+
/* Tail (slow path for non-multiple-of-8 dims) */
1079+
for (; i < dim; i++) {
1080+
int bit_off = i * 5;
1081+
int byte_idx = bit_off / 8;
1082+
int bit_pos = bit_off % 8;
1083+
uint16_t v = block->mse_indices[byte_idx];
1084+
if (bit_pos > 3 && byte_idx + 1 < (dim * 5 + 7) / 8) {
1085+
v |= (uint16_t)block->mse_indices[byte_idx + 1] << 8;
1086+
}
1087+
rotated[i] = lut[(v >> bit_pos) & 31];
1088+
}
9931089
}
9941090

9951091
void tq_turbo_kv_5b_dequantize_ref(const void* src, float* dst, int n) {
@@ -1150,15 +1246,36 @@ void tq_turbo_kv_4bo_quantize_ref(const float* src, void* dst, int n) {
11501246

11511247
static void dequant_mse_rotated_4bo(const block_tq_turbo_kv_4bo* block,
11521248
float* rotated, int dim) {
1153-
/* 4-bit codebook lookup */
1249+
/* Single-pass fused unpack + LUT lookup + scale (Round 1 pattern) */
11541250
float inv_std = tkv_fp16_to_fp32(block->inv_std_fp16);
11551251
if (inv_std < 1e-10f) inv_std = sqrtf((float)dim);
1156-
uint8_t indices[TQ_BK];
1157-
for (int i = 0; i < dim; i++) {
1158-
uint8_t b = block->mse_indices[i / 2];
1159-
indices[i] = (i & 1) ? (b >> 4) : (b & 0x0F);
1252+
float scale = 1.0f / inv_std;
1253+
const float* cb = tq_codebook_centroids(4);
1254+
float lut[16];
1255+
for (int i = 0; i < 16; i++) lut[i] = cb[i] * scale;
1256+
1257+
const uint8_t* mi = block->mse_indices;
1258+
int byte_n = dim / 2;
1259+
int i = 0;
1260+
for (int b = 0; b + 1 < byte_n; b += 2) {
1261+
uint8_t b0 = mi[b];
1262+
uint8_t b1 = mi[b + 1];
1263+
rotated[i + 0] = lut[b0 & 0x0F];
1264+
rotated[i + 1] = lut[b0 >> 4];
1265+
rotated[i + 2] = lut[b1 & 0x0F];
1266+
rotated[i + 3] = lut[b1 >> 4];
1267+
i += 4;
1268+
}
1269+
for (int b = i / 2; b < byte_n; b++) {
1270+
uint8_t bv = mi[b];
1271+
rotated[i + 0] = lut[bv & 0x0F];
1272+
rotated[i + 1] = lut[bv >> 4];
1273+
i += 2;
1274+
}
1275+
if (i < dim) {
1276+
uint8_t bv = mi[i / 2];
1277+
rotated[i] = lut[bv & 0x0F];
11601278
}
1161-
tq_codebook_dequantize(indices, rotated, dim, 4, inv_std);
11621279

11631280
/* Overwrite outlier positions with stored exact FP16 values */
11641281
int K = TQ_KV_4BO_OUTLIERS;
@@ -1305,11 +1422,37 @@ void tq_turbo_kv_3bo_quantize_ref(const float* src, void* dst, int n) {
13051422

13061423
static void dequant_mse_rotated_3bo(const block_tq_turbo_kv_3bo* block,
13071424
float* rotated, int dim) {
1425+
/* Single-pass fused unpack + LUT lookup + scale (Round 1 pattern) */
13081426
float inv_std = tkv_fp16_to_fp32(block->inv_std_fp16);
13091427
if (inv_std < 1e-10f) inv_std = sqrtf((float)dim);
1310-
uint8_t indices[TQ_BK] = {0};
1311-
unpack_3bit(block->mse_indices, indices, dim);
1312-
tq_codebook_dequantize(indices, rotated, dim, 3, inv_std);
1428+
float scale = 1.0f / inv_std;
1429+
const float* cb = tq_codebook_centroids(3);
1430+
float lut[8];
1431+
for (int i = 0; i < 8; i++) lut[i] = cb[i] * scale;
1432+
const uint8_t* p = block->mse_indices;
1433+
int i = 0;
1434+
for (; i + 7 < dim; i += 8) {
1435+
uint32_t w = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
1436+
rotated[i + 0] = lut[(w >> 0) & 7];
1437+
rotated[i + 1] = lut[(w >> 3) & 7];
1438+
rotated[i + 2] = lut[(w >> 6) & 7];
1439+
rotated[i + 3] = lut[(w >> 9) & 7];
1440+
rotated[i + 4] = lut[(w >> 12) & 7];
1441+
rotated[i + 5] = lut[(w >> 15) & 7];
1442+
rotated[i + 6] = lut[(w >> 18) & 7];
1443+
rotated[i + 7] = lut[(w >> 21) & 7];
1444+
p += 3;
1445+
}
1446+
for (; i < dim; i++) {
1447+
int bit_off = i * 3;
1448+
int byte_idx = bit_off / 8;
1449+
int bit_pos = bit_off % 8;
1450+
uint16_t v = block->mse_indices[byte_idx];
1451+
if (bit_pos > 5 && byte_idx + 1 < (dim * 3 + 7) / 8) {
1452+
v |= (uint16_t)block->mse_indices[byte_idx + 1] << 8;
1453+
}
1454+
rotated[i] = lut[(v >> bit_pos) & 7];
1455+
}
13131456

13141457
int K = TQ_KV_4BO_OUTLIERS;
13151458
for (int k = 0; k < K; k++) {

0 commit comments

Comments
 (0)