Skip to content

Commit f50ffa7

Browse files
committed
fix(compute): CPU dequant fallback for Q4_K when K%256!=0
The DequantQ4KF32 GPU kernel uses blocks_per_row=K/256 (integer division), which truncates when K is not 256-aligned. For Gemma3-1B (hidden_size=1152, 1152%256=128), this missed the last 128 values per row, producing incorrect results. When K%256!=0, fall back to CPU dequantize + H2D upload for the general GEMM path. The GEMV path already gates on k%256==0.
1 parent 5f21cbb commit f50ffa7

1 file changed

Lines changed: 22 additions & 4 deletions

File tree

compute/gpu_engine.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,8 +1497,16 @@ func (e *GPUEngine[T]) matMulQ4K(ctx context.Context, qs *tensor.Q4KStorage, a,
14971497
}
14981498
defer e.pool.Free(e.deviceID, devAF32, dequantSize)
14991499

1500-
if err := e.kernels.DequantQ4KF32(devW, devAF32, m, k, e.stream); err != nil {
1501-
return e.cpu.MatMul(ctx, a, b, dst...)
1500+
if k%256 == 0 {
1501+
if err := e.kernels.DequantQ4KF32(devW, devAF32, m, k, e.stream); err != nil {
1502+
return e.cpu.MatMul(ctx, a, b, dst...)
1503+
}
1504+
} else {
1505+
dequant := make([]float32, m*k)
1506+
qs.Dequantize(dequant)
1507+
if err := e.runtime.Memcpy(devAF32, unsafe.Pointer(&dequant[0]), dequantSize, gpuapi.MemcpyHostToDevice); err != nil {
1508+
return e.cpu.MatMul(ctx, a, b, dst...)
1509+
}
15021510
}
15031511

15041512
// Upload B to GPU.
@@ -1617,8 +1625,18 @@ func (e *GPUEngine[T]) matMulQ4KBWeight(ctx context.Context, a *tensor.TensorNum
16171625
}
16181626
defer e.pool.Free(e.deviceID, devBF32, dequantSize)
16191627

1620-
if err := e.kernels.DequantQ4KF32(devQ4K, devBF32, n, k, e.stream); err != nil {
1621-
return e.cpu.MatMul(ctx, a, b, dst...)
1628+
if k%256 == 0 {
1629+
// GPU dequant when K is super-block aligned.
1630+
if err := e.kernels.DequantQ4KF32(devQ4K, devBF32, n, k, e.stream); err != nil {
1631+
return e.cpu.MatMul(ctx, a, b, dst...)
1632+
}
1633+
} else {
1634+
// CPU dequant for unaligned K (super-block boundary doesn't match row boundary).
1635+
dequant := make([]float32, n*k)
1636+
qs.Dequantize(dequant)
1637+
if err := e.runtime.Memcpy(devBF32, unsafe.Pointer(&dequant[0]), dequantSize, gpuapi.MemcpyHostToDevice); err != nil {
1638+
return e.cpu.MatMul(ctx, a, b, dst...)
1639+
}
16221640
}
16231641

16241642
// Upload A to GPU.

0 commit comments

Comments
 (0)