Skip to content

Commit d456c39

Browse files
committed
perf(cuda): separated GPU layout for Q5_0 GEMV
Q5_0 is the dominant weight type in Gemma3-1B Q4_K_M (117 of 170 weight tensors). The interleaved 22-byte block layout required byte-wise __ldg loads on ARM64 Grace Hopper (blocks not 4-byte aligned after block 0). This caused ~40% throughput regression. Introduce a separated GPU layout (scales | qh | qs) where each region is naturally aligned. The GEMV kernel now reads fp16 scales at 2-byte boundaries and uint32 qh at 4-byte boundaries with single __ldg instructions instead of 4-byte-at-a-time reconstruction. Also add RawBytesGPU() to Q5_0Storage for the separated layout, matching the pattern used by Q4Storage.
1 parent dc63c8f commit d456c39

14 files changed

Lines changed: 132 additions & 70 deletions

compute/gpu_engine.go

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -393,13 +393,14 @@ func (e *GPUEngine[T]) UploadWeights(tensors []*tensor.TensorNumeric[float32]) e
393393
q4Uploaded++
394394
continue
395395
}
396-
// Upload Q5_0 raw bytes to GPU for fused GEMV kernel.
397-
// Q5_0 blocks (22 bytes per 32 values) are uploaded contiguously.
396+
// Upload Q5_0 in separated GPU layout (scales | qh | qs) for fast GEMV.
397+
// The separated layout aligns fp16 and uint32 fields naturally, avoiding
398+
// byte-wise loads on ARM64 Grace Hopper.
398399
if qs, ok := any(t.GetStorage()).(*tensor.Q5_0Storage); ok {
399400
if ptr, _, _ := qs.GPUPtr(); ptr != nil {
400401
continue // already on GPU
401402
}
402-
rawBytes := qs.RawBytes()
403+
rawBytes := qs.RawBytesGPU()
403404
devPtr, err := e.allocWeight(len(rawBytes))
404405
if err != nil {
405406
return fmt.Errorf("alloc Q5_0 GPU (shape %v): %w", t.Shape(), err)
@@ -2133,11 +2134,12 @@ func (e *GPUEngine[T]) matMulQ5_0(ctx context.Context, qs *tensor.Q5_0Storage, a
21332134

21342135
var devW unsafe.Pointer
21352136
var freeW func()
2137+
nBlocks := qs.NumBlocks()
21362138
if ptr, _, _ := qs.GPUPtr(); ptr != nil {
21372139
devW = ptr
21382140
freeW = func() {}
21392141
} else {
2140-
rawBytes := qs.RawBytes()
2142+
rawBytes := qs.RawBytesGPU()
21412143
var err error
21422144
devW, err = e.pool.Alloc(e.deviceID, len(rawBytes))
21432145
if err != nil {
@@ -2151,6 +2153,9 @@ func (e *GPUEngine[T]) matMulQ5_0(ctx context.Context, qs *tensor.Q5_0Storage, a
21512153
}
21522154
defer freeW()
21532155

2156+
qhOff := tensor.Q5_0GPUQhOffset(nBlocks)
2157+
qsOff := tensor.Q5_0GPUQsOffset(nBlocks)
2158+
21542159
if n == 1 {
21552160
devX, cleanupX, err := getDevicePtr(e, b)
21562161
if err != nil {
@@ -2165,7 +2170,7 @@ func (e *GPUEngine[T]) matMulQ5_0(ctx context.Context, qs *tensor.Q5_0Storage, a
21652170
return e.cpu.MatMul(ctx, a, b, dst...)
21662171
}
21672172

2168-
if err := e.kernels.GemvQ5_0F32(devW, devX, devY, m, k, e.stream); err != nil {
2173+
if err := e.kernels.GemvQ5_0F32(devW, devX, devY, m, k, qhOff, qsOff, e.stream); err != nil {
21692174
e.pool.Free(e.deviceID, devY, cSize)
21702175
return e.cpu.MatMul(ctx, a, b, dst...)
21712176
}
@@ -2240,11 +2245,12 @@ func (e *GPUEngine[T]) matMulQ5_0BWeight(ctx context.Context, a *tensor.TensorNu
22402245

22412246
var devQ5_0 unsafe.Pointer
22422247
var freeQ5_0 func()
2248+
nBlocks := qs.NumBlocks()
22432249
if ptr, _, _ := qs.GPUPtr(); ptr != nil {
22442250
devQ5_0 = ptr
22452251
freeQ5_0 = func() {}
22462252
} else {
2247-
rawBytes := qs.RawBytes()
2253+
rawBytes := qs.RawBytesGPU()
22482254
var err error
22492255
devQ5_0, err = e.pool.Alloc(e.deviceID, len(rawBytes))
22502256
if err != nil {
@@ -2258,6 +2264,9 @@ func (e *GPUEngine[T]) matMulQ5_0BWeight(ctx context.Context, a *tensor.TensorNu
22582264
}
22592265
defer freeQ5_0()
22602266

2267+
qhOff := tensor.Q5_0GPUQhOffset(nBlocks)
2268+
qsOff := tensor.Q5_0GPUQsOffset(nBlocks)
2269+
22612270
if m == 1 {
22622271
devX, cleanupX, err := getDevicePtr(e, a)
22632272
if err != nil {
@@ -2272,7 +2281,7 @@ func (e *GPUEngine[T]) matMulQ5_0BWeight(ctx context.Context, a *tensor.TensorNu
22722281
return e.cpu.MatMul(ctx, a, b, dst...)
22732282
}
22742283

2275-
if err := e.kernels.GemvQ5_0F32(devQ5_0, devX, devY, n, k, e.stream); err != nil {
2284+
if err := e.kernels.GemvQ5_0F32(devQ5_0, devX, devY, n, k, qhOff, qsOff, e.stream); err != nil {
22762285
e.pool.Free(e.deviceID, devY, cSize)
22772286
return e.cpu.MatMul(ctx, a, b, dst...)
22782287
}

internal/cuda/kernels/gemv_q5_0.cu

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
/* Q5_0 fused dequant-GEMV kernel for single-token decode (batch=1).
22
*
3-
* Reads Q5_0 blocks directly, dequantizes in registers (no global
4-
* memory intermediary), multiplies by the activation vector, and accumulates
5-
* in FP32. This halves memory traffic compared to separate dequant + GEMV.
3+
* GPU-optimized SEPARATED layout (from Q5_0Storage.RawBytesGPU):
4+
* Region 1: [nBlocks * 2 bytes] fp16 scales, padded to 16-byte boundary
5+
* Region 2: [nBlocks * 4 bytes] uint32 qh values, padded to 16-byte boundary
6+
* Region 3: [nBlocks * 16 bytes] packed nibbles (qs)
67
*
7-
* Q5_0 block (22 bytes, 32 values):
8-
* [0:2] fp16 d -- block scale
9-
* [2:6] uint32 qh -- 32 high bits (one per element)
10-
* [6:22] 16 bytes qs -- packed nibbles (two 4-bit values per byte)
8+
* This layout ensures natural alignment: fp16 at 2-byte, uint32 at 4-byte.
9+
* Eliminates the byte-wise loads required for the interleaved 22-byte layout
10+
* on ARM64 Grace Hopper.
1111
*
1212
* Dequantization (matching llama.cpp dequantize_row_q5_0):
1313
* For j in 0..15:
@@ -25,26 +25,22 @@
2525
#include <stdint.h>
2626

2727
#define Q5_0_BLOCK_SIZE 32
28-
#define Q5_0_BLOCK_BYTES 22
2928
#define Q5_0_WARPS_PER_BLOCK 4
3029
#define Q5_0_WARP_SIZE 32
3130

32-
/* ---------- Fused GEMV kernel ----------
31+
/* ---------- Fused GEMV kernel (separated GPU layout) ----------
3332
*
3433
* y[row] = sum_k dequant(W_q5_0[row, k]) * x[k]
3534
*
36-
* Strategy:
37-
* - Load input vector x into shared memory (all threads cooperate).
38-
* - One warp per row for simplicity and good occupancy.
39-
* - Each lane processes a strided subset of blocks.
40-
* - Within each block, 16 packed bytes yield 32 dequantized values.
41-
* - Warp shuffle reduction produces the final dot product.
35+
* W_q5_0 points to the separated layout base. qhOffset and qsOffset
36+
* are byte offsets to the qh and qs regions respectively.
4237
*/
4338
__global__ void gemv_q5_0_kernel(
4439
const uint8_t* __restrict__ W_q5_0,
4540
const float* __restrict__ x,
4641
float* __restrict__ y,
47-
int M, int K)
42+
int M, int K,
43+
int qhOffset, int qsOffset)
4844
{
4945
extern __shared__ float sx[];
5046

@@ -62,27 +58,20 @@ __global__ void gemv_q5_0_kernel(
6258
if (row >= M) return;
6359

6460
int blocks_per_row = K / Q5_0_BLOCK_SIZE;
65-
const uint8_t* row_data = W_q5_0 + (size_t)row * blocks_per_row * Q5_0_BLOCK_BYTES;
61+
62+
/* Pointers to the three separated regions for this row. */
63+
const __half* row_scales = (const __half*)(W_q5_0 + row * blocks_per_row * 2);
64+
const uint32_t* row_qh = (const uint32_t*)(W_q5_0 + qhOffset + row * blocks_per_row * 4);
65+
const uint8_t* row_qs = W_q5_0 + qsOffset + (size_t)row * blocks_per_row * 16;
6666

6767
float acc = 0.0f;
6868

6969
/* Each lane handles a strided subset of blocks. */
7070
for (int bi = lane_id; bi < blocks_per_row; bi += Q5_0_WARP_SIZE) {
71-
const uint8_t* blk = row_data + bi * Q5_0_BLOCK_BYTES;
72-
73-
/* Read fp16 d using byte-wise load (ARM64 alignment safety).
74-
* Q5_0 blocks are 22 bytes — not a multiple of 4, so blk may
75-
* be misaligned for uint16/uint32 casts after the first block. */
76-
uint16_t d_bits = (uint16_t)__ldg(&blk[0]) | ((uint16_t)__ldg(&blk[1]) << 8);
77-
float d = __half2float(*reinterpret_cast<const __half*>(&d_bits));
78-
79-
/* Read qh (32 high bits) using byte-wise load. */
80-
uint32_t qh = (uint32_t)__ldg(&blk[2])
81-
| ((uint32_t)__ldg(&blk[3]) << 8)
82-
| ((uint32_t)__ldg(&blk[4]) << 16)
83-
| ((uint32_t)__ldg(&blk[5]) << 24);
84-
85-
const uint8_t* qs = blk + 6;
71+
/* All loads are naturally aligned in the separated layout. */
72+
float d = __half2float(__ldg(&row_scales[bi]));
73+
uint32_t qh = __ldg(&row_qh[bi]);
74+
const uint8_t* qs = row_qs + bi * 16;
8675
int k_base = bi * Q5_0_BLOCK_SIZE;
8776

8877
/* Process 16 packed bytes -> 32 dequantized values. */
@@ -119,6 +108,7 @@ __global__ void gemv_q5_0_kernel(
119108
extern "C" cudaError_t gemv_q5_0_f32(
120109
const void* W_q5_0, const float* x, float* y,
121110
int M, int K,
111+
int qhOffset, int qsOffset,
122112
cudaStream_t stream)
123113
{
124114
if (K % Q5_0_BLOCK_SIZE != 0) {
@@ -130,7 +120,7 @@ extern "C" cudaError_t gemv_q5_0_f32(
130120
int smem = K * sizeof(float);
131121

132122
gemv_q5_0_kernel<<<grid, threads, smem, stream>>>(
133-
(const uint8_t*)W_q5_0, x, y, M, K);
123+
(const uint8_t*)W_q5_0, x, y, M, K, qhOffset, qsOffset);
134124

135125
return cudaGetLastError();
136126
}

internal/cuda/kernels/gemv_q5_0.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@ import (
1414
)
1515

1616
// GemvQ5_0F32 performs Q5_0 fused dequant-GEMV: y = dequant(W_q5_0) * x.
17-
// W_q5_0 is raw Q5_0 blocks for matrix [M, K] (row-major block layout).
18-
// x is [K] FP32 input vector. y is [M] FP32 output vector.
19-
// K must be a multiple of 32.
17+
// W_q5_0 is the separated GPU layout (scales | qh | qs).
18+
// qhOffset and qsOffset are byte offsets to the qh and qs regions.
2019
func GemvQ5_0F32(
2120
W_q5_0, x, y unsafe.Pointer,
22-
M, K int,
21+
M, K, qhOffset, qsOffset int,
2322
stream unsafe.Pointer,
2423
) error {
2524
err := C.gemv_q5_0_f32(
2625
W_q5_0, (*C.float)(x), (*C.float)(y),
2726
C.int(M), C.int(K),
27+
C.int(qhOffset), C.int(qsOffset),
2828
C.cudaStream_t(stream),
2929
)
3030
if err != C.cudaSuccess {

internal/cuda/kernels/gemv_q5_0.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
/* Q5_0 fused dequant-GEMV kernel interface.
1+
/* Q5_0 fused dequant-GEMV kernel interface (separated GPU layout).
22
*
3-
* Q5_0 block format (22 bytes per 32 values):
4-
* - 2 bytes: fp16 d (block scale)
5-
* - 4 bytes: uint32 qh (32 high bits, one per element)
6-
* - 16 bytes: qs (packed nibbles, two 4-bit values per byte)
3+
* GPU layout (from Q5_0Storage.RawBytesGPU):
4+
* Region 1: [nBlocks * 2 bytes] fp16 scales, padded to 16-byte boundary
5+
* Region 2: [nBlocks * 4 bytes] uint32 qh values, padded to 16-byte boundary
6+
* Region 3: [nBlocks * 16 bytes] packed nibbles (qs)
77
*
88
* Computes: y[m] = sum_k( dequant(W_q5_0[m,k]) * x[k] )
9-
* W_q5_0 is raw Q5_0 blocks laid out row-major.
10-
* x is [K] FP32 input vector. y is [M] FP32 output vector.
119
* Batch=1 only (GEMV, not GEMM).
1210
*/
1311
#ifndef GEMV_Q5_0_H
@@ -22,16 +20,18 @@ extern "C" {
2220
/* gemv_q5_0_f32 performs Q5_0 fused dequant-GEMV:
2321
* y[m] = sum_k( dequant(W_q5_0[m,k]) * x[k] )
2422
*
25-
* W_q5_0: device pointer to raw Q5_0 blocks for matrix W [M, K].
26-
* M * ceil(K/32) blocks, each 22 bytes. Row-major layout.
27-
* x: device pointer to [K] float input vector.
28-
* y: device pointer to [M] float output vector.
29-
* M, K: matrix dimensions. K must be a multiple of 32.
30-
* stream: CUDA stream.
23+
* W_q5_0: device pointer to separated Q5_0 layout (scales | qh | qs).
24+
* x: device pointer to [K] float input vector.
25+
* y: device pointer to [M] float output vector.
26+
* M, K: matrix dimensions. K must be a multiple of 32.
27+
* qhOffset: byte offset from W_q5_0 to the qh region.
28+
* qsOffset: byte offset from W_q5_0 to the qs region.
29+
* stream: CUDA stream.
3130
*/
3231
cudaError_t gemv_q5_0_f32(
3332
const void* W_q5_0, const float* x, float* y,
3433
int M, int K,
34+
int qhOffset, int qsOffset,
3535
cudaStream_t stream);
3636

3737
#ifdef __cplusplus

internal/cuda/kernels/gemv_q5_0_purego.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ import (
1010
)
1111

1212
// GemvQ5_0F32 performs Q5_0 fused dequant-GEMV: y = dequant(W_q5_0) * x.
13-
// W_q5_0 is raw Q5_0 blocks, x is [K] FP32, y is [M] FP32.
13+
// W_q5_0 is the separated GPU layout (scales | qh | qs).
14+
// qhOffset and qsOffset are byte offsets to the qh and qs regions.
1415
func GemvQ5_0F32(
1516
W_q5_0, x, y unsafe.Pointer, //nolint:gocritic // match CGo API
16-
M, K int, //nolint:gocritic // match CGo API
17+
M, K, qhOffset, qsOffset int, //nolint:gocritic // match CGo API
1718
stream unsafe.Pointer,
1819
) error {
1920
k := klib()
@@ -22,6 +23,8 @@ func GemvQ5_0F32(
2223
}
2324
ret := cuda.Ccall(k.launchGemvQ5_0F32,
2425
uintptr(W_q5_0), uintptr(x), uintptr(y),
25-
uintptr(M), uintptr(K), uintptr(stream))
26+
uintptr(M), uintptr(K),
27+
uintptr(qhOffset), uintptr(qsOffset),
28+
uintptr(stream))
2629
return checkKernel(ret, "gemv_q5_0_f32")
2730
}

internal/gpuapi/cuda_kernels.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ func (k *CUDAKernels) GemvQ6KF32(wQ6K, x, y unsafe.Pointer, M, K int, s Stream)
129129
return kernels.GemvQ6KF32(wQ6K, x, y, M, K, streamPtr(s))
130130
}
131131

132-
func (k *CUDAKernels) GemvQ5_0F32(wQ5_0, x, y unsafe.Pointer, M, K int, s Stream) error {
133-
return kernels.GemvQ5_0F32(wQ5_0, x, y, M, K, streamPtr(s))
132+
func (k *CUDAKernels) GemvQ5_0F32(wQ5_0, x, y unsafe.Pointer, M, K, qhOffset, qsOffset int, s Stream) error {
133+
return kernels.GemvQ5_0F32(wQ5_0, x, y, M, K, qhOffset, qsOffset, streamPtr(s))
134134
}
135135

136136
func (k *CUDAKernels) DequantQ4KF32(src, dst unsafe.Pointer, rows, K int, s Stream) error {

internal/gpuapi/fpga_kernels.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func (k *FPGAKernels) GemvQ6KF32(_, _, _ unsafe.Pointer, _, _ int, _ Stream) err
115115
return fmt.Errorf("GemvQ6KF32: not implemented for FPGA")
116116
}
117117

118-
func (k *FPGAKernels) GemvQ5_0F32(_, _, _ unsafe.Pointer, _, _ int, _ Stream) error {
118+
func (k *FPGAKernels) GemvQ5_0F32(_, _, _ unsafe.Pointer, _, _, _, _ int, _ Stream) error {
119119
return fmt.Errorf("GemvQ5_0F32: not implemented for FPGA")
120120
}
121121

internal/gpuapi/gpuapi_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ func (stubKernelRunner) GemvQ5KF32(_, _, _ unsafe.Pointer, _, _ int, _ gpuapi.St
150150
func (stubKernelRunner) GemvQ6KF32(_, _, _ unsafe.Pointer, _, _ int, _ gpuapi.Stream) error {
151151
return nil
152152
}
153-
func (stubKernelRunner) GemvQ5_0F32(_, _, _ unsafe.Pointer, _, _ int, _ gpuapi.Stream) error {
153+
func (stubKernelRunner) GemvQ5_0F32(_, _, _ unsafe.Pointer, _, _, _, _ int, _ gpuapi.Stream) error {
154154
return nil
155155
}
156156
func (stubKernelRunner) DequantQ4KF32(_, _ unsafe.Pointer, _, _ int, _ gpuapi.Stream) error {

internal/gpuapi/kernels.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ type KernelRunner interface {
6363
GemvQ6KF32(wQ6K, x, y unsafe.Pointer, M, K int, stream Stream) error
6464

6565
// GemvQ5_0F32 performs Q5_0 fused dequant-GEMV: y = dequant(W_q5_0) * x.
66-
// W_q5_0 is raw Q5_0 blocks for matrix [M, K]. x is [K] float32.
66+
// W_q5_0 is the separated GPU layout (scales | qh | qs). x is [K] float32.
6767
// y is [M] float32. K must be a multiple of 32. Batch=1 only.
68-
GemvQ5_0F32(wQ5_0, x, y unsafe.Pointer, M, K int, stream Stream) error
68+
// qhOffset and qsOffset are byte offsets to the qh and qs regions.
69+
GemvQ5_0F32(wQ5_0, x, y unsafe.Pointer, M, K, qhOffset, qsOffset int, stream Stream) error
6970

7071
// DequantQ4KF32 dequantizes Q4_K super-blocks to FP32 in global memory.
7172
// src is raw Q4_K super-blocks for matrix [rows, K]. dst is [rows, K] float32.

internal/gpuapi/metal_kernels.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ func (k *MetalKernels) GemvQ6KF32(_, _, _ unsafe.Pointer, _, _ int, _ Stream) er
491491
return fmt.Errorf("GemvQ6KF32: not yet implemented for Metal")
492492
}
493493

494-
func (k *MetalKernels) GemvQ5_0F32(_, _, _ unsafe.Pointer, _, _ int, _ Stream) error {
494+
func (k *MetalKernels) GemvQ5_0F32(_, _, _ unsafe.Pointer, _, _, _, _ int, _ Stream) error {
495495
return fmt.Errorf("GemvQ5_0F32: not yet implemented for Metal")
496496
}
497497

0 commit comments

Comments
 (0)