@@ -1815,6 +1815,105 @@ static float fused_dot_iq3_xxs(const void* row, const float* x, int n) {
18151815 * First 32 elements: d1 * (q[l] & 0xF) - m1 (low nibble, scale pair[0])
18161816 * Next 32 elements: d2 * (q[l] >> 4) - m2 (high nibble, scale pair[1])
18171817 */
1818+ /* Fused Q2_K dot product: 84 bytes per 256 elements.
1819+ * 2-bit values packed 4-per-byte in qs[64]. scales[16] holds:
1820+ * low 4 bits = sub-block scale (× d)
1821+ * high 4 bits = sub-block min (× dmin)
1822+ * 16 sub-blocks of 16 elements each. Formula per sub-block:
1823+ * sum += dl * dot(q2_values, x) - ml * sum(x)
1824+ * where q2_values are unsigned 2-bit values in [0, 3]. */
1825+ static float fused_dot_q2_k (const void * row , const float * x , int n ) {
1826+ const int nb = n / 256 ;
1827+ const block_q2_K * blk = (const block_q2_K * )row ;
1828+ float sum = 0.0f ;
1829+
1830+ for (int b = 0 ; b < nb ; b ++ ) {
1831+ const float d = fp16_to_fp32 (blk [b ].d );
1832+ const float dmin = fp16_to_fp32 (blk [b ].dmin );
1833+ const uint8_t * q = blk [b ].qs ;
1834+ const float * xp = x + b * 256 ;
1835+
1836+ int is = 0 ;
1837+ /* 2 halves, 4 shifts per half, 2 sub-blocks per shift */
1838+ for (int half = 0 ; half < 2 ; half ++ ) {
1839+ int shift = 0 ;
1840+ for (int j = 0 ; j < 4 ; ++ j ) {
1841+ /* Sub-block 0: q[0..15] >> shift & 3 */
1842+ uint8_t sc0 = blk [b ].scales [is ++ ];
1843+ float dl0 = d * (sc0 & 0x0F );
1844+ float ml0 = dmin * (sc0 >> 4 );
1845+
1846+ /* Sub-block 1: q[16..31] >> shift & 3 */
1847+ uint8_t sc1 = blk [b ].scales [is ++ ];
1848+ float dl1 = d * (sc1 & 0x0F );
1849+ float ml1 = dmin * (sc1 >> 4 );
1850+
1851+ #if TQ_HAS_NEON
1852+ /* Load 16 packed bytes, extract 2-bit values, dot with x.
1853+ * Mask is uint8x16_t of 0x03; shift applied via vshrq_n_u8. */
1854+ uint8x16_t qv0 = vld1q_u8 (q ); /* sub-block 0 bytes */
1855+ uint8x16_t qv1 = vld1q_u8 (q + 16 ); /* sub-block 1 bytes */
1856+ uint8x16_t m03 = vdupq_n_u8 (0x03 );
1857+ uint8x16_t v0 , v1 ;
1858+ switch (shift ) {
1859+ case 0 : v0 = vandq_u8 (qv0 , m03 ); v1 = vandq_u8 (qv1 , m03 ); break ;
1860+ case 2 : v0 = vandq_u8 (vshrq_n_u8 (qv0 , 2 ), m03 ); v1 = vandq_u8 (vshrq_n_u8 (qv1 , 2 ), m03 ); break ;
1861+ case 4 : v0 = vandq_u8 (vshrq_n_u8 (qv0 , 4 ), m03 ); v1 = vandq_u8 (vshrq_n_u8 (qv1 , 4 ), m03 ); break ;
1862+ default : v0 = vshrq_n_u8 (qv0 , 6 ); v1 = vshrq_n_u8 (qv1 , 6 ); break ;
1863+ }
1864+ /* Expand u8 → float32, accumulate dot and sum_x */
1865+ #define TQ_Q2K_ACC (nv , off , dl_v , ml_v ) \
1866+ do { \
1867+ uint16x8_t w16_l = vmovl_u8(vget_low_u8(nv)); \
1868+ uint16x8_t w16_h = vmovl_u8(vget_high_u8(nv)); \
1869+ float32x4_t wf0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(w16_l))); \
1870+ float32x4_t wf1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(w16_l))); \
1871+ float32x4_t wf2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(w16_h))); \
1872+ float32x4_t wf3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(w16_h))); \
1873+ float32x4_t x0 = vld1q_f32(xp + (off)); \
1874+ float32x4_t x1 = vld1q_f32(xp + (off) + 4); \
1875+ float32x4_t x2 = vld1q_f32(xp + (off) + 8); \
1876+ float32x4_t x3 = vld1q_f32(xp + (off) + 12); \
1877+ float32x4_t vd = vdupq_n_f32(0.0f); \
1878+ vd = vfmaq_f32(vd, wf0, x0); \
1879+ vd = vfmaq_f32(vd, wf1, x1); \
1880+ vd = vfmaq_f32(vd, wf2, x2); \
1881+ vd = vfmaq_f32(vd, wf3, x3); \
1882+ float dot_s = vaddvq_f32(vd); \
1883+ float sum_s = vaddvq_f32(vaddq_f32(vaddq_f32(x0, x1), vaddq_f32(x2, x3))); \
1884+ sum += (dl_v) * dot_s - (ml_v) * sum_s; \
1885+ } while (0)
1886+
1887+ int yi = half * 128 + j * 32 ;
1888+ TQ_Q2K_ACC (v0 , yi , dl0 , ml0 );
1889+ TQ_Q2K_ACC (v1 , yi + 16 , dl1 , ml1 );
1890+ #undef TQ_Q2K_ACC
1891+ #else
1892+ int yi = half * 128 + j * 32 ;
1893+ float dot0 = 0 , sumx0 = 0 ;
1894+ for (int l = 0 ; l < 16 ; l ++ ) {
1895+ float xv = xp [yi + l ];
1896+ dot0 += (float )((q [l ] >> shift ) & 3 ) * xv ;
1897+ sumx0 += xv ;
1898+ }
1899+ sum += dl0 * dot0 - ml0 * sumx0 ;
1900+
1901+ float dot1 = 0 , sumx1 = 0 ;
1902+ for (int l = 0 ; l < 16 ; l ++ ) {
1903+ float xv = xp [yi + 16 + l ];
1904+ dot1 += (float )((q [l + 16 ] >> shift ) & 3 ) * xv ;
1905+ sumx1 += xv ;
1906+ }
1907+ sum += dl1 * dot1 - ml1 * sumx1 ;
1908+ #endif
1909+ shift += 2 ;
1910+ }
1911+ q += 32 ;
1912+ }
1913+ }
1914+ return sum ;
1915+ }
1916+
18181917static float fused_dot_q4_k (const void * row , const float * x , int n ) {
18191918 const int nb = n / 256 ;
18201919 const block_q4_K * blk = (const block_q4_K * )row ;
@@ -1846,22 +1945,79 @@ static float fused_dot_q4_k(const void* row, const float* x, int n) {
18461945 const float * xp = x + b * 256 ;
18471946 int is = 0 ;
18481947
1849- /* 4 groups of 64 elements */
1948+ #if TQ_HAS_NEON
1949+ /* 4 groups of 64 elements, NEON-accelerated */
1950+ const uint8x16_t mask_lo = vdupq_n_u8 (0x0F );
1951+ for (int j = 0 ; j < 256 ; j += 64 ) {
1952+ const float d1 = d * sc [is + 0 ];
1953+ const float m1 = dmin * mn [is + 0 ];
1954+ const float d2 = d * sc [is + 1 ];
1955+ const float m2 = dmin * mn [is + 1 ];
1956+
1957+ /* Load 32 bytes = 32 pairs of nibbles covering 64 elements */
1958+ uint8x16_t qa = vld1q_u8 (q );
1959+ uint8x16_t qb = vld1q_u8 (q + 16 );
1960+ /* Extract low nibbles (→ elements 0..31) and high nibbles (→ 32..63) */
1961+ uint8x16_t lo_a = vandq_u8 (qa , mask_lo );
1962+ uint8x16_t lo_b = vandq_u8 (qb , mask_lo );
1963+ uint8x16_t hi_a = vshrq_n_u8 (qa , 4 );
1964+ uint8x16_t hi_b = vshrq_n_u8 (qb , 4 );
1965+
1966+ /* Convert u8 nibbles [0..15] to float32 and dot with xp[...] */
1967+ float32x4_t vdot1 = vdupq_n_f32 (0.0f );
1968+ float32x4_t vsum1 = vdupq_n_f32 (0.0f );
1969+ float32x4_t vdot2 = vdupq_n_f32 (0.0f );
1970+ float32x4_t vsum2 = vdupq_n_f32 (0.0f );
1971+ /* Helper: process 16 elements of x[off:off+16] with nibble vector `nv` */
1972+ #define TQ_Q4K_ACC (nv , off , vdot , vsum ) \
1973+ do { \
1974+ uint16x8_t w16_l = vmovl_u8(vget_low_u8(nv)); \
1975+ uint16x8_t w16_h = vmovl_u8(vget_high_u8(nv)); \
1976+ float32x4_t wf0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(w16_l))); \
1977+ float32x4_t wf1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(w16_l))); \
1978+ float32x4_t wf2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(w16_h))); \
1979+ float32x4_t wf3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(w16_h))); \
1980+ float32x4_t x0 = vld1q_f32(xp + (off)); \
1981+ float32x4_t x1 = vld1q_f32(xp + (off) + 4); \
1982+ float32x4_t x2 = vld1q_f32(xp + (off) + 8); \
1983+ float32x4_t x3 = vld1q_f32(xp + (off) + 12); \
1984+ (vdot) = vfmaq_f32((vdot), wf0, x0); \
1985+ (vdot) = vfmaq_f32((vdot), wf1, x1); \
1986+ (vdot) = vfmaq_f32((vdot), wf2, x2); \
1987+ (vdot) = vfmaq_f32((vdot), wf3, x3); \
1988+ (vsum) = vaddq_f32((vsum), vaddq_f32(vaddq_f32(x0, x1), vaddq_f32(x2, x3))); \
1989+ } while (0)
1990+ TQ_Q4K_ACC (lo_a , j + 0 , vdot1 , vsum1 );
1991+ TQ_Q4K_ACC (lo_b , j + 16 , vdot1 , vsum1 );
1992+ TQ_Q4K_ACC (hi_a , j + 32 , vdot2 , vsum2 );
1993+ TQ_Q4K_ACC (hi_b , j + 48 , vdot2 , vsum2 );
1994+ #undef TQ_Q4K_ACC
1995+
1996+ float dot1_s = vaddvq_f32 (vdot1 );
1997+ float sum1_s = vaddvq_f32 (vsum1 );
1998+ float dot2_s = vaddvq_f32 (vdot2 );
1999+ float sum2_s = vaddvq_f32 (vsum2 );
2000+ sum += d1 * dot1_s - m1 * sum1_s ;
2001+ sum += d2 * dot2_s - m2 * sum2_s ;
2002+
2003+ q += 32 ;
2004+ is += 2 ;
2005+ }
2006+ #else
2007+ /* Scalar fallback (non-ARM) */
18502008 for (int j = 0 ; j < 256 ; j += 64 ) {
18512009 const float d1 = d * sc [is + 0 ];
18522010 const float m1 = dmin * mn [is + 0 ];
18532011 const float d2 = d * sc [is + 1 ];
18542012 const float m2 = dmin * mn [is + 1 ];
18552013
1856- /* First 32 elements: low nibble */
18572014 float dot1 = 0.0f , sum_x1 = 0.0f ;
18582015 for (int l = 0 ; l < 32 ; l ++ ) {
18592016 dot1 += (float )(q [l ] & 0x0F ) * xp [j + l ];
18602017 sum_x1 += xp [j + l ];
18612018 }
18622019 sum += d1 * dot1 - m1 * sum_x1 ;
18632020
1864- /* Next 32 elements: high nibble */
18652021 float dot2 = 0.0f , sum_x2 = 0.0f ;
18662022 for (int l = 0 ; l < 32 ; l ++ ) {
18672023 dot2 += (float )(q [l ] >> 4 ) * xp [j + 32 + l ];
@@ -1872,6 +2028,7 @@ static float fused_dot_q4_k(const void* row, const float* x, int n) {
18722028 q += 32 ;
18732029 is += 2 ;
18742030 }
2031+ #endif
18752032 }
18762033 return sum ;
18772034}
@@ -2416,6 +2573,9 @@ void tq_matmul_gguf(float* out, const float* x,
24162573 case TQ_GGML_TYPE_Q8_0 :
24172574 fused_dot = fused_dot_q8_0 ;
24182575 break ;
2576+ case TQ_GGML_TYPE_Q2_K :
2577+ fused_dot = fused_dot_q2_k ;
2578+ break ;
24192579 case TQ_GGML_TYPE_Q8_1 :
24202580 fused_dot = fused_dot_q8_1 ;
24212581 break ;
0 commit comments