@@ -2516,21 +2516,161 @@ void* q5k_int_dot_worker(void* arg) {
25162516 return NULL ;
25172517}
25182518
2519+ /* 2-row inner loop helper for q4k_int worker — processes 2 output rows
2520+ * in parallel for instruction-level parallelism. The two rows share the
2521+ * same x_qs / x_ds / x_isums (read-only activation), but have different
2522+ * weight rows. Parallelism hides load latency on the weight reads. */
2523+ static inline void q4k_int_dot_two_rows (
2524+ const block_q4_K * blk0 , const block_q4_K * blk1 ,
2525+ const int8_t * x_qs , const float * x_ds , const int32_t * x_isums ,
2526+ int nb_super , float * out0 , float * out1 )
2527+ {
2528+ const uint8x16_t mask_lo = vdupq_n_u8 (0x0F );
2529+ float sum0 = 0.0f , sum1 = 0.0f ;
2530+ for (int sb = 0 ; sb < nb_super ; sb ++ ) {
2531+ const block_q4_K * b0 = blk0 + sb ;
2532+ const block_q4_K * b1 = blk1 + sb ;
2533+ const float dW0 = fp16_to_fp32 (b0 -> d );
2534+ const float dminW0 = fp16_to_fp32 (b0 -> dmin );
2535+ const float dW1 = fp16_to_fp32 (b1 -> d );
2536+ const float dminW1 = fp16_to_fp32 (b1 -> dmin );
2537+
2538+ uint8_t sc0 [8 ], mn0 [8 ], sc1 [8 ], mn1 [8 ];
2539+ #define UNPACK (SC , MN , BLK ) \
2540+ do { \
2541+ SC[0] = BLK->scales[0] & 63; \
2542+ SC[1] = BLK->scales[1] & 63; \
2543+ SC[2] = BLK->scales[2] & 63; \
2544+ SC[3] = BLK->scales[3] & 63; \
2545+ MN[0] = BLK->scales[4] & 63; \
2546+ MN[1] = BLK->scales[5] & 63; \
2547+ MN[2] = BLK->scales[6] & 63; \
2548+ MN[3] = BLK->scales[7] & 63; \
2549+ SC[4] = (BLK->scales[8] & 0x0F) | ((BLK->scales[0] >> 6) << 4); \
2550+ SC[5] = (BLK->scales[9] & 0x0F) | ((BLK->scales[1] >> 6) << 4); \
2551+ SC[6] = (BLK->scales[10] & 0x0F) | ((BLK->scales[2] >> 6) << 4); \
2552+ SC[7] = (BLK->scales[11] & 0x0F) | ((BLK->scales[3] >> 6) << 4); \
2553+ MN[4] = (BLK->scales[8] >> 4) | ((BLK->scales[4] >> 6) << 4); \
2554+ MN[5] = (BLK->scales[9] >> 4) | ((BLK->scales[5] >> 6) << 4); \
2555+ MN[6] = (BLK->scales[10] >> 4) | ((BLK->scales[6] >> 6) << 4); \
2556+ MN[7] = (BLK->scales[11] >> 4) | ((BLK->scales[7] >> 6) << 4); \
2557+ } while (0)
2558+ UNPACK (sc0 , mn0 , b0 );
2559+ UNPACK (sc1 , mn1 , b1 );
2560+ #undef UNPACK
2561+
2562+ const uint8_t * q0 = b0 -> qs ;
2563+ const uint8_t * q1 = b1 -> qs ;
2564+ int sub_base = sb * 8 ;
2565+ int is = 0 ;
2566+
2567+ for (int j = 0 ; j < 256 ; j += 64 ) {
2568+ int sub_idx_a = sub_base + is ;
2569+ int sub_idx_b = sub_base + is + 1 ;
2570+
2571+ /* Load both rows' nibbles in parallel */
2572+ uint8x16_t qa0 = vld1q_u8 (q0 );
2573+ uint8x16_t qb0 = vld1q_u8 (q0 + 16 );
2574+ uint8x16_t qa1 = vld1q_u8 (q1 );
2575+ uint8x16_t qb1 = vld1q_u8 (q1 + 16 );
2576+ int8x16_t wa_lo0 = vreinterpretq_s8_u8 (vandq_u8 (qa0 , mask_lo ));
2577+ int8x16_t wa_hi0 = vreinterpretq_s8_u8 (vandq_u8 (qb0 , mask_lo ));
2578+ int8x16_t wb_lo0 = vreinterpretq_s8_u8 (vshrq_n_u8 (qa0 , 4 ));
2579+ int8x16_t wb_hi0 = vreinterpretq_s8_u8 (vshrq_n_u8 (qb0 , 4 ));
2580+ int8x16_t wa_lo1 = vreinterpretq_s8_u8 (vandq_u8 (qa1 , mask_lo ));
2581+ int8x16_t wa_hi1 = vreinterpretq_s8_u8 (vandq_u8 (qb1 , mask_lo ));
2582+ int8x16_t wb_lo1 = vreinterpretq_s8_u8 (vshrq_n_u8 (qa1 , 4 ));
2583+ int8x16_t wb_hi1 = vreinterpretq_s8_u8 (vshrq_n_u8 (qb1 , 4 ));
2584+
2585+ const int8_t * xa = x_qs + (size_t )sub_idx_a * 32 ;
2586+ int8x16_t xa_lo = vld1q_s8 (xa );
2587+ int8x16_t xa_hi = vld1q_s8 (xa + 16 );
2588+ const int8_t * xb = x_qs + (size_t )sub_idx_b * 32 ;
2589+ int8x16_t xb_lo = vld1q_s8 (xb );
2590+ int8x16_t xb_hi = vld1q_s8 (xb + 16 );
2591+
2592+ #ifdef __ARM_FEATURE_DOTPROD
2593+ int32x4_t accA0 = vdotq_s32 (vdupq_n_s32 (0 ), wa_lo0 , xa_lo );
2594+ accA0 = vdotq_s32 (accA0 , wa_hi0 , xa_hi );
2595+ int32x4_t accA1 = vdotq_s32 (vdupq_n_s32 (0 ), wa_lo1 , xa_lo );
2596+ accA1 = vdotq_s32 (accA1 , wa_hi1 , xa_hi );
2597+ int32x4_t accB0 = vdotq_s32 (vdupq_n_s32 (0 ), wb_lo0 , xb_lo );
2598+ accB0 = vdotq_s32 (accB0 , wb_hi0 , xb_hi );
2599+ int32x4_t accB1 = vdotq_s32 (vdupq_n_s32 (0 ), wb_lo1 , xb_lo );
2600+ accB1 = vdotq_s32 (accB1 , wb_hi1 , xb_hi );
2601+ int32_t isumA0 = vaddvq_s32 (accA0 );
2602+ int32_t isumA1 = vaddvq_s32 (accA1 );
2603+ int32_t isumB0 = vaddvq_s32 (accB0 );
2604+ int32_t isumB1 = vaddvq_s32 (accB1 );
2605+ #else
2606+ /* Fallback: same as single-row but doubled; relies on compiler
2607+ * to find ILP. Already a win on M1 if vmull/vpadalq saturate. */
2608+ int32x4_t accA0 = vpadalq_s16 (vdupq_n_s32 (0 ), vmull_s8 (vget_low_s8 (wa_lo0 ), vget_low_s8 (xa_lo )));
2609+ accA0 = vpadalq_s16 (accA0 , vmull_s8 (vget_high_s8 (wa_lo0 ), vget_high_s8 (xa_lo )));
2610+ accA0 = vpadalq_s16 (accA0 , vmull_s8 (vget_low_s8 (wa_hi0 ), vget_low_s8 (xa_hi )));
2611+ accA0 = vpadalq_s16 (accA0 , vmull_s8 (vget_high_s8 (wa_hi0 ), vget_high_s8 (xa_hi )));
2612+ int32x4_t accA1 = vpadalq_s16 (vdupq_n_s32 (0 ), vmull_s8 (vget_low_s8 (wa_lo1 ), vget_low_s8 (xa_lo )));
2613+ accA1 = vpadalq_s16 (accA1 , vmull_s8 (vget_high_s8 (wa_lo1 ), vget_high_s8 (xa_lo )));
2614+ accA1 = vpadalq_s16 (accA1 , vmull_s8 (vget_low_s8 (wa_hi1 ), vget_low_s8 (xa_hi )));
2615+ accA1 = vpadalq_s16 (accA1 , vmull_s8 (vget_high_s8 (wa_hi1 ), vget_high_s8 (xa_hi )));
2616+ int32x4_t accB0 = vpadalq_s16 (vdupq_n_s32 (0 ), vmull_s8 (vget_low_s8 (wb_lo0 ), vget_low_s8 (xb_lo )));
2617+ accB0 = vpadalq_s16 (accB0 , vmull_s8 (vget_high_s8 (wb_lo0 ), vget_high_s8 (xb_lo )));
2618+ accB0 = vpadalq_s16 (accB0 , vmull_s8 (vget_low_s8 (wb_hi0 ), vget_low_s8 (xb_hi )));
2619+ accB0 = vpadalq_s16 (accB0 , vmull_s8 (vget_high_s8 (wb_hi0 ), vget_high_s8 (xb_hi )));
2620+ int32x4_t accB1 = vpadalq_s16 (vdupq_n_s32 (0 ), vmull_s8 (vget_low_s8 (wb_lo1 ), vget_low_s8 (xb_lo )));
2621+ accB1 = vpadalq_s16 (accB1 , vmull_s8 (vget_high_s8 (wb_lo1 ), vget_high_s8 (xb_lo )));
2622+ accB1 = vpadalq_s16 (accB1 , vmull_s8 (vget_low_s8 (wb_hi1 ), vget_low_s8 (xb_hi )));
2623+ accB1 = vpadalq_s16 (accB1 , vmull_s8 (vget_high_s8 (wb_hi1 ), vget_high_s8 (xb_hi )));
2624+ int32_t isumA0 = vaddvq_s32 (accA0 );
2625+ int32_t isumA1 = vaddvq_s32 (accA1 );
2626+ int32_t isumB0 = vaddvq_s32 (accB0 );
2627+ int32_t isumB1 = vaddvq_s32 (accB1 );
2628+ #endif
2629+
2630+ float xdA = x_ds [sub_idx_a ];
2631+ float xdB = x_ds [sub_idx_b ];
2632+ int32_t xisA = x_isums [sub_idx_a ];
2633+ int32_t xisB = x_isums [sub_idx_b ];
2634+
2635+ sum0 += (dW0 * sc0 [is + 0 ] * xdA ) * (float )isumA0
2636+ - (dminW0 * mn0 [is + 0 ] * xdA ) * (float )xisA ;
2637+ sum0 += (dW0 * sc0 [is + 1 ] * xdB ) * (float )isumB0
2638+ - (dminW0 * mn0 [is + 1 ] * xdB ) * (float )xisB ;
2639+ sum1 += (dW1 * sc1 [is + 0 ] * xdA ) * (float )isumA1
2640+ - (dminW1 * mn1 [is + 0 ] * xdA ) * (float )xisA ;
2641+ sum1 += (dW1 * sc1 [is + 1 ] * xdB ) * (float )isumB1
2642+ - (dminW1 * mn1 [is + 1 ] * xdB ) * (float )xisB ;
2643+
2644+ q0 += 32 ;
2645+ q1 += 32 ;
2646+ is += 2 ;
2647+ }
2648+ }
2649+ * out0 = sum0 ;
2650+ * out1 = sum1 ;
2651+ }
2652+
25192653void * q4k_int_dot_worker (void * arg ) {
25202654 q4k_int_task_t * t = (q4k_int_task_t * )arg ;
25212655 const uint8x16_t mask_lo = vdupq_n_u8 (0x0F );
2522- for ( int d = t -> start_row ; d < t -> end_row ; d ++ ) {
2523- const block_q4_K * wblk = ( const block_q4_K * )(( const uint8_t * ) t -> weight + ( size_t ) d * t -> row_bytes );
2524- /* Prefetch the next row's weights while we're processing this one.
2525- * Each row is nb_super * 144 bytes; for typical Phi-3.5 (in_dim=3072,
2526- * 12 super-blocks) that's 1728 bytes. Prefetch 4 cache lines ahead. */
2527- if (d + 1 < t -> end_row ) {
2528- const uint8_t * next = (const uint8_t * )t -> weight + (size_t )(d + 1 ) * t -> row_bytes ;
2656+ int d = t -> start_row ;
2657+ /* 2-row inner loop while we have pairs */
2658+ for (; d + 1 < t -> end_row ; d += 2 ) {
2659+ const block_q4_K * wblk0 = ( const block_q4_K * )(( const uint8_t * ) t -> weight + ( size_t ) d * t -> row_bytes );
2660+ const block_q4_K * wblk1 = ( const block_q4_K * )(( const uint8_t * ) t -> weight + ( size_t )( d + 1 ) * t -> row_bytes );
2661+ if (d + 2 < t -> end_row ) {
2662+ const uint8_t * next = (const uint8_t * )t -> weight + (size_t )(d + 2 ) * t -> row_bytes ;
25292663 __builtin_prefetch (next + 0 , 0 , 0 );
25302664 __builtin_prefetch (next + 64 , 0 , 0 );
25312665 __builtin_prefetch (next + 128 , 0 , 0 );
25322666 __builtin_prefetch (next + 192 , 0 , 0 );
25332667 }
2668+ q4k_int_dot_two_rows (wblk0 , wblk1 , t -> x_qs , t -> x_ds , t -> x_isums ,
2669+ t -> nb_super , & t -> out [d ], & t -> out [d + 1 ]);
2670+ }
2671+ /* Tail: single row */
2672+ for (; d < t -> end_row ; d ++ ) {
2673+ const block_q4_K * wblk = (const block_q4_K * )((const uint8_t * )t -> weight + (size_t )d * t -> row_bytes );
25342674 float row_sum = 0.0f ;
25352675 for (int sb = 0 ; sb < t -> nb_super ; sb ++ ) {
25362676 const block_q4_K * blk = wblk + sb ;
0 commit comments