Skip to content

Commit 70f8fd5

Browse files
committed
fix(kernels): update GemvQ5_0F32 test to match qhOffset/qsOffset signature
The GemvQ5_0F32 kernel was updated to accept qhOffset and qsOffset parameters for the GPU-separated layout, but the test still called with the old 6-arg signature. Fix all 3 call sites: - TestGemvQ5_0F32_Parity - TestGemvQ5_0F32_MultipleSizes - BenchmarkGemvQ5_0F32_4096 Add q5_0ToGPULayout helper to convert standard block format to the GPU-separated layout (scales | qh | qs) needed by the kernel.
1 parent d03e9f3 commit 70f8fd5

1 file changed

Lines changed: 70 additions & 22 deletions

File tree

internal/cuda/kernels/gemv_q5_0_test.go

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,30 @@ import (
99
"github.com/zerfoo/ztensor/internal/cuda"
1010
)
1111

12+
// q5_0ToGPULayout converts standard Q5_0 block data (22 bytes/block: [d(2)|qh(4)|qs(16)])
13+
// to the GPU-separated layout: scales(2*N, padded) | qh(4*N, padded) | qs(16*N).
14+
func q5_0ToGPULayout(raw []byte, nBlocks int) []byte {
15+
const blockBytes = 22
16+
scaleBytes := nBlocks * 2
17+
paddedScaleBytes := (scaleBytes + 15) &^ 15
18+
qhBytes := nBlocks * 4
19+
paddedQhBytes := (qhBytes + 15) &^ 15
20+
qsBytes := nBlocks * 16
21+
total := paddedScaleBytes + paddedQhBytes + qsBytes
22+
23+
out := make([]byte, total)
24+
for i := range nBlocks {
25+
blockOff := i * blockBytes
26+
// scale: 2 bytes at blockOff+0
27+
copy(out[i*2:i*2+2], raw[blockOff:blockOff+2])
28+
// qh: 4 bytes at blockOff+2
29+
copy(out[paddedScaleBytes+i*4:paddedScaleBytes+i*4+4], raw[blockOff+2:blockOff+6])
30+
// qs: 16 bytes at blockOff+6
31+
copy(out[paddedScaleBytes+paddedQhBytes+i*16:paddedScaleBytes+paddedQhBytes+i*16+16], raw[blockOff+6:blockOff+22])
32+
}
33+
return out
34+
}
35+
1236
// dequantizeQ5_0 dequantizes one Q5_0 block (22 bytes) into 32 float32 values.
1337
// Inlined here to avoid an import cycle with the tensor package.
1438
func dequantizeQ5_0(raw []byte, dst []float32) {
@@ -153,12 +177,6 @@ func TestGemvQ5_0F32_Parity(t *testing.T) {
153177
}
154178
defer func() { _ = stream.Destroy() }()
155179

156-
devW, err := cuda.Malloc(len(raw))
157-
if err != nil {
158-
t.Fatalf("cuda.Malloc W: %v", err)
159-
}
160-
defer func() { _ = cuda.Free(devW) }()
161-
162180
devX, err := cuda.Malloc(K * 4)
163181
if err != nil {
164182
t.Fatalf("cuda.Malloc x: %v", err)
@@ -171,14 +189,30 @@ func TestGemvQ5_0F32_Parity(t *testing.T) {
171189
}
172190
defer func() { _ = cuda.Free(devY) }()
173191

174-
if err := cuda.Memcpy(devW, unsafe.Pointer(&raw[0]), len(raw), cuda.MemcpyHostToDevice); err != nil {
175-
t.Fatalf("Memcpy W: %v", err)
176-
}
177192
if err := cuda.Memcpy(devX, unsafe.Pointer(&x[0]), K*4, cuda.MemcpyHostToDevice); err != nil {
178193
t.Fatalf("Memcpy x: %v", err)
179194
}
180195

181-
if err := GemvQ5_0F32(devW, devX, devY, M, K, stream.Ptr()); err != nil {
196+
// Convert standard Q5_0 blocks to GPU-separated layout (scales | qh | qs)
197+
// and compute region offsets for the kernel.
198+
nBlocks := M * (K / 32)
199+
gpuRaw := q5_0ToGPULayout(raw, nBlocks)
200+
scaleBytes := nBlocks * 2
201+
qhOffset := (scaleBytes + 15) &^ 15
202+
qhBytes := nBlocks * 4
203+
qsOffset := qhOffset + (qhBytes+15)&^15
204+
205+
// Re-upload GPU-layout data.
206+
devWGPU, err := cuda.Malloc(len(gpuRaw))
207+
if err != nil {
208+
t.Fatalf("cuda.Malloc W GPU: %v", err)
209+
}
210+
defer func() { _ = cuda.Free(devWGPU) }()
211+
if err := cuda.Memcpy(devWGPU, unsafe.Pointer(&gpuRaw[0]), len(gpuRaw), cuda.MemcpyHostToDevice); err != nil {
212+
t.Fatalf("Memcpy W GPU: %v", err)
213+
}
214+
215+
if err := GemvQ5_0F32(devWGPU, devX, devY, M, K, qhOffset, qsOffset, stream.Ptr()); err != nil {
182216
t.Fatalf("GemvQ5_0F32: %v", err)
183217
}
184218

@@ -241,12 +275,6 @@ func TestGemvQ5_0F32_MultipleSizes(t *testing.T) {
241275
}
242276
defer func() { _ = stream.Destroy() }()
243277

244-
devW, err := cuda.Malloc(len(raw))
245-
if err != nil {
246-
t.Fatalf("cuda.Malloc W: %v", err)
247-
}
248-
defer func() { _ = cuda.Free(devW) }()
249-
250278
devX, err := cuda.Malloc(tc.K * 4)
251279
if err != nil {
252280
t.Fatalf("cuda.Malloc x: %v", err)
@@ -259,14 +287,27 @@ func TestGemvQ5_0F32_MultipleSizes(t *testing.T) {
259287
}
260288
defer func() { _ = cuda.Free(devY) }()
261289

262-
if err := cuda.Memcpy(devW, unsafe.Pointer(&raw[0]), len(raw), cuda.MemcpyHostToDevice); err != nil {
263-
t.Fatalf("Memcpy W: %v", err)
290+
// Convert to GPU-separated layout and compute offsets.
291+
nBlocks := tc.M * (tc.K / 32)
292+
gpuRaw := q5_0ToGPULayout(raw, nBlocks)
293+
scaleBytes := nBlocks * 2
294+
qhOffset := (scaleBytes + 15) &^ 15
295+
qhBytes := nBlocks * 4
296+
qsOffset := qhOffset + (qhBytes+15)&^15
297+
298+
devWGPU, err := cuda.Malloc(len(gpuRaw))
299+
if err != nil {
300+
t.Fatalf("cuda.Malloc W GPU: %v", err)
301+
}
302+
defer func() { _ = cuda.Free(devWGPU) }()
303+
if err := cuda.Memcpy(devWGPU, unsafe.Pointer(&gpuRaw[0]), len(gpuRaw), cuda.MemcpyHostToDevice); err != nil {
304+
t.Fatalf("Memcpy W GPU: %v", err)
264305
}
265306
if err := cuda.Memcpy(devX, unsafe.Pointer(&x[0]), tc.K*4, cuda.MemcpyHostToDevice); err != nil {
266307
t.Fatalf("Memcpy x: %v", err)
267308
}
268309

269-
if err := GemvQ5_0F32(devW, devX, devY, tc.M, tc.K, stream.Ptr()); err != nil {
310+
if err := GemvQ5_0F32(devWGPU, devX, devY, tc.M, tc.K, qhOffset, qsOffset, stream.Ptr()); err != nil {
270311
t.Fatalf("GemvQ5_0F32: %v", err)
271312
}
272313

@@ -318,19 +359,26 @@ func BenchmarkGemvQ5_0F32_4096(b *testing.B) {
318359
}
319360
defer func() { _ = stream.Destroy() }()
320361

321-
devW, _ := cuda.Malloc(len(raw))
362+
nBlocks := M * (K / 32)
363+
gpuRaw := q5_0ToGPULayout(raw, nBlocks)
364+
scaleBytes := nBlocks * 2
365+
qhOffset := (scaleBytes + 15) &^ 15
366+
qhBytes := nBlocks * 4
367+
qsOffset := qhOffset + (qhBytes+15)&^15
368+
369+
devW, _ := cuda.Malloc(len(gpuRaw))
322370
defer func() { _ = cuda.Free(devW) }()
323371
devX, _ := cuda.Malloc(K * 4)
324372
defer func() { _ = cuda.Free(devX) }()
325373
devY, _ := cuda.Malloc(M * 4)
326374
defer func() { _ = cuda.Free(devY) }()
327375

328-
_ = cuda.Memcpy(devW, unsafe.Pointer(&raw[0]), len(raw), cuda.MemcpyHostToDevice)
376+
_ = cuda.Memcpy(devW, unsafe.Pointer(&gpuRaw[0]), len(gpuRaw), cuda.MemcpyHostToDevice)
329377
_ = cuda.Memcpy(devX, unsafe.Pointer(&x[0]), K*4, cuda.MemcpyHostToDevice)
330378

331379
b.ResetTimer()
332380
for b.Loop() {
333-
_ = GemvQ5_0F32(devW, devX, devY, M, K, stream.Ptr())
381+
_ = GemvQ5_0F32(devW, devX, devY, M, K, qhOffset, qsOffset, stream.Ptr())
334382
}
335383
_ = stream.Synchronize()
336384

0 commit comments

Comments
 (0)