@@ -29,6 +29,7 @@ extern void tq_codebook_quantize(const float* src, uint8_t* dst_indices,
2929 int n , int bits , float inv_std );
3030extern void tq_codebook_dequantize (const uint8_t * indices , float * dst ,
3131 int n , int bits , float inv_std );
32+ extern const float * tq_codebook_centroids (int bits );
3233
3334/* ============================================================
3435 * FP16 helpers (local copies to avoid cross-module dependencies)
@@ -152,26 +153,84 @@ static void compute_qjl_signs(const float* residual, uint8_t* signs,
152153
153154static void dequant_mse_rotated_3bit_v2 (const block_tq_turbo_kv_3b * block ,
154155 float * rotated , int dim ) {
155- /* Variant F (3b): 3-bit codebook (8 levels) + max-abs scaling, no QJL */
156+ /* Variant F (3b): 3-bit codebook (8 levels) + max-abs scaling.
157+ * Single-pass fused unpack + LUT lookup + scale (Round 1 pattern). */
156158 float inv_std = tkv_fp16_to_fp32 (block -> inv_std_fp16 );
157159 if (inv_std < 1e-10f ) inv_std = sqrtf ((float )dim );
158- uint8_t indices [TQ_BK ] = {0 };
159- unpack_3bit (block -> mse_indices , indices , dim );
160- tq_codebook_dequantize (indices , rotated , dim , 3 , inv_std );
160+ float scale = 1.0f / inv_std ;
161+ const float * cb = tq_codebook_centroids (3 );
162+ float lut [8 ];
163+ for (int i = 0 ; i < 8 ; i ++ ) lut [i ] = cb [i ] * scale ;
164+ /* 3-bit packing is bit-stream LSB-first, 8 elements per 3 bytes */
165+ const uint8_t * p = block -> mse_indices ;
166+ int i = 0 ;
167+ for (; i + 7 < dim ; i += 8 ) {
168+ /* 3 bytes encode 8 indices */
169+ uint32_t w = (uint32_t )p [0 ] | ((uint32_t )p [1 ] << 8 ) | ((uint32_t )p [2 ] << 16 );
170+ rotated [i + 0 ] = lut [(w >> 0 ) & 7 ];
171+ rotated [i + 1 ] = lut [(w >> 3 ) & 7 ];
172+ rotated [i + 2 ] = lut [(w >> 6 ) & 7 ];
173+ rotated [i + 3 ] = lut [(w >> 9 ) & 7 ];
174+ rotated [i + 4 ] = lut [(w >> 12 ) & 7 ];
175+ rotated [i + 5 ] = lut [(w >> 15 ) & 7 ];
176+ rotated [i + 6 ] = lut [(w >> 18 ) & 7 ];
177+ rotated [i + 7 ] = lut [(w >> 21 ) & 7 ];
178+ p += 3 ;
179+ }
180+ /* Tail */
181+ for (; i < dim ; i ++ ) {
182+ int bit_off = i * 3 ;
183+ int byte_idx = bit_off / 8 ;
184+ int bit_pos = bit_off % 8 ;
185+ uint16_t v = block -> mse_indices [byte_idx ];
186+ if (bit_pos > 5 && byte_idx + 1 < (dim * 3 + 7 ) / 8 ) {
187+ v |= (uint16_t )block -> mse_indices [byte_idx + 1 ] << 8 ;
188+ }
189+ rotated [i ] = lut [(v >> bit_pos ) & 7 ];
190+ }
161191}
162192
163193static void dequant_mse_rotated_4bit_v2 (const block_tq_turbo_kv_4b * block ,
164194 float * rotated , int dim ) {
165- /* Variant F: 4-bit codebook (16 levels) + max-abs scaling */
195+ /* Variant F: 4-bit codebook (16 levels) + max-abs scaling.
196+ *
197+ * Single-pass fused unpack + codebook lookup + scale.
198+ * Pre-multiply the per-block scale into a local 16-entry table so
199+ * the inner loop is one byte load + two table lookups + two stores.
200+ */
166201 float inv_std = tkv_fp16_to_fp32 (block -> inv_std_fp16 );
167202 if (inv_std < 1e-10f ) inv_std = sqrtf ((float )dim );
168- /* Unpack 4-bit indices: 2 per byte, LSB-first */
169- uint8_t indices [TQ_BK ];
170- for (int i = 0 ; i < dim ; i ++ ) {
171- uint8_t b = block -> mse_indices [i / 2 ];
172- indices [i ] = (i & 1 ) ? (b >> 4 ) : (b & 0x0F );
203+ float scale = 1.0f / inv_std ;
204+
205+ /* Pre-scaled local codebook (16 entries) */
206+ const float * cb = tq_codebook_centroids (4 );
207+ float lut [16 ];
208+ for (int i = 0 ; i < 16 ; i ++ ) lut [i ] = cb [i ] * scale ;
209+
210+ const uint8_t * mi = block -> mse_indices ;
211+ /* Process 2 elements per byte, unrolled by 2 bytes per iteration */
212+ int i = 0 ;
213+ int byte_n = dim / 2 ;
214+ for (int b = 0 ; b + 1 < byte_n ; b += 2 ) {
215+ uint8_t b0 = mi [b ];
216+ uint8_t b1 = mi [b + 1 ];
217+ rotated [i + 0 ] = lut [b0 & 0x0F ];
218+ rotated [i + 1 ] = lut [b0 >> 4 ];
219+ rotated [i + 2 ] = lut [b1 & 0x0F ];
220+ rotated [i + 3 ] = lut [b1 >> 4 ];
221+ i += 4 ;
222+ }
223+ for (int b = i / 2 ; b < byte_n ; b ++ ) {
224+ uint8_t bv = mi [b ];
225+ rotated [i + 0 ] = lut [bv & 0x0F ];
226+ rotated [i + 1 ] = lut [bv >> 4 ];
227+ i += 2 ;
228+ }
229+ /* Tail (odd dim) */
230+ if (i < dim ) {
231+ uint8_t bv = mi [i / 2 ];
232+ rotated [i ] = lut [bv & 0x0F ];
173233 }
174- tq_codebook_dequantize (indices , rotated , dim , 4 , inv_std );
175234}
176235
177236/* ============================================================
@@ -410,41 +469,46 @@ void tq_turbo_kv_4b_attention_ref(const float* query, const void* kv_cache,
410469 for (int i = dim ; i < TQ_BK ; i ++ ) q_rot [i ] = 0.0f ;
411470 tq_rht_transform (q_rot , dim , TKV_DEFAULT_SEED );
412471
472+ /* Hoist codebook pointer (constant for all blocks) */
473+ const float * cb = tq_codebook_centroids (4 );
474+
413475 for (int seq = 0 ; seq < seq_len ; seq ++ ) {
414476 const block_tq_turbo_kv_4b * block = & blocks_4b [seq ];
415477 float norm = tkv_fp16_to_fp32 (block -> norm );
416-
417- float rotated [TQ_BK ];
418- dequant_mse_rotated_4bit_v2 (block , rotated , dim );
419-
420- float mse_dot = 0.0f ;
421- #ifdef __ARM_NEON
422- {
423- float32x4_t acc0 = vdupq_n_f32 (0.0f );
424- float32x4_t acc1 = vdupq_n_f32 (0.0f );
425- float32x4_t acc2 = vdupq_n_f32 (0.0f );
426- float32x4_t acc3 = vdupq_n_f32 (0.0f );
427- int d = 0 ;
428- for (; d + 15 < dim ; d += 16 ) {
429- acc0 = vfmaq_f32 (acc0 , vld1q_f32 (& q_rot [d ]), vld1q_f32 (& rotated [d ]));
430- acc1 = vfmaq_f32 (acc1 , vld1q_f32 (& q_rot [d + 4 ]), vld1q_f32 (& rotated [d + 4 ]));
431- acc2 = vfmaq_f32 (acc2 , vld1q_f32 (& q_rot [d + 8 ]), vld1q_f32 (& rotated [d + 8 ]));
432- acc3 = vfmaq_f32 (acc3 , vld1q_f32 (& q_rot [d + 12 ]), vld1q_f32 (& rotated [d + 12 ]));
433- }
434- acc0 = vaddq_f32 (vaddq_f32 (acc0 , acc1 ), vaddq_f32 (acc2 , acc3 ));
435- for (; d + 3 < dim ; d += 4 ) {
436- acc0 = vfmaq_f32 (acc0 , vld1q_f32 (& q_rot [d ]), vld1q_f32 (& rotated [d ]));
437- }
438- mse_dot = vaddvq_f32 (acc0 );
439- for (; d < dim ; d ++ ) {
440- mse_dot += q_rot [d ] * rotated [d ];
441- }
478+ float inv_std = tkv_fp16_to_fp32 (block -> inv_std_fp16 );
479+ if (inv_std < 1e-10f ) inv_std = sqrtf ((float )dim );
480+ float scale = 1.0f / inv_std ;
481+
482+ /* Per-block pre-scaled LUT (16 floats, fits in 64 bytes — L1 hot) */
483+ float lut [16 ];
484+ for (int j = 0 ; j < 16 ; j ++ ) lut [j ] = cb [j ] * scale ;
485+
486+ /* Round 4: fused scalar dequant + dot product, 4 accumulators.
487+ * Eliminates the rotated[] intermediate buffer entirely.
488+ * Apple Silicon's 6 ALUs + L1-hot LUT make scalar gather fast. */
489+ const uint8_t * mi = block -> mse_indices ;
490+ float a0 = 0 , a1 = 0 , a2 = 0 , a3 = 0 ;
491+ int d = 0 ;
492+ for (; d + 7 < dim ; d += 8 ) {
493+ uint8_t b0 = mi [d / 2 + 0 ];
494+ uint8_t b1 = mi [d / 2 + 1 ];
495+ uint8_t b2 = mi [d / 2 + 2 ];
496+ uint8_t b3 = mi [d / 2 + 3 ];
497+ a0 += q_rot [d + 0 ] * lut [b0 & 0x0F ];
498+ a1 += q_rot [d + 1 ] * lut [b0 >> 4 ];
499+ a2 += q_rot [d + 2 ] * lut [b1 & 0x0F ];
500+ a3 += q_rot [d + 3 ] * lut [b1 >> 4 ];
501+ a0 += q_rot [d + 4 ] * lut [b2 & 0x0F ];
502+ a1 += q_rot [d + 5 ] * lut [b2 >> 4 ];
503+ a2 += q_rot [d + 6 ] * lut [b3 & 0x0F ];
504+ a3 += q_rot [d + 7 ] * lut [b3 >> 4 ];
442505 }
443- #else
444- for (int d = 0 ; d < dim ; d ++ ) {
445- mse_dot += q_rot [d ] * rotated [d ];
506+ float mse_dot = (a0 + a1 ) + (a2 + a3 );
507+ for (; d < dim ; d ++ ) {
508+ uint8_t bv = mi [d / 2 ];
509+ int idx = (d & 1 ) ? (bv >> 4 ) : (bv & 0x0F );
510+ mse_dot += q_rot [d ] * lut [idx ];
446511 }
447- #endif
448512
449513 scores [seq ] = norm * mse_dot ;
450514 }
@@ -985,11 +1049,43 @@ void tq_turbo_kv_5b_quantize_ref(const float* src, void* dst, int n) {
9851049
9861050static void dequant_mse_rotated_5bit (const block_tq_turbo_kv_5b * block ,
9871051 float * rotated , int dim ) {
1052+ /* Single-pass fused unpack + LUT lookup + scale (Round 1 pattern).
1053+ * 5-bit packing: 5 bytes encode 8 indices (40 bits). */
9881054 float inv_std = tkv_fp16_to_fp32 (block -> inv_std_fp16 );
9891055 if (inv_std < 1e-10f ) inv_std = sqrtf ((float )dim );
990- uint8_t indices [TQ_BK ] = {0 };
991- unpack_5bit (block -> mse_indices , indices , dim );
992- tq_codebook_dequantize (indices , rotated , dim , 5 , inv_std );
1056+ float scale = 1.0f / inv_std ;
1057+ const float * cb = tq_codebook_centroids (5 );
1058+ float lut [32 ];
1059+ for (int i = 0 ; i < 32 ; i ++ ) lut [i ] = cb [i ] * scale ;
1060+ const uint8_t * p = block -> mse_indices ;
1061+ int i = 0 ;
1062+ for (; i + 7 < dim ; i += 8 ) {
1063+ uint64_t w = (uint64_t )p [0 ]
1064+ | ((uint64_t )p [1 ] << 8 )
1065+ | ((uint64_t )p [2 ] << 16 )
1066+ | ((uint64_t )p [3 ] << 24 )
1067+ | ((uint64_t )p [4 ] << 32 );
1068+ rotated [i + 0 ] = lut [(w >> 0 ) & 31 ];
1069+ rotated [i + 1 ] = lut [(w >> 5 ) & 31 ];
1070+ rotated [i + 2 ] = lut [(w >> 10 ) & 31 ];
1071+ rotated [i + 3 ] = lut [(w >> 15 ) & 31 ];
1072+ rotated [i + 4 ] = lut [(w >> 20 ) & 31 ];
1073+ rotated [i + 5 ] = lut [(w >> 25 ) & 31 ];
1074+ rotated [i + 6 ] = lut [(w >> 30 ) & 31 ];
1075+ rotated [i + 7 ] = lut [(w >> 35 ) & 31 ];
1076+ p += 5 ;
1077+ }
1078+ /* Tail (slow path for non-multiple-of-8 dims) */
1079+ for (; i < dim ; i ++ ) {
1080+ int bit_off = i * 5 ;
1081+ int byte_idx = bit_off / 8 ;
1082+ int bit_pos = bit_off % 8 ;
1083+ uint16_t v = block -> mse_indices [byte_idx ];
1084+ if (bit_pos > 3 && byte_idx + 1 < (dim * 5 + 7 ) / 8 ) {
1085+ v |= (uint16_t )block -> mse_indices [byte_idx + 1 ] << 8 ;
1086+ }
1087+ rotated [i ] = lut [(v >> bit_pos ) & 31 ];
1088+ }
9931089}
9941090
9951091void tq_turbo_kv_5b_dequantize_ref (const void * src , float * dst , int n ) {
@@ -1150,15 +1246,36 @@ void tq_turbo_kv_4bo_quantize_ref(const float* src, void* dst, int n) {
11501246
11511247static void dequant_mse_rotated_4bo (const block_tq_turbo_kv_4bo * block ,
11521248 float * rotated , int dim ) {
1153- /* 4-bit codebook lookup */
1249+ /* Single-pass fused unpack + LUT lookup + scale (Round 1 pattern) */
11541250 float inv_std = tkv_fp16_to_fp32 (block -> inv_std_fp16 );
11551251 if (inv_std < 1e-10f ) inv_std = sqrtf ((float )dim );
1156- uint8_t indices [TQ_BK ];
1157- for (int i = 0 ; i < dim ; i ++ ) {
1158- uint8_t b = block -> mse_indices [i / 2 ];
1159- indices [i ] = (i & 1 ) ? (b >> 4 ) : (b & 0x0F );
1252+ float scale = 1.0f / inv_std ;
1253+ const float * cb = tq_codebook_centroids (4 );
1254+ float lut [16 ];
1255+ for (int i = 0 ; i < 16 ; i ++ ) lut [i ] = cb [i ] * scale ;
1256+
1257+ const uint8_t * mi = block -> mse_indices ;
1258+ int byte_n = dim / 2 ;
1259+ int i = 0 ;
1260+ for (int b = 0 ; b + 1 < byte_n ; b += 2 ) {
1261+ uint8_t b0 = mi [b ];
1262+ uint8_t b1 = mi [b + 1 ];
1263+ rotated [i + 0 ] = lut [b0 & 0x0F ];
1264+ rotated [i + 1 ] = lut [b0 >> 4 ];
1265+ rotated [i + 2 ] = lut [b1 & 0x0F ];
1266+ rotated [i + 3 ] = lut [b1 >> 4 ];
1267+ i += 4 ;
1268+ }
1269+ for (int b = i / 2 ; b < byte_n ; b ++ ) {
1270+ uint8_t bv = mi [b ];
1271+ rotated [i + 0 ] = lut [bv & 0x0F ];
1272+ rotated [i + 1 ] = lut [bv >> 4 ];
1273+ i += 2 ;
1274+ }
1275+ if (i < dim ) {
1276+ uint8_t bv = mi [i / 2 ];
1277+ rotated [i ] = lut [bv & 0x0F ];
11601278 }
1161- tq_codebook_dequantize (indices , rotated , dim , 4 , inv_std );
11621279
11631280 /* Overwrite outlier positions with stored exact FP16 values */
11641281 int K = TQ_KV_4BO_OUTLIERS ;
@@ -1305,11 +1422,37 @@ void tq_turbo_kv_3bo_quantize_ref(const float* src, void* dst, int n) {
13051422
13061423static void dequant_mse_rotated_3bo (const block_tq_turbo_kv_3bo * block ,
13071424 float * rotated , int dim ) {
1425+ /* Single-pass fused unpack + LUT lookup + scale (Round 1 pattern) */
13081426 float inv_std = tkv_fp16_to_fp32 (block -> inv_std_fp16 );
13091427 if (inv_std < 1e-10f ) inv_std = sqrtf ((float )dim );
1310- uint8_t indices [TQ_BK ] = {0 };
1311- unpack_3bit (block -> mse_indices , indices , dim );
1312- tq_codebook_dequantize (indices , rotated , dim , 3 , inv_std );
1428+ float scale = 1.0f / inv_std ;
1429+ const float * cb = tq_codebook_centroids (3 );
1430+ float lut [8 ];
1431+ for (int i = 0 ; i < 8 ; i ++ ) lut [i ] = cb [i ] * scale ;
1432+ const uint8_t * p = block -> mse_indices ;
1433+ int i = 0 ;
1434+ for (; i + 7 < dim ; i += 8 ) {
1435+ uint32_t w = (uint32_t )p [0 ] | ((uint32_t )p [1 ] << 8 ) | ((uint32_t )p [2 ] << 16 );
1436+ rotated [i + 0 ] = lut [(w >> 0 ) & 7 ];
1437+ rotated [i + 1 ] = lut [(w >> 3 ) & 7 ];
1438+ rotated [i + 2 ] = lut [(w >> 6 ) & 7 ];
1439+ rotated [i + 3 ] = lut [(w >> 9 ) & 7 ];
1440+ rotated [i + 4 ] = lut [(w >> 12 ) & 7 ];
1441+ rotated [i + 5 ] = lut [(w >> 15 ) & 7 ];
1442+ rotated [i + 6 ] = lut [(w >> 18 ) & 7 ];
1443+ rotated [i + 7 ] = lut [(w >> 21 ) & 7 ];
1444+ p += 3 ;
1445+ }
1446+ for (; i < dim ; i ++ ) {
1447+ int bit_off = i * 3 ;
1448+ int byte_idx = bit_off / 8 ;
1449+ int bit_pos = bit_off % 8 ;
1450+ uint16_t v = block -> mse_indices [byte_idx ];
1451+ if (bit_pos > 5 && byte_idx + 1 < (dim * 3 + 7 ) / 8 ) {
1452+ v |= (uint16_t )block -> mse_indices [byte_idx + 1 ] << 8 ;
1453+ }
1454+ rotated [i ] = lut [(v >> bit_pos ) & 7 ];
1455+ }
13131456
13141457 int K = TQ_KV_4BO_OUTLIERS ;
13151458 for (int k = 0 ; k < K ; k ++ ) {
0 commit comments