Skip to content

Commit d5a5bd4

Browse files
unamedkrclaude
andcommitted
uniform_3b with sub-block scales: PPL +60% — 4-bit remains the sweet spot
Implemented uniform_3b with 4 independent sub-blocks of 32 elements, each with FP16 scale/min (vs single scale for 128 elements before). Results (SmolLM2 1.7B, 815 tokens): uniform_4b: PPL 9.51 (+14%, 4.25 bpe) uniform_3b: PPL 13.28 (+60%, 4.0 bpe) — sub-block scales help but insufficient Sub-block scales improved from the broken single-scale 3-bit (+88%), but 8 quantization levels fundamentally can't match 16 levels for attention-critical key vectors. Honest conclusion: 4-bit K + Q4 V = 3.8x compression, PPL <1% is the practical optimum with current quantization approaches. 33/33 tests pass, 0 warnings. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 900d6bc commit d5a5bd4

7 files changed

Lines changed: 309 additions & 22 deletions

File tree

include/turboquant/tq_types.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ typedef enum {
5353
TQ_TYPE_TURBO_KV_4B = 9, /* TurboQuant KV: 3-bit codebook + 1-bit QJL residual */
5454
TQ_TYPE_TURBO_KV_1B = 10,/* TurboQuant KV: 1-bit Hamming (sign only) */
5555
TQ_TYPE_TURBO_KV_2B = 11,/* TurboQuant KV: 2-bit (1-bit codebook + 1-bit QJL) */
56-
TQ_TYPE_COUNT = 12
56+
TQ_TYPE_UNIFORM_3B= 12, /* Min-Max uniform 3-bit with sub-block scales */
57+
TQ_TYPE_COUNT = 13
5758
} tq_type;
5859

5960
/* ============================================================
@@ -112,6 +113,22 @@ typedef struct {
112113

113114
/* size verified after extern "C" block */
114115

116+
/* Uniform 3-bit with sub-block scales (Q3_K-style)
117+
* 4 sub-blocks of 32 elements, each with independent FP16 scale/min.
118+
* 8 quantization levels (3-bit) per value, but adapted to local statistics.
119+
* 4.0 bits per element: (16 bytes meta + 48 bytes data) / 128 elements.
120+
*/
121+
#define TQ_3B_NSUB 4 /* sub-blocks per block */
122+
#define TQ_3B_SUBK (TQ_BK / TQ_3B_NSUB) /* 32 elements per sub */
123+
124+
typedef struct {
125+
uint16_t sub_scale[TQ_3B_NSUB]; /* per-sub-block scale (fp16, 8B) */
126+
uint16_t sub_min[TQ_3B_NSUB]; /* per-sub-block minimum (fp16, 8B) */
127+
uint8_t qs[TQ_BK * 3 / 8]; /* 3-bit packed data (48B) */
128+
} block_tq_uniform_3b; /* 64 bytes per 128 elements */
129+
130+
/* size verified after extern "C" block */
131+
115132
/* Mixed precision: 4-bit base with fp16 outlier channels
116133
* Top-k channels by absolute value are stored at fp16 precision.
117134
* Remaining channels use 4-bit uniform quantization with a tighter
@@ -241,6 +258,7 @@ TQ_CHECK_SIZE(block_tq_polar, 8 + TQ_BK / 2);
241258
TQ_CHECK_SIZE(block_tq_qjl, 4 + TQ_SKETCH_DIM / 8 + TQ_OUTLIERS);
242259
TQ_CHECK_SIZE(block_tq_uniform_4b, 4 + TQ_BK / 2);
243260
TQ_CHECK_SIZE(block_tq_uniform_2b, 4 + TQ_BK / 4);
261+
TQ_CHECK_SIZE(block_tq_uniform_3b, 4 * TQ_3B_NSUB + TQ_BK * 3 / 8);
244262
TQ_CHECK_SIZE(block_tq_mixed_4b8, 4 + TQ_MIXED_OUTLIERS + TQ_MIXED_OUTLIERS * 2 + TQ_BK / 2);
245263
TQ_CHECK_SIZE(block_tq_turbo_kv_3b, 8 + TQ_BK / 4 + TQ_BK / 8);
246264
TQ_CHECK_SIZE(block_tq_turbo_kv_4b, 8 + TQ_BK * 3 / 8 + TQ_BK / 8);

integrations/llamacpp/tq_kv_cache.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ enum {
4444
GGML_TYPE_TQ_TURBO_KV_4B = GGML_TYPE_TQ_BASE + 9,
4545
GGML_TYPE_TQ_TURBO_KV_1B = GGML_TYPE_TQ_BASE + 10,
4646
GGML_TYPE_TQ_TURBO_KV_2B = GGML_TYPE_TQ_BASE + 11,
47-
GGML_TYPE_TQ_COUNT = 12,
47+
GGML_TYPE_TQ_UNIFORM_3B = GGML_TYPE_TQ_BASE + 12,
48+
GGML_TYPE_TQ_COUNT = 13,
4849
};
4950

5051
/* ============================================================
@@ -65,6 +66,7 @@ static int tq_to_ggml_type(tq_type type) {
6566
case TQ_TYPE_TURBO_KV_4B: return GGML_TYPE_TQ_TURBO_KV_4B;
6667
case TQ_TYPE_TURBO_KV_1B: return GGML_TYPE_TQ_TURBO_KV_1B;
6768
case TQ_TYPE_TURBO_KV_2B: return GGML_TYPE_TQ_TURBO_KV_2B;
69+
case TQ_TYPE_UNIFORM_3B: return GGML_TYPE_TQ_UNIFORM_3B;
6870
default: return -1;
6971
}
7072
}
@@ -83,6 +85,7 @@ static tq_type ggml_to_tq_type(int ggml_id) {
8385
case GGML_TYPE_TQ_TURBO_KV_4B: return TQ_TYPE_TURBO_KV_4B;
8486
case GGML_TYPE_TQ_TURBO_KV_1B: return TQ_TYPE_TURBO_KV_1B;
8587
case GGML_TYPE_TQ_TURBO_KV_2B: return TQ_TYPE_TURBO_KV_2B;
88+
case GGML_TYPE_TQ_UNIFORM_3B: return TQ_TYPE_UNIFORM_3B;
8689
default: return TQ_TYPE_COUNT;
8790
}
8891
}
@@ -147,6 +150,7 @@ TQ_GGML_WRAPPERS(turbo_kv_3b, TQ_TYPE_TURBO_KV_3B)
147150
TQ_GGML_WRAPPERS(turbo_kv_4b, TQ_TYPE_TURBO_KV_4B)
148151
TQ_GGML_WRAPPERS(turbo_kv_1b, TQ_TYPE_TURBO_KV_1B)
149152
TQ_GGML_WRAPPERS(turbo_kv_2b, TQ_TYPE_TURBO_KV_2B)
153+
TQ_GGML_WRAPPERS(uniform_3b, TQ_TYPE_UNIFORM_3B)
150154

151155
/* ============================================================
152156
* vec_dot wrappers (quantized key . FP32 query -> scalar)
@@ -199,6 +203,7 @@ TQ_GGML_VEC_DOT(turbo_kv_3b, TQ_TYPE_TURBO_KV_3B)
199203
TQ_GGML_VEC_DOT(turbo_kv_4b, TQ_TYPE_TURBO_KV_4B)
200204
TQ_GGML_VEC_DOT(turbo_kv_1b, TQ_TYPE_TURBO_KV_1B)
201205
TQ_GGML_VEC_DOT(turbo_kv_2b, TQ_TYPE_TURBO_KV_2B)
206+
TQ_GGML_VEC_DOT(uniform_3b, TQ_TYPE_UNIFORM_3B)
202207

203208
/* ============================================================
204209
* GGML type trait table
@@ -314,6 +319,14 @@ static const tq_ggml_type_trait TQ_GGML_TRAITS[GGML_TYPE_TQ_COUNT] = {
314319
tq_ggml_to_float_turbo_kv_2b,
315320
tq_ggml_vec_dot_turbo_kv_2b,
316321
},
322+
{
323+
"tq_uniform_3b", GGML_TYPE_TQ_UNIFORM_3B, TQ_TYPE_UNIFORM_3B,
324+
sizeof(block_tq_uniform_3b), TQ_BK,
325+
(float)sizeof(block_tq_uniform_3b) * 8.0f / TQ_BK,
326+
tq_ggml_from_float_uniform_3b,
327+
tq_ggml_to_float_uniform_3b,
328+
tq_ggml_vec_dot_uniform_3b,
329+
},
317330
};
318331

319332
#define TQ_GGML_NUM_TYPES (sizeof(TQ_GGML_TRAITS) / sizeof(TQ_GGML_TRAITS[0]))

src/core/tq_traits.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ extern void tq_mixed_4b8_dequantize_ref(const void* src, float* dst, int n);
3333
extern void tq_mixed_4b8_attention_ref(const float* query, const void* kv,
3434
float* scores, int seq_len, int head_dim);
3535

36+
extern void tq_uniform_3b_quantize_ref(const float* src, void* dst, int n);
37+
extern void tq_uniform_3b_dequantize_ref(const void* src, float* dst, int n);
38+
extern void tq_uniform_3b_attention_ref(const float* query, const void* kv,
39+
float* scores, int seq_len, int head_dim);
40+
3641
extern void tq_turbo_kv_3b_quantize_ref(const float* src, void* dst, int n);
3742
extern void tq_turbo_kv_3b_dequantize_ref(const void* src, float* dst, int n);
3843
extern void tq_turbo_kv_3b_attention_ref(const float* query, const void* kv,
@@ -174,6 +179,16 @@ const tq_type_traits_t TQ_TRAITS[TQ_TYPE_COUNT] = {
174179
.attention = tq_turbo_kv_2b_attention_ref,
175180
.residual_type = TQ_TYPE_QJL_1B,
176181
},
182+
[TQ_TYPE_UNIFORM_3B] = {
183+
.name = "uniform_3b",
184+
.block_size = TQ_BK,
185+
.type_size = sizeof(block_tq_uniform_3b),
186+
.bpe = (float)sizeof(block_tq_uniform_3b) * 8.0f / TQ_BK,
187+
.quantize = tq_uniform_3b_quantize_ref,
188+
.dequantize = tq_uniform_3b_dequantize_ref,
189+
.attention = tq_uniform_3b_attention_ref,
190+
.residual_type = TQ_TYPE_COUNT,
191+
},
177192
};
178193

179194
const char* tq_type_name(tq_type type) {
@@ -249,6 +264,8 @@ tq_format_spec_t tq_get_format_spec(tq_type type) {
249264
case TQ_TYPE_TURBO_KV_2B:
250265
spec.algorithm = TQ_ALG_TURBO; spec.key_bits = 2;
251266
spec.flags = TQ_FLAG_HAS_RESIDUAL; break;
267+
case TQ_TYPE_UNIFORM_3B:
268+
spec.algorithm = TQ_ALG_UNIFORM; spec.key_bits = 3; break;
252269
default: break;
253270
}
254271
return spec;

src/core/tq_uniform.c

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,120 @@ void tq_uniform_2b_attention_ref(const float* query, const void* kv,
258258
scores[s] = dot;
259259
}
260260
}
261+
262+
/* ====================================================================
263+
* Uniform 3-bit with per-sub-block FP16 scales (Q3_K-style)
264+
*
265+
* Each 128-element block is split into 4 sub-blocks of 32 elements.
266+
* Each sub-block has independent FP16 scale and minimum, giving
267+
* excellent adaptation to local value distributions.
268+
*
269+
* 8 quantization levels (3-bit) per value.
270+
* 64 bytes / 128 elements = 4.0 bpe.
271+
*
272+
* Compared to uniform_4b (4.0 bpe, 16 levels, 1 global scale):
273+
* - Fewer levels (8 vs 16) but finer per-sub-block adaptation
274+
* - Better for heterogeneous distributions within a head dimension
275+
* ==================================================================== */
276+
277+
/* ---------- Uniform 3-bit sub-block quantize ---------- */
278+
279+
void tq_uniform_3b_quantize_ref(const float* src, void* dst, int n) {
280+
block_tq_uniform_3b* block = (block_tq_uniform_3b*)dst;
281+
int count = n;
282+
if (count > TQ_BK) count = TQ_BK;
283+
284+
/* Compute per-sub-block min/max and store FP16 scale/min */
285+
for (int sb = 0; sb < TQ_3B_NSUB; sb++) {
286+
int start = sb * TQ_3B_SUBK;
287+
int end = start + TQ_3B_SUBK;
288+
if (end > count) end = count;
289+
float mn = FLT_MAX, mx = -FLT_MAX;
290+
for (int i = start; i < end; i++) {
291+
if (src[i] < mn) mn = src[i];
292+
if (src[i] > mx) mx = src[i];
293+
}
294+
if (end <= start) { mn = 0; mx = 0; }
295+
296+
float range = mx - mn;
297+
if (range < 1e-8f) range = 1e-8f;
298+
float scale = range / 8.0f; /* 3-bit: 8 bins of width range/8 */
299+
300+
block->sub_scale[sb] = uni_fp32_to_fp16(scale);
301+
block->sub_min[sb] = uni_fp32_to_fp16(mn);
302+
}
303+
304+
/* Pack 3-bit quantized values into qs (LSB-first).
305+
* Use the FP16-reconstructed scale/min for quantization
306+
* to minimize encode/decode mismatch.
307+
*/
308+
memset(block->qs, 0, TQ_BK * 3 / 8);
309+
for (int i = 0; i < count; i++) {
310+
int sb = i / TQ_3B_SUBK;
311+
float scale = uni_fp16_to_fp32(block->sub_scale[sb]);
312+
float mn = uni_fp16_to_fp32(block->sub_min[sb]);
313+
if (scale < 1e-10f) scale = 1e-10f;
314+
315+
int q = (int)floorf((src[i] - mn) / scale);
316+
if (q < 0) q = 0;
317+
if (q > 7) q = 7;
318+
319+
/* 3-bit packing: element i uses bits [i*3 .. i*3+2] across qs bytes */
320+
int bit_pos = i * 3;
321+
int byte_idx = bit_pos / 8;
322+
int bit_off = bit_pos % 8;
323+
block->qs[byte_idx] |= (uint8_t)(q << bit_off);
324+
/* Handle cross-byte boundary (when bit_off > 5, bits spill into next byte) */
325+
if (bit_off > 5 && byte_idx + 1 < TQ_BK * 3 / 8) {
326+
block->qs[byte_idx + 1] |= (uint8_t)(q >> (8 - bit_off));
327+
}
328+
}
329+
}
330+
331+
/* ---------- Uniform 3-bit sub-block dequantize ---------- */
332+
333+
void tq_uniform_3b_dequantize_ref(const void* src, float* dst, int n) {
334+
const block_tq_uniform_3b* block = (const block_tq_uniform_3b*)src;
335+
int count = n;
336+
if (count > TQ_BK) count = TQ_BK;
337+
338+
for (int i = 0; i < count; i++) {
339+
int sb = i / TQ_3B_SUBK;
340+
float scale = uni_fp16_to_fp32(block->sub_scale[sb]);
341+
float mn = uni_fp16_to_fp32(block->sub_min[sb]);
342+
343+
/* Extract 3-bit value */
344+
int bit_pos = i * 3;
345+
int byte_idx = bit_pos / 8;
346+
int bit_off = bit_pos % 8;
347+
int q = (block->qs[byte_idx] >> bit_off) & 0x07;
348+
if (bit_off > 5 && byte_idx + 1 < TQ_BK * 3 / 8) {
349+
q |= (block->qs[byte_idx + 1] << (8 - bit_off)) & 0x07;
350+
}
351+
352+
dst[i] = mn + ((float)q + 0.5f) * scale;
353+
}
354+
}
355+
356+
/* ---------- Uniform 3-bit attention (dequantize + dot product) ---------- */
357+
358+
void tq_uniform_3b_attention_ref(const float* query, const void* kv,
359+
float* scores, int seq_len, int head_dim) {
360+
int blocks_per_key = (head_dim + TQ_BK - 1) / TQ_BK;
361+
const block_tq_uniform_3b* all_blocks = (const block_tq_uniform_3b*)kv;
362+
363+
for (int s = 0; s < seq_len; s++) {
364+
float dot = 0;
365+
for (int b = 0; b < blocks_per_key; b++) {
366+
int offset = b * TQ_BK;
367+
int chunk = (head_dim - offset > TQ_BK) ? TQ_BK : (head_dim - offset);
368+
369+
float deq[TQ_BK];
370+
tq_uniform_3b_dequantize_ref(&all_blocks[s * blocks_per_key + b], deq, chunk);
371+
372+
for (int dd = 0; dd < chunk; dd++)
373+
dot += query[offset + dd] * deq[dd];
374+
}
375+
scores[s] = dot;
376+
}
377+
}

src/engine/tq_transformer.c

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -861,26 +861,19 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
861861
int n_heads = c->n_heads;
862862
int n_kv_heads = c->n_kv_heads;
863863

864-
/* Gemma 4 hybrid: full attention layers have different head_dim and kv_heads.
865-
* Detect from GGUF weight shapes: if Q output > n_heads * head_dim, it's a full layer. */
866-
if (model->layer_is_sliding && !model->layer_is_sliding[l] && layer->gguf_wq) {
867-
/* Full attention layer: infer head_dim from Q tensor.
868-
* Q shape = [hidden_dim, n_heads * full_head_dim * (1 + gate)] */
869-
int q_out = 0;
870-
/* Get Q output dim from GGUF tensor — stored at load time in gguf_wq_type's neighbor.
871-
* Simpler: compute from expected: global_head_dim = metadata key_length */
872-
int global_head_dim = tq_gguf_get_i32((const tq_gguf_ctx_t*)model->gguf_ctx,
873-
"gemma4.attention.key_length", head_dim);
874-
if (global_head_dim > head_dim) {
875-
head_dim = global_head_dim;
876-
/* For full layers, kv_heads is typically smaller */
877-
/* K shape for full: [dim, kv_heads_full * global_head_dim]
878-
* We know K_out from sliding kv_dim * (global/sliding) ratio... or just compute:
879-
* Total Q = n_heads * global_head_dim = 16 * 512 = 8192
880-
* Total K = ? from tensor. For now, infer: */
881-
n_kv_heads = c->n_kv_heads * c->head_dim / global_head_dim;
882-
if (n_kv_heads < 1) n_kv_heads = 1;
883-
}
864+
/* Gemma 4 hybrid: full attention layers use different head_dim and kv_heads.
865+
* Sliding layers: head_dim=256, kv_heads=8 (stored in config)
866+
* Full layers: head_dim=512, kv_heads=2
867+
* Infer full dimensions: total Q/K stays same, head_dim doubles, heads halve. */
868+
if (model->layer_is_sliding && !model->layer_is_sliding[l]) {
869+
/* Full attention layer: head_dim is 2x sliding, kv_heads is sliding/2.
870+
* Query pre-attn scalar (Gemma) also changes with head_dim. */
871+
int global_head_dim = c->head_dim * 2; /* 256 → 512 */
872+
int global_kv_heads = c->n_kv_heads * c->head_dim / global_head_dim;
873+
if (global_kv_heads < 1) global_kv_heads = 1;
874+
head_dim = global_head_dim;
875+
n_kv_heads = global_kv_heads;
876+
n_heads = n_heads; /* Q heads stay same count but with larger head_dim */
884877
}
885878

886879
int kv_dim = n_kv_heads * head_dim;

0 commit comments

Comments
 (0)