Skip to content

Commit fe6df17

Browse files
unamedkrclaude
andcommitted
turbo_kv: Karpathy loop round 1 — empirical per-block std (Variant A)
Repurposed the always-constant rht_seed field as inv_std_fp16 (block size unchanged, alignment preserved via _pad). Quantize now computes empirical std of the rotated values per block and stores its inverse for later dequant lookup, instead of the theoretical sqrt(dim). Llama 3.2 3B PPL on bench/data/ppl_1k.txt: turbo_kv_4b: 16.03 → 15.87 (Δ -0.16) turbo_kv_3b: 25.84 → 25.07 (Δ -0.77) Marginal improvement — confirms variance mismatch was real but small. The dominant bottleneck remains outlier clipping in the Lloyd-Max codebook. Round 2 will try max-abs based scaling. Target: turbo_kv_4b ≤ 14.5 (matching uniform_4b at same bit budget). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2520758 commit fe6df17

2 files changed

Lines changed: 44 additions & 29 deletions

File tree

include/turboquant/tq_types.h

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -207,26 +207,28 @@ typedef struct {
207207
/* TurboQuant KV cache block: RHT + Lloyd-Max codebook + QJL residual
208208
* 3-bit variant: 2-bit codebook (4 levels) + 1-bit QJL sign hash
209209
* Block covers TQ_BK elements (128).
210-
* Layout: norm(2) + residual_norm(2) + rht_seed(4) + mse_2bit(32) + qjl_signs(16) = 56 bytes
210+
* Layout: norm(2) + residual_norm(2) + inv_std(2) + _pad(2) + mse_2bit(32) + qjl_signs(16) = 56 bytes
211211
*/
212212
typedef struct {
213-
uint16_t norm; /* L2 norm of original vector (fp16) */
214-
uint16_t residual_norm; /* L2 norm of residual after MSE (fp16) */
215-
uint32_t rht_seed; /* RHT random seed for this block */
216-
uint8_t mse_indices[TQ_BK / 4]; /* 2-bit packed codebook indices (32B) */
217-
uint8_t qjl_signs[TQ_BK / 8]; /* 1-bit QJL sign hash on residual (16B) */
213+
uint16_t norm; /* L2 norm of original vector (fp16) */
214+
uint16_t residual_norm; /* L2 norm of residual after MSE (fp16) */
215+
uint16_t inv_std_fp16; /* per-block 1/std for codebook lookup */
216+
uint16_t _pad; /* alignment padding (was rht_seed upper) */
217+
uint8_t mse_indices[TQ_BK / 4]; /* 2-bit packed codebook indices (32B) */
218+
uint8_t qjl_signs[TQ_BK / 8]; /* 1-bit QJL sign hash on residual (16B) */
218219
} block_tq_turbo_kv_3b;
219220

220221
/* TurboQuant KV cache block: 4-bit variant
221222
* 3-bit codebook (8 levels) + 1-bit QJL sign hash
222-
* Layout: norm(2) + residual_norm(2) + rht_seed(4) + mse_3bit(48) + qjl_signs(16) = 72 bytes
223+
* Layout: norm(2) + residual_norm(2) + inv_std(2) + _pad(2) + mse_3bit(48) + qjl_signs(16) = 72 bytes
223224
*/
224225
typedef struct {
225-
uint16_t norm; /* L2 norm of original vector (fp16) */
226-
uint16_t residual_norm; /* L2 norm of residual after MSE (fp16) */
227-
uint32_t rht_seed; /* RHT random seed for this block */
228-
uint8_t mse_indices[TQ_BK * 3 / 8]; /* 3-bit packed codebook indices (48B) */
229-
uint8_t qjl_signs[TQ_BK / 8]; /* 1-bit QJL sign hash on residual (16B) */
226+
uint16_t norm; /* L2 norm of original vector (fp16) */
227+
uint16_t residual_norm; /* L2 norm of residual after MSE (fp16) */
228+
uint16_t inv_std_fp16; /* per-block 1/std for codebook lookup */
229+
uint16_t _pad; /* alignment padding */
230+
uint8_t mse_indices[TQ_BK * 3 / 8]; /* 3-bit packed codebook indices (48B) */
231+
uint8_t qjl_signs[TQ_BK / 8]; /* 1-bit QJL sign hash on residual (16B) */
230232
} block_tq_turbo_kv_4b;
231233

232234
/* TurboQuant KV cache block: 1-bit Hamming attention

src/core/tq_turbo_kv.c

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,17 @@ static void compute_qjl_signs(const float* residual, uint8_t* signs,
152152

153153
static void dequant_mse_rotated_2bit(const block_tq_turbo_kv_3b* block,
154154
float* rotated, int dim) {
155-
float inv_std = sqrtf((float)dim);
155+
float inv_std = tkv_fp16_to_fp32(block->inv_std_fp16);
156+
if (inv_std < 1e-10f) inv_std = sqrtf((float)dim); /* fallback */
156157
uint8_t indices[TQ_BK] = {0};
157158
unpack_2bit(block->mse_indices, indices, dim);
158159
tq_codebook_dequantize(indices, rotated, dim, 2, inv_std);
159160
}
160161

161162
static void dequant_mse_rotated_3bit(const block_tq_turbo_kv_4b* block,
162163
float* rotated, int dim) {
163-
float inv_std = sqrtf((float)dim);
164+
float inv_std = tkv_fp16_to_fp32(block->inv_std_fp16);
165+
if (inv_std < 1e-10f) inv_std = sqrtf((float)dim); /* fallback */
164166
uint8_t indices[TQ_BK] = {0};
165167
unpack_3bit(block->mse_indices, indices, dim);
166168
tq_codebook_dequantize(indices, rotated, dim, 3, inv_std);
@@ -195,14 +197,20 @@ void tq_turbo_kv_3b_quantize_ref(const float* src, void* dst, int n) {
195197
}
196198

197199
/* Step 3: Apply RHT (in-place on rotated) */
198-
uint32_t seed = TKV_DEFAULT_SEED;
199-
block->rht_seed = seed;
200-
tq_rht_transform(rotated, dim, seed);
201-
202-
/* Step 4: Scalar quantize with 2-bit codebook
203-
* After RHT, coordinates are approximately N(0, 1/sqrt(dim)).
204-
* inv_std = sqrt(dim) to normalize to N(0,1). */
205-
float inv_std = sqrtf((float)dim);
200+
tq_rht_transform(rotated, dim, TKV_DEFAULT_SEED);
201+
202+
/* Step 4: Compute per-block empirical std and quantize with 2-bit codebook.
203+
* Theoretical analysis says rotated coords ~ N(0, 1/dim), but real key
204+
* vectors after a single Hadamard rotation often have heavier tails or
205+
* different variance per block. Using the empirical std adapts the
206+
* codebook to the actual block distribution. */
207+
float var_emp = 0.0f;
208+
for (int i = 0; i < dim; i++) var_emp += rotated[i] * rotated[i];
209+
var_emp /= (float)dim;
210+
float std_emp = sqrtf(var_emp);
211+
if (std_emp < 1e-10f) std_emp = 1.0f / sqrtf((float)dim);
212+
float inv_std = 1.0f / std_emp;
213+
block->inv_std_fp16 = tkv_fp32_to_fp16(inv_std);
206214
uint8_t indices[TQ_BK];
207215
tq_codebook_quantize(rotated, indices, dim, 2, inv_std);
208216

@@ -248,7 +256,7 @@ void tq_turbo_kv_3b_dequantize_ref(const void* src, float* dst, int n) {
248256
if (dim > TQ_BK) dim = TQ_BK;
249257

250258
float norm = tkv_fp16_to_fp32(block->norm);
251-
uint32_t seed = block->rht_seed;
259+
uint32_t seed = TKV_DEFAULT_SEED;
252260

253261
/* MSE-only dequantize in rotated space */
254262
float rotated[TQ_BK];
@@ -432,11 +440,16 @@ void tq_turbo_kv_4b_quantize_ref(const float* src, void* dst, int n) {
432440
rotated[i] = 0.0f;
433441
}
434442

435-
uint32_t seed = TKV_DEFAULT_SEED;
436-
block->rht_seed = seed;
437-
tq_rht_transform(rotated, dim, seed);
443+
tq_rht_transform(rotated, dim, TKV_DEFAULT_SEED);
438444

439-
float inv_std = sqrtf((float)dim);
445+
/* Per-block empirical std (see 3-bit variant for rationale) */
446+
float var_emp = 0.0f;
447+
for (int i = 0; i < dim; i++) var_emp += rotated[i] * rotated[i];
448+
var_emp /= (float)dim;
449+
float std_emp = sqrtf(var_emp);
450+
if (std_emp < 1e-10f) std_emp = 1.0f / sqrtf((float)dim);
451+
float inv_std = 1.0f / std_emp;
452+
block->inv_std_fp16 = tkv_fp32_to_fp16(inv_std);
440453
uint8_t indices[TQ_BK];
441454
tq_codebook_quantize(rotated, indices, dim, 3, inv_std);
442455
pack_3bit(indices, block->mse_indices, dim);
@@ -471,7 +484,7 @@ void tq_turbo_kv_4b_dequantize_ref(const void* src, float* dst, int n) {
471484
if (dim > TQ_BK) dim = TQ_BK;
472485

473486
float norm = tkv_fp16_to_fp32(block->norm);
474-
uint32_t seed = block->rht_seed;
487+
uint32_t seed = TKV_DEFAULT_SEED;
475488

476489
float rotated[TQ_BK];
477490
dequant_mse_rotated_3bit(block, rotated, dim);
@@ -694,7 +707,7 @@ void tq_turbo_kv_1b_dequantize_ref(const void* src, float* dst, int n) {
694707
if (sketch_dim < TQ_BK) sketch_dim = TQ_BK;
695708

696709
float norm = tkv_fp16_to_fp32(block->norm);
697-
uint32_t seed = block->rht_seed;
710+
uint32_t seed = TKV_DEFAULT_SEED;
698711

699712
/* Reconstruct sign vector in rotated space.
700713
* After RHT, coordinates are ~N(0, 1/sqrt(dim)).

0 commit comments

Comments
 (0)