Skip to content

Commit 9a94e0a

Browse files
unamedkrclaude
andcommitted
perf: Q5_K int8 dot + weight-row prefetching
Two more wins on top of the 66% Q4_K int8 jump: 1. **Q5_K int8 fused dot worker**. Was generic dequant-row + FP32 dot. Same pattern as Q4_K: vdotq_s32 over nibble + qh-bit unpacked to int8. Used by Qwen3.5-4B DeltaNet attn_qkv/attn_gate matmuls. 2. **Explicit __builtin_prefetch for next weight row**. Q4_K/Q5_K/Q8_0 workers now prefetch the next row's first 256 bytes (4 cache lines) while processing the current row. M1 Pro hardware prefetcher does not always pick up the row-stride pattern between matmul iterations. vs llama.cpp on the same hardware (4-thread defaults retained): Phi-3.5 Q4_K_M: 3.2 → 5.9 tok/s (+85% session total) Qwen3.5-4B Q4_K_M: 3.5 → 5.7 tok/s (+63%) 11/11 STRICT+COHERENT+Metal-ON pass; 6/6 long-seq pass. No regression. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8f5784a commit 9a94e0a

1 file changed

Lines changed: 163 additions & 4 deletions

File tree

src/engine/tq_gguf_quants.c

Lines changed: 163 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2333,6 +2333,15 @@ void* q8_int_dot_worker(void* arg) {
23332333
q8_int_task_t* t = (q8_int_task_t*)arg;
23342334
for (int d = t->start_row; d < t->end_row; d++) {
23352335
const block_q8_0* wblk = (const block_q8_0*)((const uint8_t*)t->weight + (size_t)d * t->row_bytes);
2336+
if (d + 1 < t->end_row) {
2337+
const uint8_t* next = (const uint8_t*)t->weight + (size_t)(d + 1) * t->row_bytes;
2338+
/* Q8_0 row is n_blocks * 34 bytes — for hidden=3072, ~3.4 KB.
2339+
* Prefetch first 4 cache lines; HW prefetcher takes the rest. */
2340+
__builtin_prefetch(next + 0, 0, 0);
2341+
__builtin_prefetch(next + 64, 0, 0);
2342+
__builtin_prefetch(next + 128, 0, 0);
2343+
__builtin_prefetch(next + 192, 0, 0);
2344+
}
23362345
float row_sum = 0.0f;
23372346
for (int b = 0; b < t->n_blocks; b++) {
23382347
const float wd = fp16_to_fp32(wblk[b].d);
@@ -2369,11 +2378,159 @@ typedef struct {
23692378
const int32_t* x_isums; size_t row_bytes; int nb_super; int start_row; int end_row;
23702379
} q4k_int_task_t;
23712380

2381+
/* Q5_K int8 dot worker — same pattern as Q4_K, plus 5th bit from qh.
2382+
* The qh array has 1 bit per element across 8 sub-blocks; bit position
2383+
* shifts by 2 per j-iteration (u1 = 1<<(2*iter), u2 = 2<<(2*iter)). */
2384+
typedef q4k_int_task_t q5k_int_task_t;
2385+
2386+
void* q5k_int_dot_worker(void* arg) {
2387+
q5k_int_task_t* t = (q5k_int_task_t*)arg;
2388+
const uint8x16_t mask_lo = vdupq_n_u8(0x0F);
2389+
for (int d = t->start_row; d < t->end_row; d++) {
2390+
const block_q5_K* wblk = (const block_q5_K*)((const uint8_t*)t->weight + (size_t)d * t->row_bytes);
2391+
if (d + 1 < t->end_row) {
2392+
const uint8_t* next = (const uint8_t*)t->weight + (size_t)(d + 1) * t->row_bytes;
2393+
__builtin_prefetch(next + 0, 0, 0);
2394+
__builtin_prefetch(next + 64, 0, 0);
2395+
__builtin_prefetch(next + 128, 0, 0);
2396+
__builtin_prefetch(next + 192, 0, 0);
2397+
}
2398+
float row_sum = 0.0f;
2399+
for (int sb = 0; sb < t->nb_super; sb++) {
2400+
const block_q5_K* blk = wblk + sb;
2401+
const float dW = fp16_to_fp32(blk->d);
2402+
const float dminW = fp16_to_fp32(blk->dmin);
2403+
2404+
uint8_t sc[8], mn[8];
2405+
sc[0] = blk->scales[0] & 63;
2406+
sc[1] = blk->scales[1] & 63;
2407+
sc[2] = blk->scales[2] & 63;
2408+
sc[3] = blk->scales[3] & 63;
2409+
mn[0] = blk->scales[4] & 63;
2410+
mn[1] = blk->scales[5] & 63;
2411+
mn[2] = blk->scales[6] & 63;
2412+
mn[3] = blk->scales[7] & 63;
2413+
sc[4] = (blk->scales[8] & 0x0F) | ((blk->scales[0] >> 6) << 4);
2414+
sc[5] = (blk->scales[9] & 0x0F) | ((blk->scales[1] >> 6) << 4);
2415+
sc[6] = (blk->scales[10] & 0x0F) | ((blk->scales[2] >> 6) << 4);
2416+
sc[7] = (blk->scales[11] & 0x0F) | ((blk->scales[3] >> 6) << 4);
2417+
mn[4] = (blk->scales[8] >> 4) | ((blk->scales[4] >> 6) << 4);
2418+
mn[5] = (blk->scales[9] >> 4) | ((blk->scales[5] >> 6) << 4);
2419+
mn[6] = (blk->scales[10] >> 4) | ((blk->scales[6] >> 6) << 4);
2420+
mn[7] = (blk->scales[11] >> 4) | ((blk->scales[7] >> 6) << 4);
2421+
2422+
const uint8_t* ql = blk->qs; /* 128 bytes of low nibbles */
2423+
const uint8_t* qh = blk->qh; /* 32 bytes of high bits */
2424+
int sub_base = sb * 8;
2425+
int is = 0;
2426+
uint8_t u1 = 1, u2 = 2;
2427+
2428+
for (int j = 0; j < 256; j += 64) {
2429+
int sub_idx_a = sub_base + is;
2430+
int sub_idx_b = sub_base + is + 1;
2431+
2432+
/* Low 4 bits of weights */
2433+
uint8x16_t qa = vld1q_u8(ql);
2434+
uint8x16_t qb = vld1q_u8(ql + 16);
2435+
uint8x16_t lo_a = vandq_u8(qa, mask_lo);
2436+
uint8x16_t lo_b = vandq_u8(qb, mask_lo);
2437+
uint8x16_t hi_a = vshrq_n_u8(qa, 4);
2438+
uint8x16_t hi_b = vshrq_n_u8(qb, 4);
2439+
2440+
/* 5th bit from qh: extract bit u1 (sub-block A) and u2 (B).
2441+
* qh[0..15] covers elements 0..15, qh[16..31] covers 16..31.
2442+
* For sub-block A (low nibbles), bit position = log2(u1).
2443+
* Convert "bit set" → byte value 16 by shifting to bit 4. */
2444+
uint8x16_t qh_a = vld1q_u8(qh);
2445+
uint8x16_t qh_b = vld1q_u8(qh + 16);
2446+
uint8x16_t u1v = vdupq_n_u8(u1);
2447+
uint8x16_t u2v = vdupq_n_u8(u2);
2448+
/* Test bit, then convert to 0 or 16:
2449+
* masked_a = qh & u1 → 0 or u1 (in {1,4,16,64})
2450+
* want bit 4 set when u1 bit set → multiply masked_a by (16/u1)
2451+
* For u1 in {1,4,16,64}: 16/u1 in {16,4,1,1/4}.
2452+
* Easier: vceqq + select between 0 and 16. */
2453+
uint8x16_t bit_a_lo = vceqq_u8(vandq_u8(qh_a, u1v), u1v); /* 0xFF or 0x00 */
2454+
uint8x16_t bit_a_hi = vceqq_u8(vandq_u8(qh_b, u1v), u1v);
2455+
uint8x16_t bit_b_lo = vceqq_u8(vandq_u8(qh_a, u2v), u2v);
2456+
uint8x16_t bit_b_hi = vceqq_u8(vandq_u8(qh_b, u2v), u2v);
2457+
uint8x16_t v16 = vdupq_n_u8(16);
2458+
/* And with 16 to get 0 or 16 */
2459+
lo_a = vorrq_u8(lo_a, vandq_u8(bit_a_lo, v16));
2460+
lo_b = vorrq_u8(lo_b, vandq_u8(bit_a_hi, v16));
2461+
hi_a = vorrq_u8(hi_a, vandq_u8(bit_b_lo, v16));
2462+
hi_b = vorrq_u8(hi_b, vandq_u8(bit_b_hi, v16));
2463+
2464+
int8x16_t wa_lo = vreinterpretq_s8_u8(lo_a);
2465+
int8x16_t wa_hi = vreinterpretq_s8_u8(lo_b);
2466+
int8x16_t wb_lo = vreinterpretq_s8_u8(hi_a);
2467+
int8x16_t wb_hi = vreinterpretq_s8_u8(hi_b);
2468+
2469+
const int8_t* xa = t->x_qs + (size_t)sub_idx_a * 32;
2470+
int8x16_t xa_lo = vld1q_s8(xa);
2471+
int8x16_t xa_hi = vld1q_s8(xa + 16);
2472+
const int8_t* xb = t->x_qs + (size_t)sub_idx_b * 32;
2473+
int8x16_t xb_lo = vld1q_s8(xb);
2474+
int8x16_t xb_hi = vld1q_s8(xb + 16);
2475+
2476+
#ifdef __ARM_FEATURE_DOTPROD
2477+
int32x4_t accA = vdotq_s32(vdupq_n_s32(0), wa_lo, xa_lo);
2478+
accA = vdotq_s32(accA, wa_hi, xa_hi);
2479+
int32_t isumA = vaddvq_s32(accA);
2480+
int32x4_t accB = vdotq_s32(vdupq_n_s32(0), wb_lo, xb_lo);
2481+
accB = vdotq_s32(accB, wb_hi, xb_hi);
2482+
int32_t isumB = vaddvq_s32(accB);
2483+
#else
2484+
int32x4_t accA = vpadalq_s16(vdupq_n_s32(0),
2485+
vmull_s8(vget_low_s8(wa_lo), vget_low_s8(xa_lo)));
2486+
accA = vpadalq_s16(accA, vmull_s8(vget_high_s8(wa_lo), vget_high_s8(xa_lo)));
2487+
accA = vpadalq_s16(accA, vmull_s8(vget_low_s8(wa_hi), vget_low_s8(xa_hi)));
2488+
accA = vpadalq_s16(accA, vmull_s8(vget_high_s8(wa_hi), vget_high_s8(xa_hi)));
2489+
int32_t isumA = vaddvq_s32(accA);
2490+
int32x4_t accB = vpadalq_s16(vdupq_n_s32(0),
2491+
vmull_s8(vget_low_s8(wb_lo), vget_low_s8(xb_lo)));
2492+
accB = vpadalq_s16(accB, vmull_s8(vget_high_s8(wb_lo), vget_high_s8(xb_lo)));
2493+
accB = vpadalq_s16(accB, vmull_s8(vget_low_s8(wb_hi), vget_low_s8(xb_hi)));
2494+
accB = vpadalq_s16(accB, vmull_s8(vget_high_s8(wb_hi), vget_high_s8(xb_hi)));
2495+
int32_t isumB = vaddvq_s32(accB);
2496+
#endif
2497+
2498+
float xdA = t->x_ds[sub_idx_a];
2499+
float xdB = t->x_ds[sub_idx_b];
2500+
int32_t xisA = t->x_isums[sub_idx_a];
2501+
int32_t xisB = t->x_isums[sub_idx_b];
2502+
2503+
row_sum += (dW * sc[is + 0] * xdA) * (float)isumA
2504+
- (dminW * mn[is + 0] * xdA) * (float)xisA;
2505+
row_sum += (dW * sc[is + 1] * xdB) * (float)isumB
2506+
- (dminW * mn[is + 1] * xdB) * (float)xisB;
2507+
2508+
ql += 32;
2509+
is += 2;
2510+
u1 <<= 2;
2511+
u2 <<= 2;
2512+
}
2513+
}
2514+
t->out[d] = row_sum;
2515+
}
2516+
return NULL;
2517+
}
2518+
23722519
void* q4k_int_dot_worker(void* arg) {
23732520
q4k_int_task_t* t = (q4k_int_task_t*)arg;
23742521
const uint8x16_t mask_lo = vdupq_n_u8(0x0F);
23752522
for (int d = t->start_row; d < t->end_row; d++) {
23762523
const block_q4_K* wblk = (const block_q4_K*)((const uint8_t*)t->weight + (size_t)d * t->row_bytes);
2524+
/* Prefetch the next row's weights while we're processing this one.
2525+
* Each row is nb_super * 144 bytes; for typical Phi-3.5 (in_dim=3072,
2526+
* 12 super-blocks) that's 1728 bytes. Prefetch 4 cache lines ahead. */
2527+
if (d + 1 < t->end_row) {
2528+
const uint8_t* next = (const uint8_t*)t->weight + (size_t)(d + 1) * t->row_bytes;
2529+
__builtin_prefetch(next + 0, 0, 0);
2530+
__builtin_prefetch(next + 64, 0, 0);
2531+
__builtin_prefetch(next + 128, 0, 0);
2532+
__builtin_prefetch(next + 192, 0, 0);
2533+
}
23772534
float row_sum = 0.0f;
23782535
for (int sb = 0; sb < t->nb_super; sb++) {
23792536
const block_q4_K* blk = wblk + sb;
@@ -2690,8 +2847,8 @@ void tq_matmul_gguf(float* out, const float* x,
26902847
* precompute per-block int sums for the dmin*mn correction, then
26912848
* run vmull_s8 + vpadalq_s16 dots over 4-bit nibbles unpacked to int8.
26922849
* Replaces the float fused_dot_q4_k path on Phi-3.5/Llama Q4_K_M models. */
2693-
if (weight_type == TQ_GGML_TYPE_Q4_K && in_dim >= 256 && in_dim <= 16384
2694-
&& (in_dim % 256) == 0)
2850+
if ((weight_type == TQ_GGML_TYPE_Q4_K || weight_type == TQ_GGML_TYPE_Q5_K)
2851+
&& in_dim >= 256 && in_dim <= 16384 && (in_dim % 256) == 0)
26952852
{
26962853
/* Stack buffers: x as int8 (16KB), per-block scales (512 floats =
26972854
* 2KB), per-block int sums (512 ints = 2KB). Total ~20KB. */
@@ -2746,11 +2903,13 @@ void tq_matmul_gguf(float* out, const float* x,
27462903
tasks[t].end_row = (t == n_threads - 1) ? out_dim : (t + 1) * rows_per;
27472904
ptrs[t] = &tasks[t];
27482905
}
2906+
void* (*worker)(void*) = (weight_type == TQ_GGML_TYPE_Q5_K)
2907+
? q5k_int_dot_worker : q4k_int_dot_worker;
27492908
if (n_threads == 1) {
2750-
q4k_int_dot_worker(ptrs[0]);
2909+
worker(ptrs[0]);
27512910
} else {
27522911
extern void tq_tp_run(void* (*fn)(void*), void** args, int n);
2753-
tq_tp_run(q4k_int_dot_worker, ptrs, n_threads);
2912+
tq_tp_run(worker, ptrs, n_threads);
27542913
}
27552914
return;
27562915
}

0 commit comments

Comments
 (0)