Skip to content

Commit b7fe468

Browse files
unamedkrclaude
andcommitted
Multi-hash 1-bit prototype: rejected — uniform wins at every bit budget
Tested Permute+RHT multi-hash sign dequantization (dim=64, 128): K=2 (2.75 bpe): cosine 0.882 — uniform_2b (2.25 bpe, 0.931) wins K=3 (4.12 bpe): cosine 0.916 — uniform_4b (4.25 bpe, 0.995) wins K=8 (11 bpe): cosine 0.965 — still can't match uniform_4b Root cause: sign quantization is an unbiased inner product ESTIMATOR but an inefficient vector RECONSTRUCTOR. Averaging K sign reconstructions improves only as 1/K, while uniform quantization starts much higher. Conclusion: 1-bit sign-based KV cache is fundamentally incompatible with the dequant+FP32 attention path. The 4-bit sweet spot (PPL +0.0%, 3.8x compression) is the practical optimum. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 18eeed1 commit b7fe468

8 files changed

Lines changed: 782 additions & 53 deletions

File tree

bench/test_multihash_dequant.c

Lines changed: 626 additions & 0 deletions
Large diffs are not rendered by default.

include/turboquant/tq_engine.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ typedef struct {
5050
float rope_local_base_freq; /* RoPE base freq for local/sliding layers (10000.0 for gemma3) */
5151
int n_norms_per_block; /* 2 for qwen35, 4 for gemma3 */
5252
float query_pre_attn_scalar; /* attention scaling: 1/sqrt(this) instead of 1/sqrt(head_dim), 0=use head_dim */
53+
54+
/* Gemma 4 hybrid attention: full layers have different head_dim/kv_heads than sliding.
55+
* head_dim/n_heads/n_kv_heads store sliding layer values (majority).
56+
* These store full layer values (0 = no hybrid, use sliding values). */
57+
int full_head_dim; /* head_dim for full attention layers (e.g., 512 vs sliding 256) */
58+
int full_n_heads; /* n_heads for full layers (e.g., 8 vs sliding 16) */
59+
int full_n_kv_heads; /* n_kv_heads for full layers (e.g., 2 vs sliding 8) */
5360
} tq_model_config_t;
5461

5562
/* ============================================================

src/engine/tq_model.c

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1915,6 +1915,8 @@ static size_t calc_q4_buffer_size(const tq_model_t* model) {
19151915
int dim = c->hidden_dim;
19161916
int q_dim = c->n_heads * c->head_dim;
19171917
int kv_dim = c->n_kv_heads * c->head_dim;
1918+
int full_kv_dim = (c->full_n_kv_heads > 0 && c->full_head_dim > 0)
1919+
? c->full_n_kv_heads * c->full_head_dim : kv_dim;
19181920
int inter = c->intermediate_dim;
19191921
int qg_dim = c->attn_output_gate ? q_dim * 2 : q_dim;
19201922

@@ -1926,6 +1928,7 @@ static size_t calc_q4_buffer_size(const tq_model_t* model) {
19261928

19271929
for (int l = 0; l < c->n_layers; l++) {
19281930
const tq_layer_weights_t* layer = &model->layers[l];
1931+
int lkv = (model->layer_is_sliding && !model->layer_is_sliding[l]) ? full_kv_dim : kv_dim;
19291932

19301933
/* Self-attention weights */
19311934
if (layer->wq) {
@@ -1935,13 +1938,13 @@ static size_t calc_q4_buffer_size(const tq_model_t* model) {
19351938
}
19361939
if (layer->wk) {
19371940
int nb = (dim + 31) / 32;
1938-
total += (size_t)kv_dim * nb * 16;
1939-
total += (size_t)kv_dim * nb * 4;
1941+
total += (size_t)lkv * nb * 16;
1942+
total += (size_t)lkv * nb * 4;
19401943
}
19411944
if (layer->wv) {
19421945
int nb = (dim + 31) / 32;
1943-
total += (size_t)kv_dim * nb * 16;
1944-
total += (size_t)kv_dim * nb * 4;
1946+
total += (size_t)lkv * nb * 16;
1947+
total += (size_t)lkv * nb * 4;
19451948
}
19461949
if (layer->wo) {
19471950
int nb = (q_dim + 31) / 32;
@@ -2003,6 +2006,8 @@ void tq_quantize_weights_q4(tq_model_t* model) {
20032006
int dim = c->hidden_dim;
20042007
int q_dim = c->n_heads * c->head_dim;
20052008
int kv_dim = c->n_kv_heads * c->head_dim;
2009+
int full_kv_dim = (c->full_n_kv_heads > 0 && c->full_head_dim > 0)
2010+
? c->full_n_kv_heads * c->full_head_dim : kv_dim;
20062011
int inter = c->intermediate_dim;
20072012
int qg_dim = c->attn_output_gate ? q_dim * 2 : q_dim;
20082013

@@ -2023,17 +2028,18 @@ void tq_quantize_weights_q4(tq_model_t* model) {
20232028

20242029
for (int l = 0; l < c->n_layers; l++) {
20252030
tq_layer_weights_t* layer = &model->layers[l];
2031+
int lkv = (model->layer_is_sliding && !model->layer_is_sliding[l]) ? full_kv_dim : kv_dim;
20262032

20272033
/* Self-attention */
20282034
quantize_matrix_q4(layer->wq, qg_dim, dim,
20292035
&layer->wq_q4, &layer->wq_q4s, &buf, &used);
20302036
if (layer->wq_q4) layer->wq = NULL;
20312037

2032-
quantize_matrix_q4(layer->wk, kv_dim, dim,
2038+
quantize_matrix_q4(layer->wk, lkv, dim,
20332039
&layer->wk_q4, &layer->wk_q4s, &buf, &used);
20342040
if (layer->wk_q4) layer->wk = NULL;
20352041

2036-
quantize_matrix_q4(layer->wv, kv_dim, dim,
2042+
quantize_matrix_q4(layer->wv, lkv, dim,
20372043
&layer->wv_q4, &layer->wv_q4s, &buf, &used);
20382044
if (layer->wv_q4) layer->wv = NULL;
20392045

@@ -3246,34 +3252,59 @@ tq_model_t* tq_load_gguf(const char* path) {
32463252
}
32473253

32483254
/* Set up layer_is_sliding for Gemma hybrid attention.
3249-
* Detect from Q tensor shape: sliding layers have smaller Q output dim. */
3255+
* Detect from K tensor shape: sliding layers have LARGER kv_dim (more kv_heads),
3256+
* full layers have SMALLER kv_dim (fewer kv_heads, larger head_dim).
3257+
* Q tensor shapes can be identical for both types, so K is the reliable signal. */
32503258
if (c->sliding_window > 0 && c->model_type == 1) {
32513259
model->layer_is_sliding = (int*)calloc((size_t)c->n_layers, sizeof(int));
32523260
if (model->layer_is_sliding) {
3253-
/* Find the smallest Q output dim (sliding) */
3254-
int min_q = 999999;
3261+
/* Find the largest K output dim (sliding layers have more kv_heads) */
3262+
int max_k = 0;
32553263
for (int l = 0; l < c->n_layers; l++) {
32563264
char tname[128];
3257-
snprintf(tname, sizeof(tname), "blk.%d.attn_q.weight", l);
3258-
const tq_gguf_tensor_t* qt = tq_gguf_find_tensor(gguf, tname);
3259-
if (qt && (int)qt->shape[1] < min_q) min_q = (int)qt->shape[1];
3265+
snprintf(tname, sizeof(tname), "blk.%d.attn_k.weight", l);
3266+
const tq_gguf_tensor_t* kt = tq_gguf_find_tensor(gguf, tname);
3267+
if (kt && (int)kt->shape[1] > max_k) max_k = (int)kt->shape[1];
32603268
}
32613269
int n_sliding = 0, n_full = 0;
3270+
int full_kv_dim = 0;
32623271
for (int l = 0; l < c->n_layers; l++) {
32633272
char tname[128];
3264-
snprintf(tname, sizeof(tname), "blk.%d.attn_q.weight", l);
3265-
const tq_gguf_tensor_t* qt = tq_gguf_find_tensor(gguf, tname);
3266-
if (qt && (int)qt->shape[1] == min_q) {
3273+
snprintf(tname, sizeof(tname), "blk.%d.attn_k.weight", l);
3274+
const tq_gguf_tensor_t* kt = tq_gguf_find_tensor(gguf, tname);
3275+
int k_out = kt ? (int)kt->shape[1] : max_k;
3276+
if (k_out == max_k) {
32673277
model->layer_is_sliding[l] = 1;
32683278
n_sliding++;
32693279
} else {
32703280
model->layer_is_sliding[l] = 0;
32713281
n_full++;
3282+
full_kv_dim = k_out;
32723283
}
32733284
}
3274-
if (n_full > 0) {
3275-
fprintf(stderr, "tq_load_gguf: Gemma hybrid — %d sliding + %d full attention layers\n",
3276-
n_sliding, n_full);
3285+
if (n_full > 0 && full_kv_dim > 0) {
3286+
/* Compute full layer dimensions from kv_dim and Q/K shapes */
3287+
c->full_head_dim = c->head_dim * 2; /* Gemma 4: 256 → 512 */
3288+
/* Verify by checking if full_kv_dim / full_head_dim is integer */
3289+
if (full_kv_dim % c->full_head_dim == 0) {
3290+
c->full_n_kv_heads = full_kv_dim / c->full_head_dim;
3291+
} else {
3292+
/* Try the metadata key_length as full head_dim */
3293+
int meta_hd = tq_gguf_get_i32(gguf, GGUF_KEY("attention.key_length"), 0);
3294+
if (meta_hd > 0 && full_kv_dim % meta_hd == 0) {
3295+
c->full_head_dim = meta_hd;
3296+
c->full_n_kv_heads = full_kv_dim / meta_hd;
3297+
} else {
3298+
c->full_head_dim = c->head_dim;
3299+
c->full_n_kv_heads = c->n_kv_heads;
3300+
}
3301+
}
3302+
/* Q dim is n_heads * head_dim (NOT hidden_dim). It's constant across layers. */
3303+
c->full_n_heads = (c->n_heads * c->head_dim) / c->full_head_dim;
3304+
fprintf(stderr, "tq_load_gguf: Gemma hybrid — %d sliding (hd=%d, kv=%d) + "
3305+
"%d full (hd=%d, kv=%d, heads=%d) attention layers\n",
3306+
n_sliding, c->head_dim, c->n_kv_heads,
3307+
n_full, c->full_head_dim, c->full_n_kv_heads, c->full_n_heads);
32773308
}
32783309
}
32793310
}
@@ -3335,6 +3366,8 @@ tq_model_t* tq_load_gguf(const char* path) {
33353366
int dim = c->hidden_dim;
33363367
int q_dim = c->n_heads * c->head_dim;
33373368
int kv_dim = c->n_kv_heads * c->head_dim;
3369+
int full_kv_dim = (c->full_n_kv_heads > 0 && c->full_head_dim > 0)
3370+
? c->full_n_kv_heads * c->full_head_dim : kv_dim;
33383371
int inter = c->intermediate_dim;
33393372
int qg_dim = c->attn_output_gate ? q_dim * 2 : q_dim;
33403373
int delta_nkv = c->delta_n_kv_heads > 0 ? c->delta_n_kv_heads : c->delta_n_heads;
@@ -3346,12 +3379,13 @@ tq_model_t* tq_load_gguf(const char* path) {
33463379
size_t est_fp32 = 0;
33473380
for (int l = 0; l < c->n_layers; l++) {
33483381
const tq_layer_weights_t* layer = &model->layers[l];
3382+
int lkv = (model->layer_is_sliding && !model->layer_is_sliding[l]) ? full_kv_dim : kv_dim;
33493383
if (layer->gguf_wq)
33503384
est_fp32 += (size_t)qg_dim * dim * sizeof(float);
33513385
if (layer->gguf_wk)
3352-
est_fp32 += (size_t)kv_dim * dim * sizeof(float);
3386+
est_fp32 += (size_t)lkv * dim * sizeof(float);
33533387
if (layer->gguf_wv)
3354-
est_fp32 += (size_t)kv_dim * dim * sizeof(float);
3388+
est_fp32 += (size_t)lkv * dim * sizeof(float);
33553389
if (layer->gguf_wo)
33563390
est_fp32 += (size_t)dim * q_dim * sizeof(float);
33573391
/* Dense FFN weights (not present in MoE layers) */
@@ -3409,7 +3443,8 @@ tq_model_t* tq_load_gguf(const char* path) {
34093443
}
34103444
}
34113445
if (layer->gguf_wk) {
3412-
int n = kv_dim * dim;
3446+
int lkv = (model->layer_is_sliding && !model->layer_is_sliding[l]) ? full_kv_dim : kv_dim;
3447+
int n = lkv * dim;
34133448
float* fp = (float*)malloc((size_t)n * sizeof(float));
34143449
if (fp) {
34153450
tq_dequant_row_gguf(layer->gguf_wk_type, layer->gguf_wk, fp, n);
@@ -3419,7 +3454,8 @@ tq_model_t* tq_load_gguf(const char* path) {
34193454
}
34203455
}
34213456
if (layer->gguf_wv) {
3422-
int n = kv_dim * dim;
3457+
int lkv = (model->layer_is_sliding && !model->layer_is_sliding[l]) ? full_kv_dim : kv_dim;
3458+
int n = lkv * dim;
34233459
float* fp = (float*)malloc((size_t)n * sizeof(float));
34243460
if (fp) {
34253461
tq_dequant_row_gguf(layer->gguf_wv_type, layer->gguf_wv, fp, n);

0 commit comments

Comments
 (0)