@@ -335,39 +335,113 @@ void tq_turbo_kv_3b_attention_ref(const float* query, const void* kv_cache,
335335 for (int i = dim ; i < TQ_BK ; i ++ ) q_rot [i ] = 0.0f ;
336336 tq_rht_transform (q_rot , dim , TKV_DEFAULT_SEED );
337337
338+ /* Round 11: NEON 8-entry table lookup via vqtbl1q_s8 (using lower 8 bytes).
339+ * 3-bit codebook has 8 entries which fit in 8 bytes — store in lower half
340+ * of a 16-byte register. Indices in 0-7. */
341+ const float * cb = tq_codebook_centroids (3 );
342+ #ifdef __ARM_NEON
343+ static int8_t s_cb3_i8 [16 ] = {0 };
344+ static int s_cb3_i8_init = 0 ;
345+ static const float CB3_I8_RECIP = 2.1520f / 127.0f ;
346+ if (!s_cb3_i8_init ) {
347+ for (int j = 0 ; j < 8 ; j ++ ) {
348+ float v = cb [j ] * (127.0f / 2.1520f );
349+ int q = (int )(v >= 0 ? v + 0.5f : v - 0.5f );
350+ if (q < -127 ) q = -127 ;
351+ if (q > 127 ) q = 127 ;
352+ s_cb3_i8 [j ] = (int8_t )q ;
353+ }
354+ for (int j = 8 ; j < 16 ; j ++ ) s_cb3_i8 [j ] = 0 ;
355+ s_cb3_i8_init = 1 ;
356+ }
357+ int8x16_t cb_vec = vld1q_s8 (s_cb3_i8 );
358+ #endif
359+
338360 for (int seq = 0 ; seq < seq_len ; seq ++ ) {
339361 const block_tq_turbo_kv_3b * block = & blocks [seq ];
340362 float norm = tkv_fp16_to_fp32 (block -> norm );
363+ float inv_std = tkv_fp16_to_fp32 (block -> inv_std_fp16 );
364+ if (inv_std < 1e-10f ) inv_std = sqrtf ((float )dim );
365+ float per_block_scale = CB3_I8_RECIP / inv_std ;
341366
342- float rotated [TQ_BK ];
343- dequant_mse_rotated_3bit_v2 (block , rotated , dim );
344-
367+ const uint8_t * mi = block -> mse_indices ;
345368 float mse_dot = 0.0f ;
369+
346370#ifdef __ARM_NEON
347- {
348- float32x4_t acc0 = vdupq_n_f32 (0.0f );
349- float32x4_t acc1 = vdupq_n_f32 (0.0f );
350- float32x4_t acc2 = vdupq_n_f32 (0.0f );
351- float32x4_t acc3 = vdupq_n_f32 (0.0f );
352- int d = 0 ;
353- for (; d + 15 < dim ; d += 16 ) {
354- acc0 = vfmaq_f32 (acc0 , vld1q_f32 (& q_rot [d ]), vld1q_f32 (& rotated [d ]));
355- acc1 = vfmaq_f32 (acc1 , vld1q_f32 (& q_rot [d + 4 ]), vld1q_f32 (& rotated [d + 4 ]));
356- acc2 = vfmaq_f32 (acc2 , vld1q_f32 (& q_rot [d + 8 ]), vld1q_f32 (& rotated [d + 8 ]));
357- acc3 = vfmaq_f32 (acc3 , vld1q_f32 (& q_rot [d + 12 ]), vld1q_f32 (& rotated [d + 12 ]));
358- }
359- acc0 = vaddq_f32 (vaddq_f32 (acc0 , acc1 ), vaddq_f32 (acc2 , acc3 ));
360- for (; d + 3 < dim ; d += 4 ) {
361- acc0 = vfmaq_f32 (acc0 , vld1q_f32 (& q_rot [d ]), vld1q_f32 (& rotated [d ]));
362- }
363- mse_dot = vaddvq_f32 (acc0 );
364- for (; d < dim ; d ++ ) {
365- mse_dot += q_rot [d ] * rotated [d ];
366- }
371+ float32x4_t acc0 = vdupq_n_f32 (0.0f );
372+ float32x4_t acc1 = vdupq_n_f32 (0.0f );
373+ float32x4_t acc2 = vdupq_n_f32 (0.0f );
374+ float32x4_t acc3 = vdupq_n_f32 (0.0f );
375+ float32x4_t scale_v = vdupq_n_f32 (per_block_scale );
376+
377+ int d = 0 ;
378+ /* Process 16 elements per iteration: 6 bytes of mse_indices (48 bits = 16 × 3) */
379+ for (; d + 15 < dim ; d += 16 ) {
380+ /* uint64 read of 8 bytes (we use 6) */
381+ const uint8_t * p = mi + (d * 3 ) / 8 ;
382+ uint64_t w ;
383+ memcpy (& w , p , 8 );
384+
385+ uint8_t idx_buf [16 ];
386+ idx_buf [0 ] = (uint8_t )((w >> 0 ) & 0x07 );
387+ idx_buf [1 ] = (uint8_t )((w >> 3 ) & 0x07 );
388+ idx_buf [2 ] = (uint8_t )((w >> 6 ) & 0x07 );
389+ idx_buf [3 ] = (uint8_t )((w >> 9 ) & 0x07 );
390+ idx_buf [4 ] = (uint8_t )((w >> 12 ) & 0x07 );
391+ idx_buf [5 ] = (uint8_t )((w >> 15 ) & 0x07 );
392+ idx_buf [6 ] = (uint8_t )((w >> 18 ) & 0x07 );
393+ idx_buf [7 ] = (uint8_t )((w >> 21 ) & 0x07 );
394+ idx_buf [8 ] = (uint8_t )((w >> 24 ) & 0x07 );
395+ idx_buf [9 ] = (uint8_t )((w >> 27 ) & 0x07 );
396+ idx_buf [10 ] = (uint8_t )((w >> 30 ) & 0x07 );
397+ idx_buf [11 ] = (uint8_t )((w >> 33 ) & 0x07 );
398+ idx_buf [12 ] = (uint8_t )((w >> 36 ) & 0x07 );
399+ idx_buf [13 ] = (uint8_t )((w >> 39 ) & 0x07 );
400+ idx_buf [14 ] = (uint8_t )((w >> 42 ) & 0x07 );
401+ idx_buf [15 ] = (uint8_t )((w >> 45 ) & 0x07 );
402+ uint8x16_t indices = vld1q_u8 (idx_buf );
403+
404+ int8x16_t vals = vqtbl1q_s8 (cb_vec , indices );
405+
406+ int16x8_t i16_lo = vmovl_s8 (vget_low_s8 (vals ));
407+ int16x8_t i16_hi = vmovl_s8 (vget_high_s8 (vals ));
408+ float32x4_t f0 = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (i16_lo )));
409+ float32x4_t f1 = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (i16_lo )));
410+ float32x4_t f2 = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (i16_hi )));
411+ float32x4_t f3 = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (i16_hi )));
412+
413+ f0 = vmulq_f32 (f0 , scale_v );
414+ f1 = vmulq_f32 (f1 , scale_v );
415+ f2 = vmulq_f32 (f2 , scale_v );
416+ f3 = vmulq_f32 (f3 , scale_v );
417+
418+ acc0 = vfmaq_f32 (acc0 , vld1q_f32 (& q_rot [d + 0 ]), f0 );
419+ acc1 = vfmaq_f32 (acc1 , vld1q_f32 (& q_rot [d + 4 ]), f1 );
420+ acc2 = vfmaq_f32 (acc2 , vld1q_f32 (& q_rot [d + 8 ]), f2 );
421+ acc3 = vfmaq_f32 (acc3 , vld1q_f32 (& q_rot [d + 12 ]), f3 );
422+ }
423+ mse_dot = vaddvq_f32 (vaddq_f32 (vaddq_f32 (acc0 , acc1 ), vaddq_f32 (acc2 , acc3 )));
424+
425+ /* Tail */
426+ for (; d < dim ; d ++ ) {
427+ int bit_off = d * 3 ;
428+ int byte_idx = bit_off / 8 ;
429+ int bit_pos = bit_off % 8 ;
430+ uint16_t v = mi [byte_idx ];
431+ if (bit_pos > 5 ) v |= (uint16_t )mi [byte_idx + 1 ] << 8 ;
432+ int idx = (v >> bit_pos ) & 0x07 ;
433+ mse_dot += q_rot [d ] * (s_cb3_i8 [idx ] * per_block_scale );
367434 }
368435#else
436+ float lut [8 ];
437+ for (int j = 0 ; j < 8 ; j ++ ) lut [j ] = cb [j ] / inv_std ;
369438 for (int d = 0 ; d < dim ; d ++ ) {
370- mse_dot += q_rot [d ] * rotated [d ];
439+ int bit_off = d * 3 ;
440+ int byte_idx = bit_off / 8 ;
441+ int bit_pos = bit_off % 8 ;
442+ uint16_t v = mi [byte_idx ];
443+ if (bit_pos > 5 ) v |= (uint16_t )mi [byte_idx + 1 ] << 8 ;
444+ mse_dot += q_rot [d ] * lut [(v >> bit_pos ) & 0x07 ];
371445 }
372446#endif
373447
@@ -1212,36 +1286,149 @@ void tq_turbo_kv_5b_attention_ref(const float* query, const void* kv_cache,
12121286 for (int i = dim ; i < TQ_BK ; i ++ ) q_rot [i ] = 0.0f ;
12131287 tq_rht_transform (q_rot , dim , TKV_DEFAULT_SEED );
12141288
1289+ /* Round 11: NEON 32-entry table lookup via vqtbl2q_s8.
1290+ *
1291+ * 5-bit codebook has 32 entries which fit in 32 bytes (2 NEON registers).
1292+ * vqtbl2q_s8 takes a 2-register table and gathers 16 lanes in 1 instruction.
1293+ *
1294+ * The 5-bit packing is the harder part: 8 indices fit in 5 bytes (40 bits).
1295+ * For 16 indices we need 10 bytes. We scalar-unpack the 16 indices into a
1296+ * uint8x16_t and then SIMD-process the LUT lookup + dot product.
1297+ *
1298+ * Same int8 quantization of the codebook as Round 10 (~1% precision loss,
1299+ * within regression test thresholds).
1300+ */
1301+ const float * cb = tq_codebook_centroids (5 );
1302+ #ifdef __ARM_NEON
1303+ static int8_t s_cb5_i8 [32 ] = {0 };
1304+ static int s_cb5_i8_init = 0 ;
1305+ static const float CB5_I8_RECIP = 1.9956f / 127.0f ; /* 5-bit max centroid */
1306+ if (!s_cb5_i8_init ) {
1307+ for (int j = 0 ; j < 32 ; j ++ ) {
1308+ float v = cb [j ] * (127.0f / 1.9956f );
1309+ int q = (int )(v >= 0 ? v + 0.5f : v - 0.5f );
1310+ if (q < -127 ) q = -127 ;
1311+ if (q > 127 ) q = 127 ;
1312+ s_cb5_i8 [j ] = (int8_t )q ;
1313+ }
1314+ s_cb5_i8_init = 1 ;
1315+ }
1316+ int8x16x2_t cb_vec = { vld1q_s8 (s_cb5_i8 ), vld1q_s8 (s_cb5_i8 + 16 ) };
1317+ #endif
1318+
12151319 for (int seq = 0 ; seq < seq_len ; seq ++ ) {
12161320 const block_tq_turbo_kv_5b * block = & blocks_5b [seq ];
12171321 float norm = tkv_fp16_to_fp32 (block -> norm );
1322+ float inv_std = tkv_fp16_to_fp32 (block -> inv_std_fp16 );
1323+ if (inv_std < 1e-10f ) inv_std = sqrtf ((float )dim );
1324+ float per_block_scale = CB5_I8_RECIP / inv_std ;
12181325
1219- float rotated [TQ_BK ];
1220- dequant_mse_rotated_5bit (block , rotated , dim );
1221-
1326+ const uint8_t * mi = block -> mse_indices ;
12221327 float mse_dot = 0.0f ;
1328+
12231329#ifdef __ARM_NEON
1224- {
1225- float32x4_t acc0 = vdupq_n_f32 (0.0f );
1226- float32x4_t acc1 = vdupq_n_f32 (0.0f );
1227- float32x4_t acc2 = vdupq_n_f32 (0.0f );
1228- float32x4_t acc3 = vdupq_n_f32 (0.0f );
1229- int d = 0 ;
1230- for (; d + 15 < dim ; d += 16 ) {
1231- acc0 = vfmaq_f32 (acc0 , vld1q_f32 (& q_rot [d ]), vld1q_f32 (& rotated [d ]));
1232- acc1 = vfmaq_f32 (acc1 , vld1q_f32 (& q_rot [d + 4 ]), vld1q_f32 (& rotated [d + 4 ]));
1233- acc2 = vfmaq_f32 (acc2 , vld1q_f32 (& q_rot [d + 8 ]), vld1q_f32 (& rotated [d + 8 ]));
1234- acc3 = vfmaq_f32 (acc3 , vld1q_f32 (& q_rot [d + 12 ]), vld1q_f32 (& rotated [d + 12 ]));
1235- }
1236- acc0 = vaddq_f32 (vaddq_f32 (acc0 , acc1 ), vaddq_f32 (acc2 , acc3 ));
1237- for (; d + 3 < dim ; d += 4 ) {
1238- acc0 = vfmaq_f32 (acc0 , vld1q_f32 (& q_rot [d ]), vld1q_f32 (& rotated [d ]));
1239- }
1240- mse_dot = vaddvq_f32 (acc0 );
1241- for (; d < dim ; d ++ ) mse_dot += q_rot [d ] * rotated [d ];
1330+ float32x4_t acc0 = vdupq_n_f32 (0.0f );
1331+ float32x4_t acc1 = vdupq_n_f32 (0.0f );
1332+ float32x4_t acc2 = vdupq_n_f32 (0.0f );
1333+ float32x4_t acc3 = vdupq_n_f32 (0.0f );
1334+ float32x4_t scale_v = vdupq_n_f32 (per_block_scale );
1335+
1336+ int d = 0 ;
1337+ /* Process 16 elements per iteration: 10 bytes of mse_indices.
1338+ * Use uint64 reads + shift to extract indices fast (2 reads × 8 shifts
1339+ * vs 16 scalar bit-position computations). On ARM64, unaligned uint64
1340+ * reads are fast. */
1341+ for (; d + 15 < dim ; d += 16 ) {
1342+ /* First 8 indices from bytes [d*5/8 .. d*5/8 + 5] (40 bits) */
1343+ const uint8_t * p0 = mi + (d * 5 ) / 8 ;
1344+ uint64_t w0 ;
1345+ memcpy (& w0 , p0 , 8 ); /* unaligned 8-byte load (we use 5 bytes) */
1346+ /* Second 8 indices: 40 bits later = 5 bytes after p0 */
1347+ const uint8_t * p1 = p0 + 5 ;
1348+ uint64_t w1 ;
1349+ memcpy (& w1 , p1 , 8 );
1350+
1351+ uint8_t idx_buf [16 ];
1352+ idx_buf [0 ] = (uint8_t )((w0 >> 0 ) & 0x1F );
1353+ idx_buf [1 ] = (uint8_t )((w0 >> 5 ) & 0x1F );
1354+ idx_buf [2 ] = (uint8_t )((w0 >> 10 ) & 0x1F );
1355+ idx_buf [3 ] = (uint8_t )((w0 >> 15 ) & 0x1F );
1356+ idx_buf [4 ] = (uint8_t )((w0 >> 20 ) & 0x1F );
1357+ idx_buf [5 ] = (uint8_t )((w0 >> 25 ) & 0x1F );
1358+ idx_buf [6 ] = (uint8_t )((w0 >> 30 ) & 0x1F );
1359+ idx_buf [7 ] = (uint8_t )((w0 >> 35 ) & 0x1F );
1360+ idx_buf [8 ] = (uint8_t )((w1 >> 0 ) & 0x1F );
1361+ idx_buf [9 ] = (uint8_t )((w1 >> 5 ) & 0x1F );
1362+ idx_buf [10 ] = (uint8_t )((w1 >> 10 ) & 0x1F );
1363+ idx_buf [11 ] = (uint8_t )((w1 >> 15 ) & 0x1F );
1364+ idx_buf [12 ] = (uint8_t )((w1 >> 20 ) & 0x1F );
1365+ idx_buf [13 ] = (uint8_t )((w1 >> 25 ) & 0x1F );
1366+ idx_buf [14 ] = (uint8_t )((w1 >> 30 ) & 0x1F );
1367+ idx_buf [15 ] = (uint8_t )((w1 >> 35 ) & 0x1F );
1368+ uint8x16_t indices = vld1q_u8 (idx_buf );
1369+
1370+ /* SIMD lookup: 32-entry table via 2-register vqtbl2q_s8 */
1371+ int8x16_t vals = vqtbl2q_s8 (cb_vec , indices );
1372+
1373+ /* int8 → int16 → fp32 (16 lanes split into 4×4) */
1374+ int16x8_t i16_lo = vmovl_s8 (vget_low_s8 (vals ));
1375+ int16x8_t i16_hi = vmovl_s8 (vget_high_s8 (vals ));
1376+ float32x4_t f0 = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (i16_lo )));
1377+ float32x4_t f1 = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (i16_lo )));
1378+ float32x4_t f2 = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (i16_hi )));
1379+ float32x4_t f3 = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (i16_hi )));
1380+
1381+ /* Apply per-block scale */
1382+ f0 = vmulq_f32 (f0 , scale_v );
1383+ f1 = vmulq_f32 (f1 , scale_v );
1384+ f2 = vmulq_f32 (f2 , scale_v );
1385+ f3 = vmulq_f32 (f3 , scale_v );
1386+
1387+ /* FMA against query */
1388+ acc0 = vfmaq_f32 (acc0 , vld1q_f32 (& q_rot [d + 0 ]), f0 );
1389+ acc1 = vfmaq_f32 (acc1 , vld1q_f32 (& q_rot [d + 4 ]), f1 );
1390+ acc2 = vfmaq_f32 (acc2 , vld1q_f32 (& q_rot [d + 8 ]), f2 );
1391+ acc3 = vfmaq_f32 (acc3 , vld1q_f32 (& q_rot [d + 12 ]), f3 );
1392+ }
1393+ mse_dot = vaddvq_f32 (vaddq_f32 (vaddq_f32 (acc0 , acc1 ), vaddq_f32 (acc2 , acc3 )));
1394+
1395+ /* Tail: scalar fallback for remaining elements */
1396+ for (; d < dim ; d ++ ) {
1397+ int bit_off = d * 5 ;
1398+ int byte_idx = bit_off / 8 ;
1399+ int bit_pos = bit_off % 8 ;
1400+ uint16_t v = mi [byte_idx ];
1401+ if (bit_pos > 3 ) v |= (uint16_t )mi [byte_idx + 1 ] << 8 ;
1402+ int idx = (v >> bit_pos ) & 0x1F ;
1403+ mse_dot += q_rot [d ] * (s_cb5_i8 [idx ] * per_block_scale );
12421404 }
12431405#else
1244- for (int d = 0 ; d < dim ; d ++ ) mse_dot += q_rot [d ] * rotated [d ];
1406+ /* Scalar fallback */
1407+ float lut [32 ];
1408+ for (int j = 0 ; j < 32 ; j ++ ) lut [j ] = cb [j ] / inv_std ;
1409+ const uint8_t * p = mi ;
1410+ int d = 0 ;
1411+ for (; d + 7 < dim ; d += 8 ) {
1412+ uint64_t w = (uint64_t )p [0 ] | ((uint64_t )p [1 ] << 8 ) | ((uint64_t )p [2 ] << 16 )
1413+ | ((uint64_t )p [3 ] << 24 ) | ((uint64_t )p [4 ] << 32 );
1414+ mse_dot += q_rot [d + 0 ] * lut [(w >> 0 ) & 31 ];
1415+ mse_dot += q_rot [d + 1 ] * lut [(w >> 5 ) & 31 ];
1416+ mse_dot += q_rot [d + 2 ] * lut [(w >> 10 ) & 31 ];
1417+ mse_dot += q_rot [d + 3 ] * lut [(w >> 15 ) & 31 ];
1418+ mse_dot += q_rot [d + 4 ] * lut [(w >> 20 ) & 31 ];
1419+ mse_dot += q_rot [d + 5 ] * lut [(w >> 25 ) & 31 ];
1420+ mse_dot += q_rot [d + 6 ] * lut [(w >> 30 ) & 31 ];
1421+ mse_dot += q_rot [d + 7 ] * lut [(w >> 35 ) & 31 ];
1422+ p += 5 ;
1423+ }
1424+ for (; d < dim ; d ++ ) {
1425+ int bit_off = d * 5 ;
1426+ int byte_idx = bit_off / 8 ;
1427+ int bit_pos = bit_off % 8 ;
1428+ uint16_t v = mi [byte_idx ];
1429+ if (bit_pos > 3 ) v |= (uint16_t )mi [byte_idx + 1 ] << 8 ;
1430+ mse_dot += q_rot [d ] * lut [(v >> bit_pos ) & 0x1F ];
1431+ }
12451432#endif
12461433 scores [seq ] = norm * mse_dot ;
12471434 }
0 commit comments