@@ -258,3 +258,120 @@ void tq_uniform_2b_attention_ref(const float* query, const void* kv,
258258 scores [s ] = dot ;
259259 }
260260}
261+
262+ /* ====================================================================
263+ * Uniform 3-bit with per-sub-block FP16 scales (Q3_K-style)
264+ *
265+ * Each 128-element block is split into 4 sub-blocks of 32 elements.
266+ * Each sub-block has independent FP16 scale and minimum, giving
267+ * excellent adaptation to local value distributions.
268+ *
269+ * 8 quantization levels (3-bit) per value.
270+ * 64 bytes / 128 elements = 4.0 bpe.
271+ *
272+ * Compared to uniform_4b (4.0 bpe, 16 levels, 1 global scale):
273+ * - Fewer levels (8 vs 16) but finer per-sub-block adaptation
274+ * - Better for heterogeneous distributions within a head dimension
275+ * ==================================================================== */
276+
277+ /* ---------- Uniform 3-bit sub-block quantize ---------- */
278+
279+ void tq_uniform_3b_quantize_ref (const float * src , void * dst , int n ) {
280+ block_tq_uniform_3b * block = (block_tq_uniform_3b * )dst ;
281+ int count = n ;
282+ if (count > TQ_BK ) count = TQ_BK ;
283+
284+ /* Compute per-sub-block min/max and store FP16 scale/min */
285+ for (int sb = 0 ; sb < TQ_3B_NSUB ; sb ++ ) {
286+ int start = sb * TQ_3B_SUBK ;
287+ int end = start + TQ_3B_SUBK ;
288+ if (end > count ) end = count ;
289+ float mn = FLT_MAX , mx = - FLT_MAX ;
290+ for (int i = start ; i < end ; i ++ ) {
291+ if (src [i ] < mn ) mn = src [i ];
292+ if (src [i ] > mx ) mx = src [i ];
293+ }
294+ if (end <= start ) { mn = 0 ; mx = 0 ; }
295+
296+ float range = mx - mn ;
297+ if (range < 1e-8f ) range = 1e-8f ;
298+ float scale = range / 8.0f ; /* 3-bit: 8 bins of width range/8 */
299+
300+ block -> sub_scale [sb ] = uni_fp32_to_fp16 (scale );
301+ block -> sub_min [sb ] = uni_fp32_to_fp16 (mn );
302+ }
303+
304+ /* Pack 3-bit quantized values into qs (LSB-first).
305+ * Use the FP16-reconstructed scale/min for quantization
306+ * to minimize encode/decode mismatch.
307+ */
308+ memset (block -> qs , 0 , TQ_BK * 3 / 8 );
309+ for (int i = 0 ; i < count ; i ++ ) {
310+ int sb = i / TQ_3B_SUBK ;
311+ float scale = uni_fp16_to_fp32 (block -> sub_scale [sb ]);
312+ float mn = uni_fp16_to_fp32 (block -> sub_min [sb ]);
313+ if (scale < 1e-10f ) scale = 1e-10f ;
314+
315+ int q = (int )floorf ((src [i ] - mn ) / scale );
316+ if (q < 0 ) q = 0 ;
317+ if (q > 7 ) q = 7 ;
318+
319+ /* 3-bit packing: element i uses bits [i*3 .. i*3+2] across qs bytes */
320+ int bit_pos = i * 3 ;
321+ int byte_idx = bit_pos / 8 ;
322+ int bit_off = bit_pos % 8 ;
323+ block -> qs [byte_idx ] |= (uint8_t )(q << bit_off );
324+ /* Handle cross-byte boundary (when bit_off > 5, bits spill into next byte) */
325+ if (bit_off > 5 && byte_idx + 1 < TQ_BK * 3 / 8 ) {
326+ block -> qs [byte_idx + 1 ] |= (uint8_t )(q >> (8 - bit_off ));
327+ }
328+ }
329+ }
330+
331+ /* ---------- Uniform 3-bit sub-block dequantize ---------- */
332+
333+ void tq_uniform_3b_dequantize_ref (const void * src , float * dst , int n ) {
334+ const block_tq_uniform_3b * block = (const block_tq_uniform_3b * )src ;
335+ int count = n ;
336+ if (count > TQ_BK ) count = TQ_BK ;
337+
338+ for (int i = 0 ; i < count ; i ++ ) {
339+ int sb = i / TQ_3B_SUBK ;
340+ float scale = uni_fp16_to_fp32 (block -> sub_scale [sb ]);
341+ float mn = uni_fp16_to_fp32 (block -> sub_min [sb ]);
342+
343+ /* Extract 3-bit value */
344+ int bit_pos = i * 3 ;
345+ int byte_idx = bit_pos / 8 ;
346+ int bit_off = bit_pos % 8 ;
347+ int q = (block -> qs [byte_idx ] >> bit_off ) & 0x07 ;
348+ if (bit_off > 5 && byte_idx + 1 < TQ_BK * 3 / 8 ) {
349+ q |= (block -> qs [byte_idx + 1 ] << (8 - bit_off )) & 0x07 ;
350+ }
351+
352+ dst [i ] = mn + ((float )q + 0.5f ) * scale ;
353+ }
354+ }
355+
356+ /* ---------- Uniform 3-bit attention (dequantize + dot product) ---------- */
357+
358+ void tq_uniform_3b_attention_ref (const float * query , const void * kv ,
359+ float * scores , int seq_len , int head_dim ) {
360+ int blocks_per_key = (head_dim + TQ_BK - 1 ) / TQ_BK ;
361+ const block_tq_uniform_3b * all_blocks = (const block_tq_uniform_3b * )kv ;
362+
363+ for (int s = 0 ; s < seq_len ; s ++ ) {
364+ float dot = 0 ;
365+ for (int b = 0 ; b < blocks_per_key ; b ++ ) {
366+ int offset = b * TQ_BK ;
367+ int chunk = (head_dim - offset > TQ_BK ) ? TQ_BK : (head_dim - offset );
368+
369+ float deq [TQ_BK ];
370+ tq_uniform_3b_dequantize_ref (& all_blocks [s * blocks_per_key + b ], deq , chunk );
371+
372+ for (int dd = 0 ; dd < chunk ; dd ++ )
373+ dot += query [offset + dd ] * deq [dd ];
374+ }
375+ scores [s ] = dot ;
376+ }
377+ }
0 commit comments