1515
1616#include <string.h>
1717#include <stdio.h>
18+ #include <stdlib.h>
1819#include <math.h>
1920
2021#if defined(__ARM_NEON ) || defined(__ARM_NEON__ )
@@ -2337,6 +2338,13 @@ void* q8_int_dot_worker(void* arg) {
23372338 const float wd = fp16_to_fp32 (wblk [b ].d );
23382339 const int8_t * wqs = wblk [b ].qs ;
23392340 const int8_t * xqs = t -> x_qs + b * 32 ;
2341+ #ifdef __ARM_FEATURE_DOTPROD
2342+ /* ARMv8.2 dotprod: 16 int8 MACs per instruction (M1+ have this). */
2343+ int32x4_t vd = vdupq_n_s32 (0 );
2344+ vd = vdotq_s32 (vd , vld1q_s8 (wqs + 0 ), vld1q_s8 (xqs + 0 ));
2345+ vd = vdotq_s32 (vd , vld1q_s8 (wqs + 16 ), vld1q_s8 (xqs + 16 ));
2346+ row_sum += wd * t -> x_ds [b ] * (float )vaddvq_s32 (vd );
2347+ #else
23402348 int32x4_t vd0 = vdupq_n_s32 (0 ), vd1 = vdupq_n_s32 (0 );
23412349 for (int j = 0 ; j < 32 ; j += 16 ) {
23422350 int8x16_t vw = vld1q_s8 (wqs + j );
@@ -2345,6 +2353,126 @@ void* q8_int_dot_worker(void* arg) {
23452353 vd1 = vpadalq_s16 (vd1 , vmull_s8 (vget_high_s8 (vw ), vget_high_s8 (vx )));
23462354 }
23472355 row_sum += wd * t -> x_ds [b ] * (float )vaddvq_s32 (vaddq_s32 (vd0 , vd1 ));
2356+ #endif
2357+ }
2358+ t -> out [d ] = row_sum ;
2359+ }
2360+ return NULL ;
2361+ }
2362+
2363+ /* Q4_K int8 dot worker — same idea as q8_int but with on-the-fly nibble unpack.
2364+ * Pre-quantized x: int8 array (x_qs), per-32-element scales (x_ds), and
2365+ * pre-summed int sums per 32-element block (x_isums) so the dmin*mn correction
2366+ * doesn't recompute sum(x_int8) per output row. */
2367+ typedef struct {
2368+ float * out ; const void * weight ; const int8_t * x_qs ; const float * x_ds ;
2369+ const int32_t * x_isums ; size_t row_bytes ; int nb_super ; int start_row ; int end_row ;
2370+ } q4k_int_task_t ;
2371+
2372+ void * q4k_int_dot_worker (void * arg ) {
2373+ q4k_int_task_t * t = (q4k_int_task_t * )arg ;
2374+ const uint8x16_t mask_lo = vdupq_n_u8 (0x0F );
2375+ for (int d = t -> start_row ; d < t -> end_row ; d ++ ) {
2376+ const block_q4_K * wblk = (const block_q4_K * )((const uint8_t * )t -> weight + (size_t )d * t -> row_bytes );
2377+ float row_sum = 0.0f ;
2378+ for (int sb = 0 ; sb < t -> nb_super ; sb ++ ) {
2379+ const block_q4_K * blk = wblk + sb ;
2380+ const float dW = fp16_to_fp32 (blk -> d );
2381+ const float dminW = fp16_to_fp32 (blk -> dmin );
2382+
2383+ /* 6-bit packed sub-block scales (sc) and mins (mn).
2384+ * Layout matches fused_dot_q4_k. */
2385+ uint8_t sc [8 ], mn [8 ];
2386+ sc [0 ] = blk -> scales [0 ] & 63 ;
2387+ sc [1 ] = blk -> scales [1 ] & 63 ;
2388+ sc [2 ] = blk -> scales [2 ] & 63 ;
2389+ sc [3 ] = blk -> scales [3 ] & 63 ;
2390+ mn [0 ] = blk -> scales [4 ] & 63 ;
2391+ mn [1 ] = blk -> scales [5 ] & 63 ;
2392+ mn [2 ] = blk -> scales [6 ] & 63 ;
2393+ mn [3 ] = blk -> scales [7 ] & 63 ;
2394+ sc [4 ] = (blk -> scales [8 ] & 0x0F ) | ((blk -> scales [0 ] >> 6 ) << 4 );
2395+ sc [5 ] = (blk -> scales [9 ] & 0x0F ) | ((blk -> scales [1 ] >> 6 ) << 4 );
2396+ sc [6 ] = (blk -> scales [10 ] & 0x0F ) | ((blk -> scales [2 ] >> 6 ) << 4 );
2397+ sc [7 ] = (blk -> scales [11 ] & 0x0F ) | ((blk -> scales [3 ] >> 6 ) << 4 );
2398+ mn [4 ] = (blk -> scales [8 ] >> 4 ) | ((blk -> scales [4 ] >> 6 ) << 4 );
2399+ mn [5 ] = (blk -> scales [9 ] >> 4 ) | ((blk -> scales [5 ] >> 6 ) << 4 );
2400+ mn [6 ] = (blk -> scales [10 ] >> 4 ) | ((blk -> scales [6 ] >> 6 ) << 4 );
2401+ mn [7 ] = (blk -> scales [11 ] >> 4 ) | ((blk -> scales [7 ] >> 6 ) << 4 );
2402+
2403+ const uint8_t * q = blk -> qs ;
2404+ int sub_base = sb * 8 ;
2405+ int is = 0 ;
2406+
2407+ /* 4 j-iterations × 64 elements = 256-element super-block.
2408+ * Each iteration handles two 32-element sub-blocks (lo+hi nibbles). */
2409+ for (int j = 0 ; j < 256 ; j += 64 ) {
2410+ int sub_idx_a = sub_base + is ; /* offset j..j+31 */
2411+ int sub_idx_b = sub_base + is + 1 ; /* offset j+32..j+63 */
2412+
2413+ /* Load 32 bytes of packed nibbles */
2414+ uint8x16_t qa = vld1q_u8 (q );
2415+ uint8x16_t qb = vld1q_u8 (q + 16 );
2416+ /* lo_a: weights j..j+15, lo_b: weights j+16..j+31 */
2417+ int8x16_t wa_lo = vreinterpretq_s8_u8 (vandq_u8 (qa , mask_lo ));
2418+ int8x16_t wa_hi = vreinterpretq_s8_u8 (vandq_u8 (qb , mask_lo ));
2419+ /* hi_a: weights j+32..j+47, hi_b: weights j+48..j+63 */
2420+ int8x16_t wb_lo = vreinterpretq_s8_u8 (vshrq_n_u8 (qa , 4 ));
2421+ int8x16_t wb_hi = vreinterpretq_s8_u8 (vshrq_n_u8 (qb , 4 ));
2422+
2423+ /* x for sub-block A: 32 int8 values starting at sub_idx_a*32 */
2424+ const int8_t * xa = t -> x_qs + (size_t )sub_idx_a * 32 ;
2425+ int8x16_t xa_lo = vld1q_s8 (xa );
2426+ int8x16_t xa_hi = vld1q_s8 (xa + 16 );
2427+ /* x for sub-block B */
2428+ const int8_t * xb = t -> x_qs + (size_t )sub_idx_b * 32 ;
2429+ int8x16_t xb_lo = vld1q_s8 (xb );
2430+ int8x16_t xb_hi = vld1q_s8 (xb + 16 );
2431+
2432+ #ifdef __ARM_FEATURE_DOTPROD
2433+ /* ARMv8.2 dotprod: 16 int8 MACs per call. 2 calls = full sub-block. */
2434+ int32x4_t accA = vdotq_s32 (vdupq_n_s32 (0 ), wa_lo , xa_lo );
2435+ accA = vdotq_s32 (accA , wa_hi , xa_hi );
2436+ int32_t isumA = vaddvq_s32 (accA );
2437+ int32x4_t accB = vdotq_s32 (vdupq_n_s32 (0 ), wb_lo , xb_lo );
2438+ accB = vdotq_s32 (accB , wb_hi , xb_hi );
2439+ int32_t isumB = vaddvq_s32 (accB );
2440+ #else
2441+ /* int8 dot for sub-block A: 4 widening multiplies, padalq accumulates */
2442+ int32x4_t accA = vpadalq_s16 (vdupq_n_s32 (0 ),
2443+ vmull_s8 (vget_low_s8 (wa_lo ), vget_low_s8 (xa_lo )));
2444+ accA = vpadalq_s16 (accA , vmull_s8 (vget_high_s8 (wa_lo ), vget_high_s8 (xa_lo )));
2445+ accA = vpadalq_s16 (accA , vmull_s8 (vget_low_s8 (wa_hi ), vget_low_s8 (xa_hi )));
2446+ accA = vpadalq_s16 (accA , vmull_s8 (vget_high_s8 (wa_hi ), vget_high_s8 (xa_hi )));
2447+ int32_t isumA = vaddvq_s32 (accA );
2448+
2449+ int32x4_t accB = vpadalq_s16 (vdupq_n_s32 (0 ),
2450+ vmull_s8 (vget_low_s8 (wb_lo ), vget_low_s8 (xb_lo )));
2451+ accB = vpadalq_s16 (accB , vmull_s8 (vget_high_s8 (wb_lo ), vget_high_s8 (xb_lo )));
2452+ accB = vpadalq_s16 (accB , vmull_s8 (vget_low_s8 (wb_hi ), vget_low_s8 (xb_hi )));
2453+ accB = vpadalq_s16 (accB , vmull_s8 (vget_high_s8 (wb_hi ), vget_high_s8 (xb_hi )));
2454+ int32_t isumB = vaddvq_s32 (accB );
2455+ #endif
2456+
2457+ /* Combine: weight_i = q_i * (d*sc) - (dmin*mn)
2458+ * dot = sum(w_i * x_i)
2459+ * = (d*sc) * sum(q_i * x_int8_i) * x_d
2460+ * - (dmin*mn) * sum(x_int8_i) * x_d
2461+ * First term uses isum (just computed). Second term uses
2462+ * precomputed x_isums to avoid re-summing per row. */
2463+ float xdA = t -> x_ds [sub_idx_a ];
2464+ float xdB = t -> x_ds [sub_idx_b ];
2465+ int32_t xisA = t -> x_isums [sub_idx_a ];
2466+ int32_t xisB = t -> x_isums [sub_idx_b ];
2467+
2468+ row_sum += (dW * sc [is + 0 ] * xdA ) * (float )isumA
2469+ - (dminW * mn [is + 0 ] * xdA ) * (float )xisA ;
2470+ row_sum += (dW * sc [is + 1 ] * xdB ) * (float )isumB
2471+ - (dminW * mn [is + 1 ] * xdB ) * (float )xisB ;
2472+
2473+ q += 32 ;
2474+ is += 2 ;
2475+ }
23482476 }
23492477 t -> out [d ] = row_sum ;
23502478 }
@@ -2556,6 +2684,76 @@ void tq_matmul_gguf(float* out, const float* x,
25562684 return ;
25572685 }
25582686 }
2687+
2688+ /* ---- Q4_K × int8 dot fast path (auto-quantize activation) ----
2689+ * Same pattern as Q8_0: quantize x once to int8 (32-element blocks),
2690+ * precompute per-block int sums for the dmin*mn correction, then
2691+ * run vmull_s8 + vpadalq_s16 dots over 4-bit nibbles unpacked to int8.
2692+ * Replaces the float fused_dot_q4_k path on Phi-3.5/Llama Q4_K_M models. */
2693+ if (weight_type == TQ_GGML_TYPE_Q4_K && in_dim >= 256 && in_dim <= 16384
2694+ && (in_dim % 256 ) == 0 )
2695+ {
2696+ /* Stack buffers: x as int8 (16KB), per-block scales (512 floats =
2697+ * 2KB), per-block int sums (512 ints = 2KB). Total ~20KB. */
2698+ int8_t x_qs [16384 ];
2699+ float x_ds [512 ];
2700+ int32_t x_isums [512 ];
2701+
2702+ /* Step 1: Per-32-element-block quantization of x to int8. */
2703+ const int n_blocks_x = in_dim / 32 ;
2704+ for (int b = 0 ; b < n_blocks_x ; b ++ ) {
2705+ const float * xp = x + b * 32 ;
2706+ float amax = 0.0f ;
2707+ for (int j = 0 ; j < 32 ; j ++ ) {
2708+ float a = xp [j ] < 0 ? - xp [j ] : xp [j ];
2709+ if (a > amax ) amax = a ;
2710+ }
2711+ float d = amax / 127.0f ;
2712+ x_ds [b ] = d ;
2713+ int32_t isum = 0 ;
2714+ if (d > 0.0f ) {
2715+ float id = 1.0f / d ;
2716+ for (int j = 0 ; j < 32 ; j ++ ) {
2717+ int v = (int )roundf (xp [j ] * id );
2718+ int8_t q = (int8_t )(v < -128 ? -128 : (v > 127 ? 127 : v ));
2719+ x_qs [b * 32 + j ] = q ;
2720+ isum += q ;
2721+ }
2722+ } else {
2723+ memset (x_qs + b * 32 , 0 , 32 );
2724+ }
2725+ x_isums [b ] = isum ;
2726+ }
2727+
2728+ const int nb_super = in_dim / 256 ; /* number of 256-elem super-blocks */
2729+ int n_threads = tq_get_threads ();
2730+ if (n_threads > TQ_TP_MAX ) n_threads = TQ_TP_MAX ;
2731+ if (n_threads > out_dim ) n_threads = out_dim ;
2732+ if (n_threads < 1 ) n_threads = 1 ;
2733+
2734+ q4k_int_task_t tasks [TQ_TP_MAX ];
2735+ void * ptrs [TQ_TP_MAX ];
2736+ int rows_per = out_dim / n_threads ;
2737+ for (int t = 0 ; t < n_threads ; t ++ ) {
2738+ tasks [t ].out = out ;
2739+ tasks [t ].weight = weight ;
2740+ tasks [t ].x_qs = x_qs ;
2741+ tasks [t ].x_ds = x_ds ;
2742+ tasks [t ].x_isums = x_isums ;
2743+ tasks [t ].row_bytes = row_bytes ;
2744+ tasks [t ].nb_super = nb_super ;
2745+ tasks [t ].start_row = t * rows_per ;
2746+ tasks [t ].end_row = (t == n_threads - 1 ) ? out_dim : (t + 1 ) * rows_per ;
2747+ ptrs [t ] = & tasks [t ];
2748+ }
2749+ if (n_threads == 1 ) {
2750+ q4k_int_dot_worker (ptrs [0 ]);
2751+ } else {
2752+ extern void tq_tp_run (void * (* fn )(void * ), void * * args , int n );
2753+ tq_tp_run (q4k_int_dot_worker , ptrs , n_threads );
2754+ }
2755+ return ;
2756+ }
25592757#endif
25602758
25612759 /* ---- Fused fast paths: dequant + dot in one pass, no tmp buffer ---- */
0 commit comments