@@ -9,6 +9,30 @@ import (
99 "github.com/zerfoo/ztensor/internal/cuda"
1010)
1111
12+ // q5_0ToGPULayout converts standard Q5_0 block data (22 bytes/block: [d(2)|qh(4)|qs(16)])
13+ // to the GPU-separated layout: scales(2*N, padded) | qh(4*N, padded) | qs(16*N).
14+ func q5_0ToGPULayout (raw []byte , nBlocks int ) []byte {
15+ const blockBytes = 22
16+ scaleBytes := nBlocks * 2
17+ paddedScaleBytes := (scaleBytes + 15 ) &^ 15
18+ qhBytes := nBlocks * 4
19+ paddedQhBytes := (qhBytes + 15 ) &^ 15
20+ qsBytes := nBlocks * 16
21+ total := paddedScaleBytes + paddedQhBytes + qsBytes
22+
23+ out := make ([]byte , total )
24+ for i := range nBlocks {
25+ blockOff := i * blockBytes
26+ // scale: 2 bytes at blockOff+0
27+ copy (out [i * 2 :i * 2 + 2 ], raw [blockOff :blockOff + 2 ])
28+ // qh: 4 bytes at blockOff+2
29+ copy (out [paddedScaleBytes + i * 4 :paddedScaleBytes + i * 4 + 4 ], raw [blockOff + 2 :blockOff + 6 ])
30+ // qs: 16 bytes at blockOff+6
31+ copy (out [paddedScaleBytes + paddedQhBytes + i * 16 :paddedScaleBytes + paddedQhBytes + i * 16 + 16 ], raw [blockOff + 6 :blockOff + 22 ])
32+ }
33+ return out
34+ }
35+
1236// dequantizeQ5_0 dequantizes one Q5_0 block (22 bytes) into 32 float32 values.
1337// Inlined here to avoid an import cycle with the tensor package.
1438func dequantizeQ5_0 (raw []byte , dst []float32 ) {
@@ -153,12 +177,6 @@ func TestGemvQ5_0F32_Parity(t *testing.T) {
153177 }
154178 defer func () { _ = stream .Destroy () }()
155179
156- devW , err := cuda .Malloc (len (raw ))
157- if err != nil {
158- t .Fatalf ("cuda.Malloc W: %v" , err )
159- }
160- defer func () { _ = cuda .Free (devW ) }()
161-
162180 devX , err := cuda .Malloc (K * 4 )
163181 if err != nil {
164182 t .Fatalf ("cuda.Malloc x: %v" , err )
@@ -171,14 +189,30 @@ func TestGemvQ5_0F32_Parity(t *testing.T) {
171189 }
172190 defer func () { _ = cuda .Free (devY ) }()
173191
174- if err := cuda .Memcpy (devW , unsafe .Pointer (& raw [0 ]), len (raw ), cuda .MemcpyHostToDevice ); err != nil {
175- t .Fatalf ("Memcpy W: %v" , err )
176- }
177192 if err := cuda .Memcpy (devX , unsafe .Pointer (& x [0 ]), K * 4 , cuda .MemcpyHostToDevice ); err != nil {
178193 t .Fatalf ("Memcpy x: %v" , err )
179194 }
180195
181- if err := GemvQ5_0F32 (devW , devX , devY , M , K , stream .Ptr ()); err != nil {
196+ // Convert standard Q5_0 blocks to GPU-separated layout (scales | qh | qs)
197+ // and compute region offsets for the kernel.
198+ nBlocks := M * (K / 32 )
199+ gpuRaw := q5_0ToGPULayout (raw , nBlocks )
200+ scaleBytes := nBlocks * 2
201+ qhOffset := (scaleBytes + 15 ) &^ 15
202+ qhBytes := nBlocks * 4
203+ qsOffset := qhOffset + (qhBytes + 15 )&^15
204+
205+ // Re-upload GPU-layout data.
206+ devWGPU , err := cuda .Malloc (len (gpuRaw ))
207+ if err != nil {
208+ t .Fatalf ("cuda.Malloc W GPU: %v" , err )
209+ }
210+ defer func () { _ = cuda .Free (devWGPU ) }()
211+ if err := cuda .Memcpy (devWGPU , unsafe .Pointer (& gpuRaw [0 ]), len (gpuRaw ), cuda .MemcpyHostToDevice ); err != nil {
212+ t .Fatalf ("Memcpy W GPU: %v" , err )
213+ }
214+
215+ if err := GemvQ5_0F32 (devWGPU , devX , devY , M , K , qhOffset , qsOffset , stream .Ptr ()); err != nil {
182216 t .Fatalf ("GemvQ5_0F32: %v" , err )
183217 }
184218
@@ -241,12 +275,6 @@ func TestGemvQ5_0F32_MultipleSizes(t *testing.T) {
241275 }
242276 defer func () { _ = stream .Destroy () }()
243277
244- devW , err := cuda .Malloc (len (raw ))
245- if err != nil {
246- t .Fatalf ("cuda.Malloc W: %v" , err )
247- }
248- defer func () { _ = cuda .Free (devW ) }()
249-
250278 devX , err := cuda .Malloc (tc .K * 4 )
251279 if err != nil {
252280 t .Fatalf ("cuda.Malloc x: %v" , err )
@@ -259,14 +287,27 @@ func TestGemvQ5_0F32_MultipleSizes(t *testing.T) {
259287 }
260288 defer func () { _ = cuda .Free (devY ) }()
261289
262- if err := cuda .Memcpy (devW , unsafe .Pointer (& raw [0 ]), len (raw ), cuda .MemcpyHostToDevice ); err != nil {
263- t .Fatalf ("Memcpy W: %v" , err )
290+ // Convert to GPU-separated layout and compute offsets.
291+ nBlocks := tc .M * (tc .K / 32 )
292+ gpuRaw := q5_0ToGPULayout (raw , nBlocks )
293+ scaleBytes := nBlocks * 2
294+ qhOffset := (scaleBytes + 15 ) &^ 15
295+ qhBytes := nBlocks * 4
296+ qsOffset := qhOffset + (qhBytes + 15 )&^15
297+
298+ devWGPU , err := cuda .Malloc (len (gpuRaw ))
299+ if err != nil {
300+ t .Fatalf ("cuda.Malloc W GPU: %v" , err )
301+ }
302+ defer func () { _ = cuda .Free (devWGPU ) }()
303+ if err := cuda .Memcpy (devWGPU , unsafe .Pointer (& gpuRaw [0 ]), len (gpuRaw ), cuda .MemcpyHostToDevice ); err != nil {
304+ t .Fatalf ("Memcpy W GPU: %v" , err )
264305 }
265306 if err := cuda .Memcpy (devX , unsafe .Pointer (& x [0 ]), tc .K * 4 , cuda .MemcpyHostToDevice ); err != nil {
266307 t .Fatalf ("Memcpy x: %v" , err )
267308 }
268309
269- if err := GemvQ5_0F32 (devW , devX , devY , tc .M , tc .K , stream .Ptr ()); err != nil {
310+ if err := GemvQ5_0F32 (devWGPU , devX , devY , tc .M , tc .K , qhOffset , qsOffset , stream .Ptr ()); err != nil {
270311 t .Fatalf ("GemvQ5_0F32: %v" , err )
271312 }
272313
@@ -318,19 +359,26 @@ func BenchmarkGemvQ5_0F32_4096(b *testing.B) {
318359 }
319360 defer func () { _ = stream .Destroy () }()
320361
321- devW , _ := cuda .Malloc (len (raw ))
362+ nBlocks := M * (K / 32 )
363+ gpuRaw := q5_0ToGPULayout (raw , nBlocks )
364+ scaleBytes := nBlocks * 2
365+ qhOffset := (scaleBytes + 15 ) &^ 15
366+ qhBytes := nBlocks * 4
367+ qsOffset := qhOffset + (qhBytes + 15 )&^15
368+
369+ devW , _ := cuda .Malloc (len (gpuRaw ))
322370 defer func () { _ = cuda .Free (devW ) }()
323371 devX , _ := cuda .Malloc (K * 4 )
324372 defer func () { _ = cuda .Free (devX ) }()
325373 devY , _ := cuda .Malloc (M * 4 )
326374 defer func () { _ = cuda .Free (devY ) }()
327375
328- _ = cuda .Memcpy (devW , unsafe .Pointer (& raw [0 ]), len (raw ), cuda .MemcpyHostToDevice )
376+ _ = cuda .Memcpy (devW , unsafe .Pointer (& gpuRaw [0 ]), len (gpuRaw ), cuda .MemcpyHostToDevice )
329377 _ = cuda .Memcpy (devX , unsafe .Pointer (& x [0 ]), K * 4 , cuda .MemcpyHostToDevice )
330378
331379 b .ResetTimer ()
332380 for b .Loop () {
333- _ = GemvQ5_0F32 (devW , devX , devY , M , K , stream .Ptr ())
381+ _ = GemvQ5_0F32 (devW , devX , devY , M , K , qhOffset , qsOffset , stream .Ptr ())
334382 }
335383 _ = stream .Synchronize ()
336384
0 commit comments