Skip to content

Commit 5f21cbb

Browse files
committed
fix(compute): use dequant+cuBLAS for Q4_K when K%256!=0
Q4_K GEMV requires K to be a multiple of 256 (super-block size). For models where hidden_size is not 256-aligned (e.g., Gemma3-1B with hidden_size=1152, 1152%256=128), all Q4_K matmuls fell back to CPU. Remove the hard k%256!=0 → CPU fallback. Instead, only use the GEMV fast path when k%256==0, and fall through to the dequant+cuBLAS path (DequantQ4KF32 + SgemmNT) for unaligned K. The dequant kernel handles ceil(K/256) super-blocks, and cuBLAS handles any dimensions.
1 parent d0d3a82 commit 5f21cbb

1 file changed

Lines changed: 5 additions & 13 deletions

File tree

compute/gpu_engine.go

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,11 +1439,6 @@ func (e *GPUEngine[T]) matMulQ4K(ctx context.Context, qs *tensor.Q4KStorage, a,
14391439
k := aShape[1]
14401440
n := bShape[1]
14411441

1442-
// K must be a multiple of 256 for Q4_K super-blocks.
1443-
if k%256 != 0 {
1444-
return e.cpu.MatMul(ctx, a, b, dst...)
1445-
}
1446-
14471442
e.setDevice()
14481443

14491444
// Get Q4_K device pointer (pre-uploaded or upload now).
@@ -1467,8 +1462,8 @@ func (e *GPUEngine[T]) matMulQ4K(ctx context.Context, qs *tensor.Q4KStorage, a,
14671462
}
14681463
defer freeW()
14691464

1470-
// Fused GEMV path: y = dequant(W_q4k) * x, when n==1.
1471-
if n == 1 {
1465+
// Fused GEMV path: y = dequant(W_q4k) * x, when n==1 and K is 256-aligned.
1466+
if n == 1 && k%256 == 0 {
14721467
devX, cleanupX, err := getDevicePtr(e, b)
14731468
if err != nil {
14741469
return e.cpu.MatMul(ctx, a, b, dst...)
@@ -1554,11 +1549,6 @@ func (e *GPUEngine[T]) matMulQ4KBWeight(ctx context.Context, a *tensor.TensorNum
15541549
}
15551550
n := bShape[1] // columns of B (after virtual transpose)
15561551

1557-
// K must be a multiple of 256 for Q4_K super-blocks.
1558-
if k%256 != 0 {
1559-
return e.cpu.MatMul(ctx, a, b, dst...)
1560-
}
1561-
15621552
// Build output shape: [batch..., m_last, n].
15631553
outShape := make([]int, len(aShape))
15641554
copy(outShape, aShape[:len(aShape)-1])
@@ -1589,7 +1579,9 @@ func (e *GPUEngine[T]) matMulQ4KBWeight(ctx context.Context, a *tensor.TensorNum
15891579
defer freeQ4K()
15901580

15911581
// Fused GEMV path: y[n] = sum_k dequant(B_q4k[n, k]) * x[k], when m==1.
1592-
if m == 1 {
1582+
// Requires K % 256 == 0 for Q4_K super-block alignment.
1583+
// When K is not aligned, falls through to the general dequant+cuBLAS path.
1584+
if m == 1 && k%256 == 0 {
15931585
devX, cleanupX, err := getDevicePtr(e, a)
15941586
if err != nil {
15951587
return e.cpu.MatMul(ctx, a, b, dst...)

0 commit comments

Comments
 (0)