Skip to content

Commit aadd059

Browse files
unamedkrclaude
andcommitted
perf(q4k): 2-row ILP in q4k_int_dot_worker
Process two output rows in parallel within the same inner loop. Both rows share x_qs/x_ds/x_isums (read-only activation), so pairing them hides weight-load latency: while one row's vdotq_s32 is in flight the other's nibble unpack and activation broadcast can issue. Phi-3.5 Q4_K_M: 5.9 → 6.2 tok/s (+5%, total session +94% vs 3.2). Quality: 11/11 STRICT+COHERENT+Metal-ON pass. Identical generation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9a94e0a commit aadd059

1 file changed

Lines changed: 147 additions & 7 deletions

File tree

src/engine/tq_gguf_quants.c

Lines changed: 147 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2516,21 +2516,161 @@ void* q5k_int_dot_worker(void* arg) {
25162516
return NULL;
25172517
}
25182518

2519+
/* 2-row inner loop helper for q4k_int worker — processes 2 output rows
2520+
* in parallel for instruction-level parallelism. The two rows share the
2521+
* same x_qs / x_ds / x_isums (read-only activation), but have different
2522+
* weight rows. Parallelism hides load latency on the weight reads. */
2523+
static inline void q4k_int_dot_two_rows(
2524+
const block_q4_K* blk0, const block_q4_K* blk1,
2525+
const int8_t* x_qs, const float* x_ds, const int32_t* x_isums,
2526+
int nb_super, float* out0, float* out1)
2527+
{
2528+
const uint8x16_t mask_lo = vdupq_n_u8(0x0F);
2529+
float sum0 = 0.0f, sum1 = 0.0f;
2530+
for (int sb = 0; sb < nb_super; sb++) {
2531+
const block_q4_K* b0 = blk0 + sb;
2532+
const block_q4_K* b1 = blk1 + sb;
2533+
const float dW0 = fp16_to_fp32(b0->d);
2534+
const float dminW0 = fp16_to_fp32(b0->dmin);
2535+
const float dW1 = fp16_to_fp32(b1->d);
2536+
const float dminW1 = fp16_to_fp32(b1->dmin);
2537+
2538+
uint8_t sc0[8], mn0[8], sc1[8], mn1[8];
2539+
#define UNPACK(SC, MN, BLK) \
2540+
do { \
2541+
SC[0] = BLK->scales[0] & 63; \
2542+
SC[1] = BLK->scales[1] & 63; \
2543+
SC[2] = BLK->scales[2] & 63; \
2544+
SC[3] = BLK->scales[3] & 63; \
2545+
MN[0] = BLK->scales[4] & 63; \
2546+
MN[1] = BLK->scales[5] & 63; \
2547+
MN[2] = BLK->scales[6] & 63; \
2548+
MN[3] = BLK->scales[7] & 63; \
2549+
SC[4] = (BLK->scales[8] & 0x0F) | ((BLK->scales[0] >> 6) << 4); \
2550+
SC[5] = (BLK->scales[9] & 0x0F) | ((BLK->scales[1] >> 6) << 4); \
2551+
SC[6] = (BLK->scales[10] & 0x0F) | ((BLK->scales[2] >> 6) << 4); \
2552+
SC[7] = (BLK->scales[11] & 0x0F) | ((BLK->scales[3] >> 6) << 4); \
2553+
MN[4] = (BLK->scales[8] >> 4) | ((BLK->scales[4] >> 6) << 4); \
2554+
MN[5] = (BLK->scales[9] >> 4) | ((BLK->scales[5] >> 6) << 4); \
2555+
MN[6] = (BLK->scales[10] >> 4) | ((BLK->scales[6] >> 6) << 4); \
2556+
MN[7] = (BLK->scales[11] >> 4) | ((BLK->scales[7] >> 6) << 4); \
2557+
} while (0)
2558+
UNPACK(sc0, mn0, b0);
2559+
UNPACK(sc1, mn1, b1);
2560+
#undef UNPACK
2561+
2562+
const uint8_t* q0 = b0->qs;
2563+
const uint8_t* q1 = b1->qs;
2564+
int sub_base = sb * 8;
2565+
int is = 0;
2566+
2567+
for (int j = 0; j < 256; j += 64) {
2568+
int sub_idx_a = sub_base + is;
2569+
int sub_idx_b = sub_base + is + 1;
2570+
2571+
/* Load both rows' nibbles in parallel */
2572+
uint8x16_t qa0 = vld1q_u8(q0);
2573+
uint8x16_t qb0 = vld1q_u8(q0 + 16);
2574+
uint8x16_t qa1 = vld1q_u8(q1);
2575+
uint8x16_t qb1 = vld1q_u8(q1 + 16);
2576+
int8x16_t wa_lo0 = vreinterpretq_s8_u8(vandq_u8(qa0, mask_lo));
2577+
int8x16_t wa_hi0 = vreinterpretq_s8_u8(vandq_u8(qb0, mask_lo));
2578+
int8x16_t wb_lo0 = vreinterpretq_s8_u8(vshrq_n_u8(qa0, 4));
2579+
int8x16_t wb_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(qb0, 4));
2580+
int8x16_t wa_lo1 = vreinterpretq_s8_u8(vandq_u8(qa1, mask_lo));
2581+
int8x16_t wa_hi1 = vreinterpretq_s8_u8(vandq_u8(qb1, mask_lo));
2582+
int8x16_t wb_lo1 = vreinterpretq_s8_u8(vshrq_n_u8(qa1, 4));
2583+
int8x16_t wb_hi1 = vreinterpretq_s8_u8(vshrq_n_u8(qb1, 4));
2584+
2585+
const int8_t* xa = x_qs + (size_t)sub_idx_a * 32;
2586+
int8x16_t xa_lo = vld1q_s8(xa);
2587+
int8x16_t xa_hi = vld1q_s8(xa + 16);
2588+
const int8_t* xb = x_qs + (size_t)sub_idx_b * 32;
2589+
int8x16_t xb_lo = vld1q_s8(xb);
2590+
int8x16_t xb_hi = vld1q_s8(xb + 16);
2591+
2592+
#ifdef __ARM_FEATURE_DOTPROD
2593+
int32x4_t accA0 = vdotq_s32(vdupq_n_s32(0), wa_lo0, xa_lo);
2594+
accA0 = vdotq_s32(accA0, wa_hi0, xa_hi);
2595+
int32x4_t accA1 = vdotq_s32(vdupq_n_s32(0), wa_lo1, xa_lo);
2596+
accA1 = vdotq_s32(accA1, wa_hi1, xa_hi);
2597+
int32x4_t accB0 = vdotq_s32(vdupq_n_s32(0), wb_lo0, xb_lo);
2598+
accB0 = vdotq_s32(accB0, wb_hi0, xb_hi);
2599+
int32x4_t accB1 = vdotq_s32(vdupq_n_s32(0), wb_lo1, xb_lo);
2600+
accB1 = vdotq_s32(accB1, wb_hi1, xb_hi);
2601+
int32_t isumA0 = vaddvq_s32(accA0);
2602+
int32_t isumA1 = vaddvq_s32(accA1);
2603+
int32_t isumB0 = vaddvq_s32(accB0);
2604+
int32_t isumB1 = vaddvq_s32(accB1);
2605+
#else
2606+
/* Fallback: same as single-row but doubled; relies on compiler
2607+
* to find ILP. Already a win on M1 if vmull/vpadalq saturate. */
2608+
int32x4_t accA0 = vpadalq_s16(vdupq_n_s32(0), vmull_s8(vget_low_s8(wa_lo0), vget_low_s8(xa_lo)));
2609+
accA0 = vpadalq_s16(accA0, vmull_s8(vget_high_s8(wa_lo0), vget_high_s8(xa_lo)));
2610+
accA0 = vpadalq_s16(accA0, vmull_s8(vget_low_s8(wa_hi0), vget_low_s8(xa_hi)));
2611+
accA0 = vpadalq_s16(accA0, vmull_s8(vget_high_s8(wa_hi0), vget_high_s8(xa_hi)));
2612+
int32x4_t accA1 = vpadalq_s16(vdupq_n_s32(0), vmull_s8(vget_low_s8(wa_lo1), vget_low_s8(xa_lo)));
2613+
accA1 = vpadalq_s16(accA1, vmull_s8(vget_high_s8(wa_lo1), vget_high_s8(xa_lo)));
2614+
accA1 = vpadalq_s16(accA1, vmull_s8(vget_low_s8(wa_hi1), vget_low_s8(xa_hi)));
2615+
accA1 = vpadalq_s16(accA1, vmull_s8(vget_high_s8(wa_hi1), vget_high_s8(xa_hi)));
2616+
int32x4_t accB0 = vpadalq_s16(vdupq_n_s32(0), vmull_s8(vget_low_s8(wb_lo0), vget_low_s8(xb_lo)));
2617+
accB0 = vpadalq_s16(accB0, vmull_s8(vget_high_s8(wb_lo0), vget_high_s8(xb_lo)));
2618+
accB0 = vpadalq_s16(accB0, vmull_s8(vget_low_s8(wb_hi0), vget_low_s8(xb_hi)));
2619+
accB0 = vpadalq_s16(accB0, vmull_s8(vget_high_s8(wb_hi0), vget_high_s8(xb_hi)));
2620+
int32x4_t accB1 = vpadalq_s16(vdupq_n_s32(0), vmull_s8(vget_low_s8(wb_lo1), vget_low_s8(xb_lo)));
2621+
accB1 = vpadalq_s16(accB1, vmull_s8(vget_high_s8(wb_lo1), vget_high_s8(xb_lo)));
2622+
accB1 = vpadalq_s16(accB1, vmull_s8(vget_low_s8(wb_hi1), vget_low_s8(xb_hi)));
2623+
accB1 = vpadalq_s16(accB1, vmull_s8(vget_high_s8(wb_hi1), vget_high_s8(xb_hi)));
2624+
int32_t isumA0 = vaddvq_s32(accA0);
2625+
int32_t isumA1 = vaddvq_s32(accA1);
2626+
int32_t isumB0 = vaddvq_s32(accB0);
2627+
int32_t isumB1 = vaddvq_s32(accB1);
2628+
#endif
2629+
2630+
float xdA = x_ds[sub_idx_a];
2631+
float xdB = x_ds[sub_idx_b];
2632+
int32_t xisA = x_isums[sub_idx_a];
2633+
int32_t xisB = x_isums[sub_idx_b];
2634+
2635+
sum0 += (dW0 * sc0[is + 0] * xdA) * (float)isumA0
2636+
- (dminW0 * mn0[is + 0] * xdA) * (float)xisA;
2637+
sum0 += (dW0 * sc0[is + 1] * xdB) * (float)isumB0
2638+
- (dminW0 * mn0[is + 1] * xdB) * (float)xisB;
2639+
sum1 += (dW1 * sc1[is + 0] * xdA) * (float)isumA1
2640+
- (dminW1 * mn1[is + 0] * xdA) * (float)xisA;
2641+
sum1 += (dW1 * sc1[is + 1] * xdB) * (float)isumB1
2642+
- (dminW1 * mn1[is + 1] * xdB) * (float)xisB;
2643+
2644+
q0 += 32;
2645+
q1 += 32;
2646+
is += 2;
2647+
}
2648+
}
2649+
*out0 = sum0;
2650+
*out1 = sum1;
2651+
}
2652+
25192653
void* q4k_int_dot_worker(void* arg) {
25202654
q4k_int_task_t* t = (q4k_int_task_t*)arg;
25212655
const uint8x16_t mask_lo = vdupq_n_u8(0x0F);
2522-
for (int d = t->start_row; d < t->end_row; d++) {
2523-
const block_q4_K* wblk = (const block_q4_K*)((const uint8_t*)t->weight + (size_t)d * t->row_bytes);
2524-
/* Prefetch the next row's weights while we're processing this one.
2525-
* Each row is nb_super * 144 bytes; for typical Phi-3.5 (in_dim=3072,
2526-
* 12 super-blocks) that's 1728 bytes. Prefetch 4 cache lines ahead. */
2527-
if (d + 1 < t->end_row) {
2528-
const uint8_t* next = (const uint8_t*)t->weight + (size_t)(d + 1) * t->row_bytes;
2656+
int d = t->start_row;
2657+
/* 2-row inner loop while we have pairs */
2658+
for (; d + 1 < t->end_row; d += 2) {
2659+
const block_q4_K* wblk0 = (const block_q4_K*)((const uint8_t*)t->weight + (size_t)d * t->row_bytes);
2660+
const block_q4_K* wblk1 = (const block_q4_K*)((const uint8_t*)t->weight + (size_t)(d + 1) * t->row_bytes);
2661+
if (d + 2 < t->end_row) {
2662+
const uint8_t* next = (const uint8_t*)t->weight + (size_t)(d + 2) * t->row_bytes;
25292663
__builtin_prefetch(next + 0, 0, 0);
25302664
__builtin_prefetch(next + 64, 0, 0);
25312665
__builtin_prefetch(next + 128, 0, 0);
25322666
__builtin_prefetch(next + 192, 0, 0);
25332667
}
2668+
q4k_int_dot_two_rows(wblk0, wblk1, t->x_qs, t->x_ds, t->x_isums,
2669+
t->nb_super, &t->out[d], &t->out[d + 1]);
2670+
}
2671+
/* Tail: single row */
2672+
for (; d < t->end_row; d++) {
2673+
const block_q4_K* wblk = (const block_q4_K*)((const uint8_t*)t->weight + (size_t)d * t->row_bytes);
25342674
float row_sum = 0.0f;
25352675
for (int sb = 0; sb < t->nb_super; sb++) {
25362676
const block_q4_K* blk = wblk + sb;

0 commit comments

Comments
 (0)