|
24 | 24 | #include <arm_neon.h> |
25 | 25 | #endif |
26 | 26 |
|
| 27 | +#if defined(__AVX2__) |
| 28 | +#include <immintrin.h> |
| 29 | +#endif |
| 30 | + |
27 | 31 | /* Forward declarations from other modules */ |
28 | 32 | extern void tq_codebook_quantize(const float* src, uint8_t* dst_indices, |
29 | 33 | int n, int bits, float inv_std); |
@@ -356,6 +360,23 @@ void tq_turbo_kv_3b_attention_ref(const float* query, const void* kv_cache, |
356 | 360 | s_cb3_i8_init = 1; |
357 | 361 | } |
358 | 362 | int8x16_t cb_vec = vld1q_s8(s_cb3_i8); |
| 363 | +#elif defined(__AVX2__) |
| 364 | + /* 8-entry codebook fits in lower 8 bytes; PSHUFB only uses low 4 bits of |
| 365 | + * the index, and our 3-bit indices are guaranteed to be in [0..7]. */ |
| 366 | + static int8_t s_cb3_i8[16] = {0}; |
| 367 | + static int s_cb3_i8_init = 0; |
| 368 | + if (!s_cb3_i8_init) { |
| 369 | + for (int j = 0; j < 8; j++) { |
| 370 | + float v = cb[j] * (127.0f / 2.1520f); |
| 371 | + int q = (int)(v >= 0 ? v + 0.5f : v - 0.5f); |
| 372 | + if (q < -127) q = -127; |
| 373 | + if (q > 127) q = 127; |
| 374 | + s_cb3_i8[j] = (int8_t)q; |
| 375 | + } |
| 376 | + for (int j = 8; j < 16; j++) s_cb3_i8[j] = 0; |
| 377 | + s_cb3_i8_init = 1; |
| 378 | + } |
| 379 | + const __m128i cb3_xmm = _mm_loadu_si128((const __m128i*)s_cb3_i8); |
359 | 380 | #endif |
360 | 381 |
|
361 | 382 | for (int seq = 0; seq < seq_len; seq++) { |
@@ -433,6 +454,63 @@ void tq_turbo_kv_3b_attention_ref(const float* query, const void* kv_cache, |
433 | 454 | int idx = (v >> bit_pos) & 0x07; |
434 | 455 | mse_dot += q_rot[d] * (s_cb3_i8[idx] * per_block_scale); |
435 | 456 | } |
| 457 | +#elif defined(__AVX2__) |
| 458 | + __m256 acc0 = _mm256_setzero_ps(); |
| 459 | + __m256 acc1 = _mm256_setzero_ps(); |
| 460 | + const __m256 scale_v = _mm256_set1_ps(per_block_scale); |
| 461 | + |
| 462 | + int d = 0; |
| 463 | + for (; d + 15 < dim; d += 16) { |
| 464 | + const uint8_t* p = mi + (d * 3) / 8; |
| 465 | + uint64_t w; memcpy(&w, p, 8); |
| 466 | + |
| 467 | + uint8_t idx_buf[16]; |
| 468 | + idx_buf[0] = (uint8_t)((w >> 0) & 0x07); |
| 469 | + idx_buf[1] = (uint8_t)((w >> 3) & 0x07); |
| 470 | + idx_buf[2] = (uint8_t)((w >> 6) & 0x07); |
| 471 | + idx_buf[3] = (uint8_t)((w >> 9) & 0x07); |
| 472 | + idx_buf[4] = (uint8_t)((w >> 12) & 0x07); |
| 473 | + idx_buf[5] = (uint8_t)((w >> 15) & 0x07); |
| 474 | + idx_buf[6] = (uint8_t)((w >> 18) & 0x07); |
| 475 | + idx_buf[7] = (uint8_t)((w >> 21) & 0x07); |
| 476 | + idx_buf[8] = (uint8_t)((w >> 24) & 0x07); |
| 477 | + idx_buf[9] = (uint8_t)((w >> 27) & 0x07); |
| 478 | + idx_buf[10] = (uint8_t)((w >> 30) & 0x07); |
| 479 | + idx_buf[11] = (uint8_t)((w >> 33) & 0x07); |
| 480 | + idx_buf[12] = (uint8_t)((w >> 36) & 0x07); |
| 481 | + idx_buf[13] = (uint8_t)((w >> 39) & 0x07); |
| 482 | + idx_buf[14] = (uint8_t)((w >> 42) & 0x07); |
| 483 | + idx_buf[15] = (uint8_t)((w >> 45) & 0x07); |
| 484 | + |
| 485 | + __m128i indices = _mm_loadu_si128((const __m128i*)idx_buf); |
| 486 | + __m128i vals = _mm_shuffle_epi8(cb3_xmm, indices); |
| 487 | + |
| 488 | + __m256i i32_lo = _mm256_cvtepi8_epi32(vals); |
| 489 | + __m256i i32_hi = _mm256_cvtepi8_epi32(_mm_srli_si128(vals, 8)); |
| 490 | + __m256 f0 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_lo), scale_v); |
| 491 | + __m256 f1 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_hi), scale_v); |
| 492 | + |
| 493 | + acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 0]), f0, acc0); |
| 494 | + acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 8]), f1, acc1); |
| 495 | + } |
| 496 | + { |
| 497 | + __m256 sum = _mm256_add_ps(acc0, acc1); |
| 498 | + __m128 lo = _mm256_castps256_ps128(sum); |
| 499 | + __m128 hi = _mm256_extractf128_ps(sum, 1); |
| 500 | + __m128 s = _mm_add_ps(lo, hi); |
| 501 | + s = _mm_hadd_ps(s, s); |
| 502 | + s = _mm_hadd_ps(s, s); |
| 503 | + mse_dot = _mm_cvtss_f32(s); |
| 504 | + } |
| 505 | + for (; d < dim; d++) { |
| 506 | + int bit_off = d * 3; |
| 507 | + int byte_idx = bit_off / 8; |
| 508 | + int bit_pos = bit_off % 8; |
| 509 | + uint16_t v = mi[byte_idx]; |
| 510 | + if (bit_pos > 5) v |= (uint16_t)mi[byte_idx + 1] << 8; |
| 511 | + int idx = (v >> bit_pos) & 0x07; |
| 512 | + mse_dot += q_rot[d] * (s_cb3_i8[idx] * per_block_scale); |
| 513 | + } |
436 | 514 | #else |
437 | 515 | float lut[8]; |
438 | 516 | for (int j = 0; j < 8; j++) lut[j] = cb[j] / inv_std; |
@@ -577,6 +655,26 @@ void tq_turbo_kv_4b_attention_ref(const float* query, const void* kv_cache, |
577 | 655 | s_cb_i8_init = 1; |
578 | 656 | } |
579 | 657 | int8x16_t cb_vec = vld1q_s8(s_cb_i8); |
| 658 | +#elif defined(__AVX2__) |
| 659 | + /* x86 AVX2 mirror of the NEON tbl pattern. |
| 660 | + * _mm_shuffle_epi8 implements a 16-entry int8 table lookup in 1 instruction |
| 661 | + * (PSHUFB), exactly matching vqtbl1q_s8. Round 10's NEON breakthrough ports |
| 662 | + * to AVX2 1:1, since 16-entry codebook fits a 128-bit register on both ISAs. |
| 663 | + */ |
| 664 | + static int8_t s_cb_i8[16] = {0}; |
| 665 | + static int s_cb_i8_init = 0; |
| 666 | + if (!s_cb_i8_init) { |
| 667 | + for (int j = 0; j < 16; j++) { |
| 668 | + float v = cb[j] * (127.0f / 2.7326f); |
| 669 | + int q = (int)(v >= 0 ? v + 0.5f : v - 0.5f); |
| 670 | + if (q < -127) q = -127; |
| 671 | + if (q > 127) q = 127; |
| 672 | + s_cb_i8[j] = (int8_t)q; |
| 673 | + } |
| 674 | + s_cb_i8_init = 1; |
| 675 | + } |
| 676 | + const __m128i cb_xmm = _mm_loadu_si128((const __m128i*)s_cb_i8); |
| 677 | + const __m128i mask0F = _mm_set1_epi8(0x0F); |
580 | 678 | #endif |
581 | 679 |
|
582 | 680 | for (int seq = 0; seq < seq_len; seq++) { |
@@ -656,6 +754,55 @@ void tq_turbo_kv_4b_attention_ref(const float* query, const void* kv_cache, |
656 | 754 | int idx = (d & 1) ? (bv >> 4) : (bv & 0x0F); |
657 | 755 | mse_dot += q_rot[d] * (s_cb_i8[idx] * per_block_scale); |
658 | 756 | } |
| 757 | +#elif defined(__AVX2__) |
| 758 | + /* AVX2 path: 32 elements per iter, mirroring the NEON layout. */ |
| 759 | + __m256 acc0 = _mm256_setzero_ps(); |
| 760 | + __m256 acc1 = _mm256_setzero_ps(); |
| 761 | + __m256 acc2 = _mm256_setzero_ps(); |
| 762 | + __m256 acc3 = _mm256_setzero_ps(); |
| 763 | + const __m256 scale_v = _mm256_set1_ps(per_block_scale); |
| 764 | + |
| 765 | + int d = 0; |
| 766 | + for (; d + 31 < dim; d += 32) { |
| 767 | + __m128i bytes = _mm_loadu_si128((const __m128i*)(mi + d / 2)); |
| 768 | + __m128i low_nib = _mm_and_si128(bytes, mask0F); |
| 769 | + __m128i high_nib = _mm_and_si128(_mm_srli_epi16(bytes, 4), mask0F); |
| 770 | + __m128i low_vals = _mm_shuffle_epi8(cb_xmm, low_nib); |
| 771 | + __m128i high_vals = _mm_shuffle_epi8(cb_xmm, high_nib); |
| 772 | + |
| 773 | + /* Interleave: result[2i]=low[i], result[2i+1]=high[i] */ |
| 774 | + __m128i inter_lo = _mm_unpacklo_epi8(low_vals, high_vals); /* elems 0..15 */ |
| 775 | + __m128i inter_hi = _mm_unpackhi_epi8(low_vals, high_vals); /* elems 16..31 */ |
| 776 | + |
| 777 | + __m256i i32_0 = _mm256_cvtepi8_epi32(inter_lo); |
| 778 | + __m256i i32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(inter_lo, 8)); |
| 779 | + __m256i i32_2 = _mm256_cvtepi8_epi32(inter_hi); |
| 780 | + __m256i i32_3 = _mm256_cvtepi8_epi32(_mm_srli_si128(inter_hi, 8)); |
| 781 | + |
| 782 | + __m256 f0 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_0), scale_v); |
| 783 | + __m256 f1 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_1), scale_v); |
| 784 | + __m256 f2 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_2), scale_v); |
| 785 | + __m256 f3 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_3), scale_v); |
| 786 | + |
| 787 | + acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 0]), f0, acc0); |
| 788 | + acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 8]), f1, acc1); |
| 789 | + acc2 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 16]), f2, acc2); |
| 790 | + acc3 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 24]), f3, acc3); |
| 791 | + } |
| 792 | + { |
| 793 | + __m256 sum = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3)); |
| 794 | + __m128 lo = _mm256_castps256_ps128(sum); |
| 795 | + __m128 hi = _mm256_extractf128_ps(sum, 1); |
| 796 | + __m128 s = _mm_add_ps(lo, hi); |
| 797 | + s = _mm_hadd_ps(s, s); |
| 798 | + s = _mm_hadd_ps(s, s); |
| 799 | + mse_dot = _mm_cvtss_f32(s); |
| 800 | + } |
| 801 | + for (; d < dim; d++) { |
| 802 | + uint8_t bv = mi[d / 2]; |
| 803 | + int idx = (d & 1) ? (bv >> 4) : (bv & 0x0F); |
| 804 | + mse_dot += q_rot[d] * (s_cb_i8[idx] * per_block_scale); |
| 805 | + } |
659 | 806 | #else |
660 | 807 | /* Scalar fallback */ |
661 | 808 | float lut[16]; |
@@ -1317,6 +1464,30 @@ void tq_turbo_kv_5b_attention_ref(const float* query, const void* kv_cache, |
1317 | 1464 | s_cb5_i8_init = 1; |
1318 | 1465 | } |
1319 | 1466 | int8x16x2_t cb_vec = { vld1q_s8(s_cb5_i8), vld1q_s8(s_cb5_i8 + 16) }; |
| 1467 | +#elif defined(__AVX2__) |
| 1468 | + /* AVX2 mirror of NEON vqtbl2q_s8 (32-entry table lookup). |
| 1469 | + * |
| 1470 | + * AVX2's PSHUFB is per-lane 16-entry only. We split the 32-entry codebook |
| 1471 | + * into cb_lo (entries 0..15) and cb_hi (entries 16..31), do two PSHUFBs |
| 1472 | + * with indices & 0x0F, then BLENDV based on the original bit 4 of each |
| 1473 | + * index (1 → use cb_hi). Cost: ~5 ops vs NEON's 1, still SIMD over scalar. |
| 1474 | + */ |
| 1475 | + static int8_t s_cb5_i8[32] = {0}; |
| 1476 | + static int s_cb5_i8_init = 0; |
| 1477 | + if (!s_cb5_i8_init) { |
| 1478 | + for (int j = 0; j < 32; j++) { |
| 1479 | + float v = cb[j] * (127.0f / 1.9956f); |
| 1480 | + int q = (int)(v >= 0 ? v + 0.5f : v - 0.5f); |
| 1481 | + if (q < -127) q = -127; |
| 1482 | + if (q > 127) q = 127; |
| 1483 | + s_cb5_i8[j] = (int8_t)q; |
| 1484 | + } |
| 1485 | + s_cb5_i8_init = 1; |
| 1486 | + } |
| 1487 | + const __m128i cb5_lo_xmm = _mm_loadu_si128((const __m128i*)(s_cb5_i8 + 0)); |
| 1488 | + const __m128i cb5_hi_xmm = _mm_loadu_si128((const __m128i*)(s_cb5_i8 + 16)); |
| 1489 | + const __m128i mask0F_x = _mm_set1_epi8(0x0F); |
| 1490 | + const __m128i mask80_x = _mm_set1_epi8((char)0x80); |
1320 | 1491 | #endif |
1321 | 1492 |
|
1322 | 1493 | for (int seq = 0; seq < seq_len; seq++) { |
@@ -1405,6 +1576,71 @@ void tq_turbo_kv_5b_attention_ref(const float* query, const void* kv_cache, |
1405 | 1576 | int idx = (v >> bit_pos) & 0x1F; |
1406 | 1577 | mse_dot += q_rot[d] * (s_cb5_i8[idx] * per_block_scale); |
1407 | 1578 | } |
| 1579 | +#elif defined(__AVX2__) |
| 1580 | + __m256 acc0 = _mm256_setzero_ps(); |
| 1581 | + __m256 acc1 = _mm256_setzero_ps(); |
| 1582 | + const __m256 scale_v = _mm256_set1_ps(per_block_scale); |
| 1583 | + |
| 1584 | + int d = 0; |
| 1585 | + for (; d + 15 < dim; d += 16) { |
| 1586 | + /* 5-bit unpack: 16 indices = 10 bytes (= two 5-byte groups) */ |
| 1587 | + const uint8_t* p0 = mi + (d * 5) / 8; |
| 1588 | + uint64_t w0; memcpy(&w0, p0, 8); |
| 1589 | + const uint8_t* p1 = p0 + 5; |
| 1590 | + uint64_t w1; memcpy(&w1, p1, 8); |
| 1591 | + |
| 1592 | + uint8_t idx_buf[16]; |
| 1593 | + idx_buf[0] = (uint8_t)((w0 >> 0) & 0x1F); |
| 1594 | + idx_buf[1] = (uint8_t)((w0 >> 5) & 0x1F); |
| 1595 | + idx_buf[2] = (uint8_t)((w0 >> 10) & 0x1F); |
| 1596 | + idx_buf[3] = (uint8_t)((w0 >> 15) & 0x1F); |
| 1597 | + idx_buf[4] = (uint8_t)((w0 >> 20) & 0x1F); |
| 1598 | + idx_buf[5] = (uint8_t)((w0 >> 25) & 0x1F); |
| 1599 | + idx_buf[6] = (uint8_t)((w0 >> 30) & 0x1F); |
| 1600 | + idx_buf[7] = (uint8_t)((w0 >> 35) & 0x1F); |
| 1601 | + idx_buf[8] = (uint8_t)((w1 >> 0) & 0x1F); |
| 1602 | + idx_buf[9] = (uint8_t)((w1 >> 5) & 0x1F); |
| 1603 | + idx_buf[10] = (uint8_t)((w1 >> 10) & 0x1F); |
| 1604 | + idx_buf[11] = (uint8_t)((w1 >> 15) & 0x1F); |
| 1605 | + idx_buf[12] = (uint8_t)((w1 >> 20) & 0x1F); |
| 1606 | + idx_buf[13] = (uint8_t)((w1 >> 25) & 0x1F); |
| 1607 | + idx_buf[14] = (uint8_t)((w1 >> 30) & 0x1F); |
| 1608 | + idx_buf[15] = (uint8_t)((w1 >> 35) & 0x1F); |
| 1609 | + |
| 1610 | + __m128i indices = _mm_loadu_si128((const __m128i*)idx_buf); |
| 1611 | + __m128i lo_idx = _mm_and_si128(indices, mask0F_x); |
| 1612 | + __m128i lo_vals = _mm_shuffle_epi8(cb5_lo_xmm, lo_idx); |
| 1613 | + __m128i hi_vals = _mm_shuffle_epi8(cb5_hi_xmm, lo_idx); |
| 1614 | + /* Bit 4 of original index → bit 7 (sign) for blendv selector */ |
| 1615 | + __m128i sel_mask = _mm_and_si128(_mm_slli_epi16(indices, 3), mask80_x); |
| 1616 | + __m128i vals = _mm_blendv_epi8(lo_vals, hi_vals, sel_mask); |
| 1617 | + |
| 1618 | + __m256i i32_lo = _mm256_cvtepi8_epi32(vals); |
| 1619 | + __m256i i32_hi = _mm256_cvtepi8_epi32(_mm_srli_si128(vals, 8)); |
| 1620 | + __m256 f0 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_lo), scale_v); |
| 1621 | + __m256 f1 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_hi), scale_v); |
| 1622 | + |
| 1623 | + acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 0]), f0, acc0); |
| 1624 | + acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 8]), f1, acc1); |
| 1625 | + } |
| 1626 | + { |
| 1627 | + __m256 sum = _mm256_add_ps(acc0, acc1); |
| 1628 | + __m128 lo = _mm256_castps256_ps128(sum); |
| 1629 | + __m128 hi = _mm256_extractf128_ps(sum, 1); |
| 1630 | + __m128 s = _mm_add_ps(lo, hi); |
| 1631 | + s = _mm_hadd_ps(s, s); |
| 1632 | + s = _mm_hadd_ps(s, s); |
| 1633 | + mse_dot = _mm_cvtss_f32(s); |
| 1634 | + } |
| 1635 | + for (; d < dim; d++) { |
| 1636 | + int bit_off = d * 5; |
| 1637 | + int byte_idx = bit_off / 8; |
| 1638 | + int bit_pos = bit_off % 8; |
| 1639 | + uint16_t v = mi[byte_idx]; |
| 1640 | + if (bit_pos > 3) v |= (uint16_t)mi[byte_idx + 1] << 8; |
| 1641 | + int idx = (v >> bit_pos) & 0x1F; |
| 1642 | + mse_dot += q_rot[d] * (s_cb5_i8[idx] * per_block_scale); |
| 1643 | + } |
1408 | 1644 | #else |
1409 | 1645 | /* Scalar fallback */ |
1410 | 1646 | float lut[32]; |
@@ -1883,6 +2119,23 @@ void tq_turbo_kv_5b_fast_attention_ref(const float* query, const void* kv_cache, |
1883 | 2119 | s_cb5fast_init = 1; |
1884 | 2120 | } |
1885 | 2121 | int8x16x2_t cb_vec = { vld1q_s8(s_cb5fast_i8), vld1q_s8(s_cb5fast_i8 + 16) }; |
| 2122 | +#elif defined(__AVX2__) |
| 2123 | + static int8_t s_cb5fast_i8[32] = {0}; |
| 2124 | + static int s_cb5fast_init = 0; |
| 2125 | + if (!s_cb5fast_init) { |
| 2126 | + for (int j = 0; j < 32; j++) { |
| 2127 | + float v = cb[j] * (127.0f / 1.9956f); |
| 2128 | + int q = (int)(v >= 0 ? v + 0.5f : v - 0.5f); |
| 2129 | + if (q < -127) q = -127; |
| 2130 | + if (q > 127) q = 127; |
| 2131 | + s_cb5fast_i8[j] = (int8_t)q; |
| 2132 | + } |
| 2133 | + s_cb5fast_init = 1; |
| 2134 | + } |
| 2135 | + const __m128i cb5f_lo_xmm = _mm_loadu_si128((const __m128i*)(s_cb5fast_i8 + 0)); |
| 2136 | + const __m128i cb5f_hi_xmm = _mm_loadu_si128((const __m128i*)(s_cb5fast_i8 + 16)); |
| 2137 | + const __m128i mask0F_xf = _mm_set1_epi8(0x0F); |
| 2138 | + const __m128i mask80_xf = _mm_set1_epi8((char)0x80); |
1886 | 2139 | #endif |
1887 | 2140 |
|
1888 | 2141 | for (int seq = 0; seq < seq_len; seq++) { |
@@ -1928,6 +2181,42 @@ void tq_turbo_kv_5b_fast_attention_ref(const float* query, const void* kv_cache, |
1928 | 2181 | } |
1929 | 2182 | mse_dot = vaddvq_f32(vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3))); |
1930 | 2183 |
|
| 2184 | + for (; d < dim; d++) { |
| 2185 | + mse_dot += q_rot[d] * (s_cb5fast_i8[mi[d]] * per_block_scale); |
| 2186 | + } |
| 2187 | +#elif defined(__AVX2__) |
| 2188 | + /* Direct 1-byte-per-index loads — no scalar unpack. The cleanest |
| 2189 | + * AVX2 path of all turbo_kv variants thanks to byte alignment. */ |
| 2190 | + __m256 acc0 = _mm256_setzero_ps(); |
| 2191 | + __m256 acc1 = _mm256_setzero_ps(); |
| 2192 | + const __m256 scale_v = _mm256_set1_ps(per_block_scale); |
| 2193 | + |
| 2194 | + int d = 0; |
| 2195 | + for (; d + 15 < dim; d += 16) { |
| 2196 | + __m128i indices = _mm_loadu_si128((const __m128i*)(mi + d)); |
| 2197 | + __m128i lo_idx = _mm_and_si128(indices, mask0F_xf); |
| 2198 | + __m128i lo_vals = _mm_shuffle_epi8(cb5f_lo_xmm, lo_idx); |
| 2199 | + __m128i hi_vals = _mm_shuffle_epi8(cb5f_hi_xmm, lo_idx); |
| 2200 | + __m128i sel_mask = _mm_and_si128(_mm_slli_epi16(indices, 3), mask80_xf); |
| 2201 | + __m128i vals = _mm_blendv_epi8(lo_vals, hi_vals, sel_mask); |
| 2202 | + |
| 2203 | + __m256i i32_lo = _mm256_cvtepi8_epi32(vals); |
| 2204 | + __m256i i32_hi = _mm256_cvtepi8_epi32(_mm_srli_si128(vals, 8)); |
| 2205 | + __m256 f0 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_lo), scale_v); |
| 2206 | + __m256 f1 = _mm256_mul_ps(_mm256_cvtepi32_ps(i32_hi), scale_v); |
| 2207 | + |
| 2208 | + acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 0]), f0, acc0); |
| 2209 | + acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(&q_rot[d + 8]), f1, acc1); |
| 2210 | + } |
| 2211 | + { |
| 2212 | + __m256 sum = _mm256_add_ps(acc0, acc1); |
| 2213 | + __m128 lo = _mm256_castps256_ps128(sum); |
| 2214 | + __m128 hi = _mm256_extractf128_ps(sum, 1); |
| 2215 | + __m128 s = _mm_add_ps(lo, hi); |
| 2216 | + s = _mm_hadd_ps(s, s); |
| 2217 | + s = _mm_hadd_ps(s, s); |
| 2218 | + mse_dot = _mm_cvtss_f32(s); |
| 2219 | + } |
1931 | 2220 | for (; d < dim; d++) { |
1932 | 2221 | mse_dot += q_rot[d] * (s_cb5fast_i8[mi[d]] * per_block_scale); |
1933 | 2222 | } |
|
0 commit comments