Skip to content

Commit 2dcbde4

Browse files
unamedkrclaude
andcommitted
v0.8.0 [1/2]: AVX2 port of turbo_kv attention kernels
Mirrors the Round 10/11 NEON tbl breakthrough on x86 AVX2 for all four turbo_kv attention variants: 4b : _mm_shuffle_epi8(cb_xmm, low_nib/high_nib) — 16-entry table 5b : split-table PSHUFB + BLENDV (32-entry via 2× 16-entry) 5b_fast : direct 1-byte index loads (cleanest path; same split-table) 3b : single PSHUFB (8-entry codebook fits in low half) Pattern: PSHUFB ↔ vqtbl1q_s8, BLENDV bit-trick handles 32-entry codebooks that don't fit in a single 16-byte register. Layout matches the NEON register cadence (16 elements/iter for 5b/3b/5b_fast, 32 for 4b). Build: NEON unaffected (35/35 tests pass). AVX2 path is `#elif defined(__AVX2__)` so existing CMake -mavx2 flag automatically activates it on x86 CI. Tests: added KV_5B_FAST_AttentionCosine regression (cos > 0.999) — was missing coverage. Existing 3b/4b/5b cosine tests will exercise AVX2 on Linux x86 CI runners and catch any numerical drift. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4d8952c commit 2dcbde4

2 files changed

Lines changed: 322 additions & 0 deletions

File tree

src/core/tq_turbo_kv.c

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
#include <arm_neon.h>
2525
#endif
2626

27+
#if defined(__AVX2__)
28+
#include <immintrin.h>
29+
#endif
30+
2731
/* Forward declarations from other modules */
2832
extern void tq_codebook_quantize(const float* src, uint8_t* dst_indices,
2933
int n, int bits, float inv_std);
@@ -356,6 +360,23 @@ void tq_turbo_kv_3b_attention_ref(const float* query, const void* kv_cache,
356360
s_cb3_i8_init = 1;
357361
}
358362
int8x16_t cb_vec = vld1q_s8(s_cb3_i8);
363+
#elif defined(__AVX2__)
364+
/* 8-entry codebook fits in lower 8 bytes; PSHUFB only uses low 4 bits of
365+
* the index, and our 3-bit indices are guaranteed to be in [0..7]. */
366+
static int8_t s_cb3_i8[16] = {0};
367+
static int s_cb3_i8_init = 0;
368+
if (!s_cb3_i8_init) {
369+
for (int j = 0; j < 8; j++) {
370+
float v = cb[j] * (127.0f / 2.1520f);
371+
int q = (int)(v >= 0 ? v + 0.5f : v - 0.5f);
372+
if (q < -127) q = -127;
373+
if (q > 127) q = 127;
374+
s_cb3_i8[j] = (int8_t)q;
375+
}
376+
for (int j = 8; j < 16; j++) s_cb3_i8[j] = 0;
377+
s_cb3_i8_init = 1;
378+
}
379+
const __m128i cb3_xmm = _mm_loadu_si128((const __m128i*)s_cb3_i8);
359380
#endif
360381

361382
for (int seq = 0; seq < seq_len; seq++) {
@@ -433,6 +454,63 @@ void tq_turbo_kv_3b_attention_ref(const float* query, const void* kv_cache,
433454
int idx = (v >> bit_pos) & 0x07;
434455
mse_dot += q_rot[d] * (s_cb3_i8[idx] * per_block_scale);
435456
}
457+
#elif defined(__AVX2__)
458+
__m256 acc0 = _mm256_setzero_ps();
459+
__m256 acc1 = _mm256_setzero_ps();
460+
const __m256 scale_v = _mm256_set1_ps(per_block_scale);
461+
462+
int d = 0;
463+
for (; d + 15 < dim; d += 16) {
464+
const uint8_t* p = mi + (d * 3) / 8;
465+
uint64_t w; memcpy(&w, p, 8);
466+
467+
uint8_t idx_buf[16];
468+
idx_buf[0] = (uint8_t)((w >> 0) & 0x07);
469+
idx_buf[1] = (uint8_t)((w >> 3) & 0x07);
470+
idx_buf[2] = (uint8_t)((w >> 6) & 0x07);
471+
idx_buf[3] = (uint8_t)((w >> 9) & 0x07);
472+
idx_buf[4] = (uint8_t)((w >> 12) & 0x07);
473+
idx_buf[5] = (uint8_t)((w >> 15) & 0x07);
474+
idx_buf[6] = (uint8_t)((w >> 18) & 0x07);
475+
idx_buf[7] = (uint8_t)((w >> 21) & 0x07);
476+
idx_buf[8] = (uint8_t)((w >> 24) & 0x07);
477+
idx_buf[9] = (uint8_t)((w >> 27) & 0x07);
478+
idx_buf[10] = (uint8_t)((w >> 30) & 0x07);
479+
idx_buf[11] = (uint8_t)((w >> 33) & 0x07);
480+
idx_buf[12] = (uint8_t)((w >> 36) & 0x07);
481+
idx_buf[13] = (uint8_t)((w >> 39) & 0x07);
482+
idx_buf[14] = (uint8_t)((w >> 42) & 0x07);
483+
idx_buf[15] = (uint8_t)((w >> 45) & 0x07);
484+
485+
__m128i indices = _mm_loadu_si128((const __m128i*)idx_buf);
486+
__m128i vals = _mm_shuffle_epi8(cb3_xmm, indices);
487+
488+
__m256i i32_lo = _mm256_cvtepi8_epi32(vals);
489+
__m256i i32_hi = _mm256_cvtepi8_epi32(_mm_srli_si128(vals, 8));
490+
__m256 f0 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_lo), scale_v);
491+
__m256 f1 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_hi), scale_v);
492+
493+
acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 0]), f0, acc0);
494+
acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 8]), f1, acc1);
495+
}
496+
{
497+
__m256 sum = _mm256_add_ps(acc0, acc1);
498+
__m128 lo = _mm256_castps256_ps128(sum);
499+
__m128 hi = _mm256_extractf128_ps(sum, 1);
500+
__m128 s = _mm_add_ps(lo, hi);
501+
s = _mm_hadd_ps(s, s);
502+
s = _mm_hadd_ps(s, s);
503+
mse_dot = _mm_cvtss_f32(s);
504+
}
505+
for (; d < dim; d++) {
506+
int bit_off = d * 3;
507+
int byte_idx = bit_off / 8;
508+
int bit_pos = bit_off % 8;
509+
uint16_t v = mi[byte_idx];
510+
if (bit_pos > 5) v |= (uint16_t)mi[byte_idx + 1] << 8;
511+
int idx = (v >> bit_pos) & 0x07;
512+
mse_dot += q_rot[d] * (s_cb3_i8[idx] * per_block_scale);
513+
}
436514
#else
437515
float lut[8];
438516
for (int j = 0; j < 8; j++) lut[j] = cb[j] / inv_std;
@@ -577,6 +655,26 @@ void tq_turbo_kv_4b_attention_ref(const float* query, const void* kv_cache,
577655
s_cb_i8_init = 1;
578656
}
579657
int8x16_t cb_vec = vld1q_s8(s_cb_i8);
658+
#elif defined(__AVX2__)
659+
/* x86 AVX2 mirror of the NEON tbl pattern.
660+
* _mm_shuffle_epi8 implements a 16-entry int8 table lookup in 1 instruction
661+
* (PSHUFB), exactly matching vqtbl1q_s8. Round 10's NEON breakthrough ports
662+
* to AVX2 1:1, since 16-entry codebook fits a 128-bit register on both ISAs.
663+
*/
664+
static int8_t s_cb_i8[16] = {0};
665+
static int s_cb_i8_init = 0;
666+
if (!s_cb_i8_init) {
667+
for (int j = 0; j < 16; j++) {
668+
float v = cb[j] * (127.0f / 2.7326f);
669+
int q = (int)(v >= 0 ? v + 0.5f : v - 0.5f);
670+
if (q < -127) q = -127;
671+
if (q > 127) q = 127;
672+
s_cb_i8[j] = (int8_t)q;
673+
}
674+
s_cb_i8_init = 1;
675+
}
676+
const __m128i cb_xmm = _mm_loadu_si128((const __m128i*)s_cb_i8);
677+
const __m128i mask0F = _mm_set1_epi8(0x0F);
580678
#endif
581679

582680
for (int seq = 0; seq < seq_len; seq++) {
@@ -656,6 +754,55 @@ void tq_turbo_kv_4b_attention_ref(const float* query, const void* kv_cache,
656754
int idx = (d & 1) ? (bv >> 4) : (bv & 0x0F);
657755
mse_dot += q_rot[d] * (s_cb_i8[idx] * per_block_scale);
658756
}
757+
#elif defined(__AVX2__)
758+
/* AVX2 path: 32 elements per iter, mirroring the NEON layout. */
759+
__m256 acc0 = _mm256_setzero_ps();
760+
__m256 acc1 = _mm256_setzero_ps();
761+
__m256 acc2 = _mm256_setzero_ps();
762+
__m256 acc3 = _mm256_setzero_ps();
763+
const __m256 scale_v = _mm256_set1_ps(per_block_scale);
764+
765+
int d = 0;
766+
for (; d + 31 < dim; d += 32) {
767+
__m128i bytes = _mm_loadu_si128((const __m128i*)(mi + d / 2));
768+
__m128i low_nib = _mm_and_si128(bytes, mask0F);
769+
__m128i high_nib = _mm_and_si128(_mm_srli_epi16(bytes, 4), mask0F);
770+
__m128i low_vals = _mm_shuffle_epi8(cb_xmm, low_nib);
771+
__m128i high_vals = _mm_shuffle_epi8(cb_xmm, high_nib);
772+
773+
/* Interleave: result[2i]=low[i], result[2i+1]=high[i] */
774+
__m128i inter_lo = _mm_unpacklo_epi8(low_vals, high_vals); /* elems 0..15 */
775+
__m128i inter_hi = _mm_unpackhi_epi8(low_vals, high_vals); /* elems 16..31 */
776+
777+
__m256i i32_0 = _mm256_cvtepi8_epi32(inter_lo);
778+
__m256i i32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(inter_lo, 8));
779+
__m256i i32_2 = _mm256_cvtepi8_epi32(inter_hi);
780+
__m256i i32_3 = _mm256_cvtepi8_epi32(_mm_srli_si128(inter_hi, 8));
781+
782+
__m256 f0 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_0), scale_v);
783+
__m256 f1 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_1), scale_v);
784+
__m256 f2 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_2), scale_v);
785+
__m256 f3 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_3), scale_v);
786+
787+
acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 0]), f0, acc0);
788+
acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 8]), f1, acc1);
789+
acc2 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 16]), f2, acc2);
790+
acc3 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 24]), f3, acc3);
791+
}
792+
{
793+
__m256 sum = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3));
794+
__m128 lo = _mm256_castps256_ps128(sum);
795+
__m128 hi = _mm256_extractf128_ps(sum, 1);
796+
__m128 s = _mm_add_ps(lo, hi);
797+
s = _mm_hadd_ps(s, s);
798+
s = _mm_hadd_ps(s, s);
799+
mse_dot = _mm_cvtss_f32(s);
800+
}
801+
for (; d < dim; d++) {
802+
uint8_t bv = mi[d / 2];
803+
int idx = (d & 1) ? (bv >> 4) : (bv & 0x0F);
804+
mse_dot += q_rot[d] * (s_cb_i8[idx] * per_block_scale);
805+
}
659806
#else
660807
/* Scalar fallback */
661808
float lut[16];
@@ -1317,6 +1464,30 @@ void tq_turbo_kv_5b_attention_ref(const float* query, const void* kv_cache,
13171464
s_cb5_i8_init = 1;
13181465
}
13191466
int8x16x2_t cb_vec = { vld1q_s8(s_cb5_i8), vld1q_s8(s_cb5_i8 + 16) };
1467+
#elif defined(__AVX2__)
1468+
/* AVX2 mirror of NEON vqtbl2q_s8 (32-entry table lookup).
1469+
*
1470+
* AVX2's PSHUFB is per-lane 16-entry only. We split the 32-entry codebook
1471+
* into cb_lo (entries 0..15) and cb_hi (entries 16..31), do two PSHUFBs
1472+
* with indices & 0x0F, then BLENDV based on the original bit 4 of each
1473+
* index (1 → use cb_hi). Cost: ~5 ops vs NEON's 1, still SIMD over scalar.
1474+
*/
1475+
static int8_t s_cb5_i8[32] = {0};
1476+
static int s_cb5_i8_init = 0;
1477+
if (!s_cb5_i8_init) {
1478+
for (int j = 0; j < 32; j++) {
1479+
float v = cb[j] * (127.0f / 1.9956f);
1480+
int q = (int)(v >= 0 ? v + 0.5f : v - 0.5f);
1481+
if (q < -127) q = -127;
1482+
if (q > 127) q = 127;
1483+
s_cb5_i8[j] = (int8_t)q;
1484+
}
1485+
s_cb5_i8_init = 1;
1486+
}
1487+
const __m128i cb5_lo_xmm = _mm_loadu_si128((const __m128i*)(s_cb5_i8 + 0));
1488+
const __m128i cb5_hi_xmm = _mm_loadu_si128((const __m128i*)(s_cb5_i8 + 16));
1489+
const __m128i mask0F_x = _mm_set1_epi8(0x0F);
1490+
const __m128i mask80_x = _mm_set1_epi8((char)0x80);
13201491
#endif
13211492

13221493
for (int seq = 0; seq < seq_len; seq++) {
@@ -1405,6 +1576,71 @@ void tq_turbo_kv_5b_attention_ref(const float* query, const void* kv_cache,
14051576
int idx = (v >> bit_pos) & 0x1F;
14061577
mse_dot += q_rot[d] * (s_cb5_i8[idx] * per_block_scale);
14071578
}
1579+
#elif defined(__AVX2__)
1580+
__m256 acc0 = _mm256_setzero_ps();
1581+
__m256 acc1 = _mm256_setzero_ps();
1582+
const __m256 scale_v = _mm256_set1_ps(per_block_scale);
1583+
1584+
int d = 0;
1585+
for (; d + 15 < dim; d += 16) {
1586+
/* 5-bit unpack: 16 indices = 10 bytes (= two 5-byte groups) */
1587+
const uint8_t* p0 = mi + (d * 5) / 8;
1588+
uint64_t w0; memcpy(&w0, p0, 8);
1589+
const uint8_t* p1 = p0 + 5;
1590+
uint64_t w1; memcpy(&w1, p1, 8);
1591+
1592+
uint8_t idx_buf[16];
1593+
idx_buf[0] = (uint8_t)((w0 >> 0) & 0x1F);
1594+
idx_buf[1] = (uint8_t)((w0 >> 5) & 0x1F);
1595+
idx_buf[2] = (uint8_t)((w0 >> 10) & 0x1F);
1596+
idx_buf[3] = (uint8_t)((w0 >> 15) & 0x1F);
1597+
idx_buf[4] = (uint8_t)((w0 >> 20) & 0x1F);
1598+
idx_buf[5] = (uint8_t)((w0 >> 25) & 0x1F);
1599+
idx_buf[6] = (uint8_t)((w0 >> 30) & 0x1F);
1600+
idx_buf[7] = (uint8_t)((w0 >> 35) & 0x1F);
1601+
idx_buf[8] = (uint8_t)((w1 >> 0) & 0x1F);
1602+
idx_buf[9] = (uint8_t)((w1 >> 5) & 0x1F);
1603+
idx_buf[10] = (uint8_t)((w1 >> 10) & 0x1F);
1604+
idx_buf[11] = (uint8_t)((w1 >> 15) & 0x1F);
1605+
idx_buf[12] = (uint8_t)((w1 >> 20) & 0x1F);
1606+
idx_buf[13] = (uint8_t)((w1 >> 25) & 0x1F);
1607+
idx_buf[14] = (uint8_t)((w1 >> 30) & 0x1F);
1608+
idx_buf[15] = (uint8_t)((w1 >> 35) & 0x1F);
1609+
1610+
__m128i indices = _mm_loadu_si128((const __m128i*)idx_buf);
1611+
__m128i lo_idx = _mm_and_si128(indices, mask0F_x);
1612+
__m128i lo_vals = _mm_shuffle_epi8(cb5_lo_xmm, lo_idx);
1613+
__m128i hi_vals = _mm_shuffle_epi8(cb5_hi_xmm, lo_idx);
1614+
/* Bit 4 of original index → bit 7 (sign) for blendv selector */
1615+
__m128i sel_mask = _mm_and_si128(_mm_slli_epi16(indices, 3), mask80_x);
1616+
__m128i vals = _mm_blendv_epi8(lo_vals, hi_vals, sel_mask);
1617+
1618+
__m256i i32_lo = _mm256_cvtepi8_epi32(vals);
1619+
__m256i i32_hi = _mm256_cvtepi8_epi32(_mm_srli_si128(vals, 8));
1620+
__m256 f0 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_lo), scale_v);
1621+
__m256 f1 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_hi), scale_v);
1622+
1623+
acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 0]), f0, acc0);
1624+
acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 8]), f1, acc1);
1625+
}
1626+
{
1627+
__m256 sum = _mm256_add_ps(acc0, acc1);
1628+
__m128 lo = _mm256_castps256_ps128(sum);
1629+
__m128 hi = _mm256_extractf128_ps(sum, 1);
1630+
__m128 s = _mm_add_ps(lo, hi);
1631+
s = _mm_hadd_ps(s, s);
1632+
s = _mm_hadd_ps(s, s);
1633+
mse_dot = _mm_cvtss_f32(s);
1634+
}
1635+
for (; d < dim; d++) {
1636+
int bit_off = d * 5;
1637+
int byte_idx = bit_off / 8;
1638+
int bit_pos = bit_off % 8;
1639+
uint16_t v = mi[byte_idx];
1640+
if (bit_pos > 3) v |= (uint16_t)mi[byte_idx + 1] << 8;
1641+
int idx = (v >> bit_pos) & 0x1F;
1642+
mse_dot += q_rot[d] * (s_cb5_i8[idx] * per_block_scale);
1643+
}
14081644
#else
14091645
/* Scalar fallback */
14101646
float lut[32];
@@ -1883,6 +2119,23 @@ void tq_turbo_kv_5b_fast_attention_ref(const float* query, const void* kv_cache,
18832119
s_cb5fast_init = 1;
18842120
}
18852121
int8x16x2_t cb_vec = { vld1q_s8(s_cb5fast_i8), vld1q_s8(s_cb5fast_i8 + 16) };
2122+
#elif defined(__AVX2__)
2123+
static int8_t s_cb5fast_i8[32] = {0};
2124+
static int s_cb5fast_init = 0;
2125+
if (!s_cb5fast_init) {
2126+
for (int j = 0; j < 32; j++) {
2127+
float v = cb[j] * (127.0f / 1.9956f);
2128+
int q = (int)(v >= 0 ? v + 0.5f : v - 0.5f);
2129+
if (q < -127) q = -127;
2130+
if (q > 127) q = 127;
2131+
s_cb5fast_i8[j] = (int8_t)q;
2132+
}
2133+
s_cb5fast_init = 1;
2134+
}
2135+
const __m128i cb5f_lo_xmm = _mm_loadu_si128((const __m128i*)(s_cb5fast_i8 + 0));
2136+
const __m128i cb5f_hi_xmm = _mm_loadu_si128((const __m128i*)(s_cb5fast_i8 + 16));
2137+
const __m128i mask0F_xf = _mm_set1_epi8(0x0F);
2138+
const __m128i mask80_xf = _mm_set1_epi8((char)0x80);
18862139
#endif
18872140

18882141
for (int seq = 0; seq < seq_len; seq++) {
@@ -1928,6 +2181,42 @@ void tq_turbo_kv_5b_fast_attention_ref(const float* query, const void* kv_cache,
19282181
}
19292182
mse_dot = vaddvq_f32(vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3)));
19302183

2184+
for (; d < dim; d++) {
2185+
mse_dot += q_rot[d] * (s_cb5fast_i8[mi[d]] * per_block_scale);
2186+
}
2187+
#elif defined(__AVX2__)
2188+
/* Direct 1-byte-per-index loads — no scalar unpack. The cleanest
2189+
* AVX2 path of all turbo_kv variants thanks to byte alignment. */
2190+
__m256 acc0 = _mm256_setzero_ps();
2191+
__m256 acc1 = _mm256_setzero_ps();
2192+
const __m256 scale_v = _mm256_set1_ps(per_block_scale);
2193+
2194+
int d = 0;
2195+
for (; d + 15 < dim; d += 16) {
2196+
__m128i indices = _mm_loadu_si128((const __m128i*)(mi + d));
2197+
__m128i lo_idx = _mm_and_si128(indices, mask0F_xf);
2198+
__m128i lo_vals = _mm_shuffle_epi8(cb5f_lo_xmm, lo_idx);
2199+
__m128i hi_vals = _mm_shuffle_epi8(cb5f_hi_xmm, lo_idx);
2200+
__m128i sel_mask = _mm_and_si128(_mm_slli_epi16(indices, 3), mask80_xf);
2201+
__m128i vals = _mm_blendv_epi8(lo_vals, hi_vals, sel_mask);
2202+
2203+
__m256i i32_lo = _mm256_cvtepi8_epi32(vals);
2204+
__m256i i32_hi = _mm256_cvtepi8_epi32(_mm_srli_si128(vals, 8));
2205+
__m256 f0 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_lo), scale_v);
2206+
__m256 f1 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_hi), scale_v);
2207+
2208+
acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 0]), f0, acc0);
2209+
acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 8]), f1, acc1);
2210+
}
2211+
{
2212+
__m256 sum = _mm256_add_ps(acc0, acc1);
2213+
__m128 lo = _mm256_castps256_ps128(sum);
2214+
__m128 hi = _mm256_extractf128_ps(sum, 1);
2215+
__m128 s = _mm_add_ps(lo, hi);
2216+
s = _mm_hadd_ps(s, s);
2217+
s = _mm_hadd_ps(s, s);
2218+
mse_dot = _mm_cvtss_f32(s);
2219+
}
19312220
for (; d < dim; d++) {
19322221
mse_dot += q_rot[d] * (s_cb5fast_i8[mi[d]] * per_block_scale);
19332222
}

tests/test_turbo_kv.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,9 @@ TEST(TurboKV, ZeroInput) {
488488

489489
extern "C" {
490490
void tq_turbo_kv_5b_quantize_ref(const float* src, void* dst, int n);
491+
void tq_turbo_kv_5b_fast_quantize_ref(const float* src, void* dst, int n);
492+
void tq_turbo_kv_5b_fast_attention_ref(const float* query, const void* kv,
493+
float* scores, int seq_len, int head_dim);
491494
void tq_turbo_kv_5b_attention_ref(const float* query, const void* kv,
492495
float* scores, int seq_len, int head_dim);
493496
}
@@ -626,3 +629,33 @@ TEST(TurboKVRegression, KV_5B_BeatsKV_4B) {
626629
<< "5-bit must be at least as accurate as 4-bit (5b=" << cos5
627630
<< ", 4b=" << cos4 << ")";
628631
}
632+
633+
TEST(TurboKVRegression, KV_5B_FAST_AttentionCosine) {
634+
/* turbo_kv_5b_fast uses the same 32-entry codebook as 5b but with a
635+
* 1-byte-per-index layout (no scalar bit unpack). Quality must match
636+
* the 5b path bit-for-bit on the codebook lookup; only the layout differs. */
637+
const int dim = TQ_BK;
638+
const int n_keys = 256;
639+
640+
std::vector<std::vector<float>> keys;
641+
synth_keys(keys, n_keys, dim, /*seed=*/0xC0FFEE);
642+
std::vector<float> q;
643+
synth_query(q, dim, /*seed=*/0xBADC0DE);
644+
645+
std::vector<float> ref_scores;
646+
fp32_attention(q, keys, ref_scores);
647+
648+
std::vector<block_tq_turbo_kv_5b_fast> blocks(n_keys);
649+
for (int s = 0; s < n_keys; s++) {
650+
memset(&blocks[s], 0, sizeof(blocks[s]));
651+
tq_turbo_kv_5b_fast_quantize_ref(keys[s].data(), &blocks[s], dim);
652+
}
653+
654+
std::vector<float> est_scores(n_keys);
655+
tq_turbo_kv_5b_fast_attention_ref(q.data(), blocks.data(), est_scores.data(),
656+
n_keys, dim);
657+
658+
double cos = compute_cosine(ref_scores.data(), est_scores.data(), n_keys);
659+
EXPECT_GT(cos, 0.999)
660+
<< "turbo_kv_5b_fast attention cosine regressed below 0.999";
661+
}

0 commit comments

Comments
 (0)