@@ -2333,6 +2333,15 @@ void* q8_int_dot_worker(void* arg) {
23332333 q8_int_task_t * t = (q8_int_task_t * )arg ;
23342334 for (int d = t -> start_row ; d < t -> end_row ; d ++ ) {
23352335 const block_q8_0 * wblk = (const block_q8_0 * )((const uint8_t * )t -> weight + (size_t )d * t -> row_bytes );
2336+ if (d + 1 < t -> end_row ) {
2337+ const uint8_t * next = (const uint8_t * )t -> weight + (size_t )(d + 1 ) * t -> row_bytes ;
2338+ /* Q8_0 row is n_blocks * 34 bytes — for hidden=3072, ~3.4 KB.
2339+ * Prefetch first 4 cache lines; HW prefetcher takes the rest. */
2340+ __builtin_prefetch (next + 0 , 0 , 0 );
2341+ __builtin_prefetch (next + 64 , 0 , 0 );
2342+ __builtin_prefetch (next + 128 , 0 , 0 );
2343+ __builtin_prefetch (next + 192 , 0 , 0 );
2344+ }
23362345 float row_sum = 0.0f ;
23372346 for (int b = 0 ; b < t -> n_blocks ; b ++ ) {
23382347 const float wd = fp16_to_fp32 (wblk [b ].d );
@@ -2369,11 +2378,159 @@ typedef struct {
23692378 const int32_t * x_isums ; size_t row_bytes ; int nb_super ; int start_row ; int end_row ;
23702379} q4k_int_task_t ;
23712380
2381+ /* Q5_K int8 dot worker — same pattern as Q4_K, plus 5th bit from qh.
2382+ * The qh array has 1 bit per element across 8 sub-blocks; bit position
2383+ * shifts by 2 per j-iteration (u1 = 1<<(2*iter), u2 = 2<<(2*iter)). */
2384+ typedef q4k_int_task_t q5k_int_task_t ;
2385+
2386+ void * q5k_int_dot_worker (void * arg ) {
2387+ q5k_int_task_t * t = (q5k_int_task_t * )arg ;
2388+ const uint8x16_t mask_lo = vdupq_n_u8 (0x0F );
2389+ for (int d = t -> start_row ; d < t -> end_row ; d ++ ) {
2390+ const block_q5_K * wblk = (const block_q5_K * )((const uint8_t * )t -> weight + (size_t )d * t -> row_bytes );
2391+ if (d + 1 < t -> end_row ) {
2392+ const uint8_t * next = (const uint8_t * )t -> weight + (size_t )(d + 1 ) * t -> row_bytes ;
2393+ __builtin_prefetch (next + 0 , 0 , 0 );
2394+ __builtin_prefetch (next + 64 , 0 , 0 );
2395+ __builtin_prefetch (next + 128 , 0 , 0 );
2396+ __builtin_prefetch (next + 192 , 0 , 0 );
2397+ }
2398+ float row_sum = 0.0f ;
2399+ for (int sb = 0 ; sb < t -> nb_super ; sb ++ ) {
2400+ const block_q5_K * blk = wblk + sb ;
2401+ const float dW = fp16_to_fp32 (blk -> d );
2402+ const float dminW = fp16_to_fp32 (blk -> dmin );
2403+
2404+ uint8_t sc [8 ], mn [8 ];
2405+ sc [0 ] = blk -> scales [0 ] & 63 ;
2406+ sc [1 ] = blk -> scales [1 ] & 63 ;
2407+ sc [2 ] = blk -> scales [2 ] & 63 ;
2408+ sc [3 ] = blk -> scales [3 ] & 63 ;
2409+ mn [0 ] = blk -> scales [4 ] & 63 ;
2410+ mn [1 ] = blk -> scales [5 ] & 63 ;
2411+ mn [2 ] = blk -> scales [6 ] & 63 ;
2412+ mn [3 ] = blk -> scales [7 ] & 63 ;
2413+ sc [4 ] = (blk -> scales [8 ] & 0x0F ) | ((blk -> scales [0 ] >> 6 ) << 4 );
2414+ sc [5 ] = (blk -> scales [9 ] & 0x0F ) | ((blk -> scales [1 ] >> 6 ) << 4 );
2415+ sc [6 ] = (blk -> scales [10 ] & 0x0F ) | ((blk -> scales [2 ] >> 6 ) << 4 );
2416+ sc [7 ] = (blk -> scales [11 ] & 0x0F ) | ((blk -> scales [3 ] >> 6 ) << 4 );
2417+ mn [4 ] = (blk -> scales [8 ] >> 4 ) | ((blk -> scales [4 ] >> 6 ) << 4 );
2418+ mn [5 ] = (blk -> scales [9 ] >> 4 ) | ((blk -> scales [5 ] >> 6 ) << 4 );
2419+ mn [6 ] = (blk -> scales [10 ] >> 4 ) | ((blk -> scales [6 ] >> 6 ) << 4 );
2420+ mn [7 ] = (blk -> scales [11 ] >> 4 ) | ((blk -> scales [7 ] >> 6 ) << 4 );
2421+
2422+ const uint8_t * ql = blk -> qs ; /* 128 bytes of low nibbles */
2423+ const uint8_t * qh = blk -> qh ; /* 32 bytes of high bits */
2424+ int sub_base = sb * 8 ;
2425+ int is = 0 ;
2426+ uint8_t u1 = 1 , u2 = 2 ;
2427+
2428+ for (int j = 0 ; j < 256 ; j += 64 ) {
2429+ int sub_idx_a = sub_base + is ;
2430+ int sub_idx_b = sub_base + is + 1 ;
2431+
2432+ /* Low 4 bits of weights */
2433+ uint8x16_t qa = vld1q_u8 (ql );
2434+ uint8x16_t qb = vld1q_u8 (ql + 16 );
2435+ uint8x16_t lo_a = vandq_u8 (qa , mask_lo );
2436+ uint8x16_t lo_b = vandq_u8 (qb , mask_lo );
2437+ uint8x16_t hi_a = vshrq_n_u8 (qa , 4 );
2438+ uint8x16_t hi_b = vshrq_n_u8 (qb , 4 );
2439+
2440+ /* 5th bit from qh: extract bit u1 (sub-block A) and u2 (B).
2441+ * qh[0..15] covers elements 0..15, qh[16..31] covers 16..31.
2442+ * For sub-block A (low nibbles), bit position = log2(u1).
2443+ * Convert "bit set" → byte value 16 by shifting to bit 4. */
2444+ uint8x16_t qh_a = vld1q_u8 (qh );
2445+ uint8x16_t qh_b = vld1q_u8 (qh + 16 );
2446+ uint8x16_t u1v = vdupq_n_u8 (u1 );
2447+ uint8x16_t u2v = vdupq_n_u8 (u2 );
2448+ /* Test bit, then convert to 0 or 16:
2449+ * masked_a = qh & u1 → 0 or u1 (in {1,4,16,64})
2450+ * want bit 4 set when u1 bit set → multiply masked_a by (16/u1)
2451+ * For u1 in {1,4,16,64}: 16/u1 in {16,4,1,1/4}.
2452+ * Easier: vceqq + select between 0 and 16. */
2453+ uint8x16_t bit_a_lo = vceqq_u8 (vandq_u8 (qh_a , u1v ), u1v ); /* 0xFF or 0x00 */
2454+ uint8x16_t bit_a_hi = vceqq_u8 (vandq_u8 (qh_b , u1v ), u1v );
2455+ uint8x16_t bit_b_lo = vceqq_u8 (vandq_u8 (qh_a , u2v ), u2v );
2456+ uint8x16_t bit_b_hi = vceqq_u8 (vandq_u8 (qh_b , u2v ), u2v );
2457+ uint8x16_t v16 = vdupq_n_u8 (16 );
2458+ /* And with 16 to get 0 or 16 */
2459+ lo_a = vorrq_u8 (lo_a , vandq_u8 (bit_a_lo , v16 ));
2460+ lo_b = vorrq_u8 (lo_b , vandq_u8 (bit_a_hi , v16 ));
2461+ hi_a = vorrq_u8 (hi_a , vandq_u8 (bit_b_lo , v16 ));
2462+ hi_b = vorrq_u8 (hi_b , vandq_u8 (bit_b_hi , v16 ));
2463+
2464+ int8x16_t wa_lo = vreinterpretq_s8_u8 (lo_a );
2465+ int8x16_t wa_hi = vreinterpretq_s8_u8 (lo_b );
2466+ int8x16_t wb_lo = vreinterpretq_s8_u8 (hi_a );
2467+ int8x16_t wb_hi = vreinterpretq_s8_u8 (hi_b );
2468+
2469+ const int8_t * xa = t -> x_qs + (size_t )sub_idx_a * 32 ;
2470+ int8x16_t xa_lo = vld1q_s8 (xa );
2471+ int8x16_t xa_hi = vld1q_s8 (xa + 16 );
2472+ const int8_t * xb = t -> x_qs + (size_t )sub_idx_b * 32 ;
2473+ int8x16_t xb_lo = vld1q_s8 (xb );
2474+ int8x16_t xb_hi = vld1q_s8 (xb + 16 );
2475+
2476+ #ifdef __ARM_FEATURE_DOTPROD
2477+ int32x4_t accA = vdotq_s32 (vdupq_n_s32 (0 ), wa_lo , xa_lo );
2478+ accA = vdotq_s32 (accA , wa_hi , xa_hi );
2479+ int32_t isumA = vaddvq_s32 (accA );
2480+ int32x4_t accB = vdotq_s32 (vdupq_n_s32 (0 ), wb_lo , xb_lo );
2481+ accB = vdotq_s32 (accB , wb_hi , xb_hi );
2482+ int32_t isumB = vaddvq_s32 (accB );
2483+ #else
2484+ int32x4_t accA = vpadalq_s16 (vdupq_n_s32 (0 ),
2485+ vmull_s8 (vget_low_s8 (wa_lo ), vget_low_s8 (xa_lo )));
2486+ accA = vpadalq_s16 (accA , vmull_s8 (vget_high_s8 (wa_lo ), vget_high_s8 (xa_lo )));
2487+ accA = vpadalq_s16 (accA , vmull_s8 (vget_low_s8 (wa_hi ), vget_low_s8 (xa_hi )));
2488+ accA = vpadalq_s16 (accA , vmull_s8 (vget_high_s8 (wa_hi ), vget_high_s8 (xa_hi )));
2489+ int32_t isumA = vaddvq_s32 (accA );
2490+ int32x4_t accB = vpadalq_s16 (vdupq_n_s32 (0 ),
2491+ vmull_s8 (vget_low_s8 (wb_lo ), vget_low_s8 (xb_lo )));
2492+ accB = vpadalq_s16 (accB , vmull_s8 (vget_high_s8 (wb_lo ), vget_high_s8 (xb_lo )));
2493+ accB = vpadalq_s16 (accB , vmull_s8 (vget_low_s8 (wb_hi ), vget_low_s8 (xb_hi )));
2494+ accB = vpadalq_s16 (accB , vmull_s8 (vget_high_s8 (wb_hi ), vget_high_s8 (xb_hi )));
2495+ int32_t isumB = vaddvq_s32 (accB );
2496+ #endif
2497+
2498+ float xdA = t -> x_ds [sub_idx_a ];
2499+ float xdB = t -> x_ds [sub_idx_b ];
2500+ int32_t xisA = t -> x_isums [sub_idx_a ];
2501+ int32_t xisB = t -> x_isums [sub_idx_b ];
2502+
2503+ row_sum += (dW * sc [is + 0 ] * xdA ) * (float )isumA
2504+ - (dminW * mn [is + 0 ] * xdA ) * (float )xisA ;
2505+ row_sum += (dW * sc [is + 1 ] * xdB ) * (float )isumB
2506+ - (dminW * mn [is + 1 ] * xdB ) * (float )xisB ;
2507+
2508+ ql += 32 ;
2509+ is += 2 ;
2510+ u1 <<= 2 ;
2511+ u2 <<= 2 ;
2512+ }
2513+ }
2514+ t -> out [d ] = row_sum ;
2515+ }
2516+ return NULL ;
2517+ }
2518+
23722519void * q4k_int_dot_worker (void * arg ) {
23732520 q4k_int_task_t * t = (q4k_int_task_t * )arg ;
23742521 const uint8x16_t mask_lo = vdupq_n_u8 (0x0F );
23752522 for (int d = t -> start_row ; d < t -> end_row ; d ++ ) {
23762523 const block_q4_K * wblk = (const block_q4_K * )((const uint8_t * )t -> weight + (size_t )d * t -> row_bytes );
2524+ /* Prefetch the next row's weights while we're processing this one.
2525+ * Each row is nb_super * 144 bytes; for typical Phi-3.5 (in_dim=3072,
2526+ * 12 super-blocks) that's 1728 bytes. Prefetch 4 cache lines ahead. */
2527+ if (d + 1 < t -> end_row ) {
2528+ const uint8_t * next = (const uint8_t * )t -> weight + (size_t )(d + 1 ) * t -> row_bytes ;
2529+ __builtin_prefetch (next + 0 , 0 , 0 );
2530+ __builtin_prefetch (next + 64 , 0 , 0 );
2531+ __builtin_prefetch (next + 128 , 0 , 0 );
2532+ __builtin_prefetch (next + 192 , 0 , 0 );
2533+ }
23772534 float row_sum = 0.0f ;
23782535 for (int sb = 0 ; sb < t -> nb_super ; sb ++ ) {
23792536 const block_q4_K * blk = wblk + sb ;
@@ -2690,8 +2847,8 @@ void tq_matmul_gguf(float* out, const float* x,
26902847 * precompute per-block int sums for the dmin*mn correction, then
26912848 * run vmull_s8 + vpadalq_s16 dots over 4-bit nibbles unpacked to int8.
26922849 * Replaces the float fused_dot_q4_k path on Phi-3.5/Llama Q4_K_M models. */
2693- if (weight_type == TQ_GGML_TYPE_Q4_K && in_dim >= 256 && in_dim <= 16384
2694- && (in_dim % 256 ) == 0 )
2850+ if (( weight_type == TQ_GGML_TYPE_Q4_K || weight_type == TQ_GGML_TYPE_Q5_K )
2851+ && in_dim >= 256 && in_dim <= 16384 && (in_dim % 256 ) == 0 )
26952852 {
26962853 /* Stack buffers: x as int8 (16KB), per-block scales (512 floats =
26972854 * 2KB), per-block int sums (512 ints = 2KB). Total ~20KB. */
@@ -2746,11 +2903,13 @@ void tq_matmul_gguf(float* out, const float* x,
27462903 tasks [t ].end_row = (t == n_threads - 1 ) ? out_dim : (t + 1 ) * rows_per ;
27472904 ptrs [t ] = & tasks [t ];
27482905 }
2906+ void * (* worker )(void * ) = (weight_type == TQ_GGML_TYPE_Q5_K )
2907+ ? q5k_int_dot_worker : q4k_int_dot_worker ;
27492908 if (n_threads == 1 ) {
2750- q4k_int_dot_worker (ptrs [0 ]);
2909+ worker (ptrs [0 ]);
27512910 } else {
27522911 extern void tq_tp_run (void * (* fn )(void * ), void * * args , int n );
2753- tq_tp_run (q4k_int_dot_worker , ptrs , n_threads );
2912+ tq_tp_run (worker , ptrs , n_threads );
27542913 }
27552914 return ;
27562915 }
0 commit comments