@@ -1791,6 +1791,67 @@ static float fused_dot_q4_k(const void* row, const float* x, int n) {
17911791 return sum ;
17921792}
17931793
1794+ /* Fused Q3_K dot product: 110 bytes per 256 elements
1795+ * 3-bit = 2 low bits (qs) + 1 high bit (hmask)
1796+ * 16 sub-blocks with 6-bit scales packed into 12 bytes */
1797+ static float fused_dot_q3_k (const void * row , const float * x , int n ) {
1798+ const int nb = n / 256 ;
1799+ const block_q3_K * blk = (const block_q3_K * )row ;
1800+ float sum = 0.0f ;
1801+
1802+ const uint32_t kmask1 = 0x03030303 ;
1803+ const uint32_t kmask2 = 0x0f0f0f0f ;
1804+
1805+ uint32_t aux [4 ];
1806+ const int8_t * scales = (const int8_t * )aux ;
1807+
1808+ for (int b = 0 ; b < nb ; b ++ ) {
1809+ const float d_all = fp16_to_fp32 (blk [b ].d );
1810+
1811+ const uint8_t * q = blk [b ].qs ;
1812+ const uint8_t * hm = blk [b ].hmask ;
1813+ uint8_t m = 1 ;
1814+
1815+ /* Decode 16 x 6-bit scales (same as dequant_q3_k) */
1816+ memcpy (aux , blk [b ].scales , 12 );
1817+ uint32_t tmp = aux [2 ];
1818+ aux [2 ] = ((aux [0 ] >> 4 ) & kmask2 ) | (((tmp >> 4 ) & kmask1 ) << 4 );
1819+ aux [3 ] = ((aux [1 ] >> 4 ) & kmask2 ) | (((tmp >> 6 ) & kmask1 ) << 4 );
1820+ aux [0 ] = (aux [0 ] & kmask2 ) | (((tmp >> 0 ) & kmask1 ) << 4 );
1821+ aux [1 ] = (aux [1 ] & kmask2 ) | (((tmp >> 2 ) & kmask1 ) << 4 );
1822+
1823+ int is = 0 ;
1824+ const float * xp = x + b * 256 ;
1825+ int yi = 0 ;
1826+
1827+ for (int half = 0 ; half < 2 ; half ++ ) {
1828+ int shift = 0 ;
1829+ for (int j = 0 ; j < 4 ; ++ j ) {
1830+ float dl = d_all * (scales [is ++ ] - 32 );
1831+ float dot = 0.0f ;
1832+ for (int l = 0 ; l < 16 ; ++ l ) {
1833+ dot += xp [yi + l ] * (float )((int8_t )((q [l + 0 ] >> shift ) & 3 ) - ((hm [l + 0 ] & m ) ? 0 : 4 ));
1834+ }
1835+ sum += dl * dot ;
1836+ yi += 16 ;
1837+
1838+ dl = d_all * (scales [is ++ ] - 32 );
1839+ dot = 0.0f ;
1840+ for (int l = 0 ; l < 16 ; ++ l ) {
1841+ dot += xp [yi + l ] * (float )((int8_t )((q [l + 16 ] >> shift ) & 3 ) - ((hm [l + 16 ] & m ) ? 0 : 4 ));
1842+ }
1843+ sum += dl * dot ;
1844+ yi += 16 ;
1845+
1846+ shift += 2 ;
1847+ m <<= 1 ;
1848+ }
1849+ q += 32 ;
1850+ }
1851+ }
1852+ return sum ;
1853+ }
1854+
17941855/* Fused Q4_0 dot product: 18 bytes per 32 elements */
17951856static float fused_dot_q4_0 (const void * row , const float * x , int n ) {
17961857 const int nb = n / 32 ;
@@ -1951,6 +2012,73 @@ static void* gguf_matmul_worker(void* arg) {
19512012 return NULL ;
19522013}
19532014
2015+ /* Pre-quantize input vector to Q8 format for int8×int8 matmul.
2016+ * Called once in transformer, result reused for Q/K/V/O projections.
2017+ * Stores int8 values in qs[n], per-block scales in ds[n/32]. */
2018+ void tq_preq_input_q8 (const float * x , int8_t * qs , float * ds , int n ) {
2019+ int nb = n / 32 ;
2020+ for (int b = 0 ; b < nb ; b ++ ) {
2021+ const float * xp = x + b * 32 ;
2022+ float amax = 0.0f ;
2023+ #if TQ_HAS_NEON
2024+ float32x4_t vmax = vdupq_n_f32 (0.0f );
2025+ for (int j = 0 ; j < 32 ; j += 4 ) {
2026+ float32x4_t vx = vld1q_f32 (xp + j );
2027+ vmax = vmaxq_f32 (vmax , vabsq_f32 (vx ));
2028+ }
2029+ amax = vmaxvq_f32 (vmax );
2030+ #else
2031+ for (int j = 0 ; j < 32 ; j ++ ) {
2032+ float a = xp [j ] < 0 ? - xp [j ] : xp [j ];
2033+ if (a > amax ) amax = a ;
2034+ }
2035+ #endif
2036+ float d = amax / 127.0f ;
2037+ ds [b ] = d ;
2038+ if (d > 0.0f ) {
2039+ float id = 127.0f / amax ;
2040+ #if TQ_HAS_NEON
2041+ float32x4_t vid = vdupq_n_f32 (id );
2042+ for (int j = 0 ; j < 32 ; j += 8 ) {
2043+ float32x4_t v0 = vmulq_f32 (vld1q_f32 (xp + j ), vid );
2044+ float32x4_t v1 = vmulq_f32 (vld1q_f32 (xp + j + 4 ), vid );
2045+ int32x4_t i0 = vcvtnq_s32_f32 (v0 );
2046+ int32x4_t i1 = vcvtnq_s32_f32 (v1 );
2047+ int16x4_t s0 = vqmovn_s32 (i0 );
2048+ int16x4_t s1 = vqmovn_s32 (i1 );
2049+ int8x8_t b8 = vqmovn_s16 (vcombine_s16 (s0 , s1 ));
2050+ vst1_s8 (qs + b * 32 + j , b8 );
2051+ }
2052+ #else
2053+ for (int j = 0 ; j < 32 ; j ++ ) {
2054+ int v = (int )roundf (xp [j ] * id );
2055+ qs [b * 32 + j ] = (int8_t )(v < -128 ? -128 : (v > 127 ? 127 : v ));
2056+ }
2057+ #endif
2058+ } else {
2059+ memset (qs + b * 32 , 0 , 32 );
2060+ }
2061+ }
2062+ }
2063+
2064+ /* Thread-local pre-quantized input pointer (set by tq_matmul_gguf when available) */
2065+ #ifdef _MSC_VER
2066+ static __declspec(thread ) const int8_t * g_preq_qs = NULL ;
2067+ static __declspec(thread ) const float * g_preq_ds = NULL ;
2068+ #else
2069+ static __thread const int8_t * g_preq_qs = NULL ;
2070+ static __thread const float * g_preq_ds = NULL ;
2071+ #endif
2072+
2073+ void tq_set_preq (const int8_t * qs , const float * ds ) {
2074+ g_preq_qs = qs ;
2075+ g_preq_ds = ds ;
2076+ }
2077+ void tq_clear_preq (void ) {
2078+ g_preq_qs = NULL ;
2079+ g_preq_ds = NULL ;
2080+ }
2081+
19542082/* Q8×Q8 integer dot worker — processes a range of output rows using int8 multiply-accumulate */
19552083#if TQ_HAS_NEON
19562084typedef struct {
@@ -2036,10 +2164,40 @@ void tq_matmul_gguf(float* out, const float* x,
20362164 const int n_blocks = in_dim / block_elems ;
20372165 const size_t row_bytes = (size_t )n_blocks * block_bytes ;
20382166
2039- /* Q8×Q8 integer dot: benchmarked slower than float fused dot due to per-call
2040- * input quantization overhead. Keeping code in q8_int_dot_worker for future
2041- * "quantize-once-in-transformer" optimization.
2042- * TODO: move input Q8 quantization to transformer level, call once per layer. */
2167+ /* Q8×Q8 integer dot: input pre-quantized in transformer (once per layer).
2168+ * Uses int8×int8 vmull_s8+vpadalq_s16 — ~2x faster than float fused dot. */
2169+ #if TQ_HAS_NEON
2170+ if (weight_type == TQ_GGML_TYPE_Q8_0 && g_preq_qs != NULL && in_dim <= 4096 ) {
2171+ const int8_t * xqs = g_preq_qs ;
2172+ const float * xds = g_preq_ds ;
2173+
2174+ /* Always use thread pool — pass preq data via task struct (TLS doesn't propagate) */
2175+ int n_threads = tq_get_threads ();
2176+ if (n_threads > TQ_TP_MAX ) n_threads = TQ_TP_MAX ;
2177+ if (n_threads > out_dim ) n_threads = out_dim ;
2178+ if (n_threads < 1 ) n_threads = 1 ;
2179+
2180+ q8_int_task_t tasks [TQ_TP_MAX ];
2181+ void * ptrs [TQ_TP_MAX ];
2182+ int rows_per = out_dim / n_threads ;
2183+ for (int t = 0 ; t < n_threads ; t ++ ) {
2184+ tasks [t ] = (q8_int_task_t ){
2185+ .out = out , .weight = weight , .x_qs = xqs , .x_ds = xds ,
2186+ .row_bytes = row_bytes , .n_blocks = n_blocks ,
2187+ .start_row = t * rows_per ,
2188+ .end_row = (t == n_threads - 1 ) ? out_dim : (t + 1 ) * rows_per
2189+ };
2190+ ptrs [t ] = & tasks [t ];
2191+ }
2192+ if (n_threads == 1 ) {
2193+ q8_int_dot_worker (ptrs [0 ]);
2194+ } else {
2195+ extern void tq_tp_run (void * (* fn )(void * ), void * * args , int n );
2196+ tq_tp_run (q8_int_dot_worker , ptrs , n_threads );
2197+ }
2198+ return ;
2199+ }
2200+ #endif
20432201
20442202 /* ---- Q8_0×Q8 integer dot fast path — DISABLED (per-call overhead > benefit) ---- */
20452203#if 0 /* TQ_HAS_NEON */
@@ -2144,6 +2302,9 @@ void tq_matmul_gguf(float* out, const float* x,
21442302 case TQ_GGML_TYPE_Q8_0 :
21452303 fused_dot = fused_dot_q8_0 ;
21462304 break ;
2305+ case TQ_GGML_TYPE_Q3_K :
2306+ fused_dot = fused_dot_q3_k ;
2307+ break ;
21472308 case TQ_GGML_TYPE_Q4_K :
21482309 fused_dot = fused_dot_q4_k ;
21492310 break ;
0 commit comments