Skip to content

Commit 5bc914b

Browse files
committed
fix(compute): upload CPU fallback MatMul results to GPU for device consistency
1 parent 0d97c69 commit 5bc914b

1 file changed

Lines changed: 59 additions & 33 deletions

File tree

compute/gpu_engine.go

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2640,14 +2640,40 @@ func (e *GPUEngine[T]) matMulBF16BWeight(ctx context.Context, a *tensor.TensorNu
26402640
return makeGPUResult[T](e, outShape, devC, m*n, dst...)
26412641
}
26422642

2643+
// cpuMatMulToGPU runs MatMul on the CPU engine then uploads the result to GPU.
2644+
// This ensures callers always receive a GPU-resident tensor, maintaining device
2645+
// consistency when the GPU engine falls back to CPU for unsupported quant types.
2646+
func (e *GPUEngine[T]) cpuMatMulToGPU(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) {
2647+
result, err := e.cpu.MatMul(ctx, a, b, dst...)
2648+
if err != nil {
2649+
return nil, err
2650+
}
2651+
// If already GPU-resident (e.g., dst was provided with GPUStorage), return as-is.
2652+
if _, ok := result.GetStorage().(*tensor.GPUStorage[T]); ok {
2653+
return result, nil
2654+
}
2655+
// Upload CPU result to GPU.
2656+
data := result.Data()
2657+
byteSize := len(data) * int(unsafe.Sizeof(*new(T)))
2658+
devPtr, err := e.pool.Alloc(e.deviceID, byteSize)
2659+
if err != nil {
2660+
return result, nil // fallback: return CPU tensor if GPU alloc fails
2661+
}
2662+
if err := e.runtime.Memcpy(devPtr, unsafe.Pointer(&data[0]), byteSize, gpuapi.MemcpyHostToDevice); err != nil {
2663+
e.pool.Free(e.deviceID, devPtr, byteSize)
2664+
return result, nil
2665+
}
2666+
return makeGPUResult[T](e, result.Shape(), devPtr, len(data), dst...)
2667+
}
2668+
26432669
// matMulMmap handles MatMul where A has MmapStorage. Routes to the appropriate
26442670
// quantized kernel based on QType, using the pre-uploaded GPU pointer from
26452671
// UploadWeights or uploading raw bytes on the fly.
26462672
func (e *GPUEngine[T]) matMulMmap(ctx context.Context, ms *tensor.MmapStorage, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) {
26472673
aShape := a.Shape()
26482674
bShape := b.Shape()
26492675
if len(aShape) < 2 || len(bShape) < 2 || len(aShape) > 2 || len(bShape) > 2 {
2650-
return e.cpu.MatMul(ctx, a, b, dst...)
2676+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
26512677
}
26522678

26532679
m := aShape[0]
@@ -2658,7 +2684,7 @@ func (e *GPUEngine[T]) matMulMmap(ctx context.Context, ms *tensor.MmapStorage, a
26582684
// Acquire GPU pointer for the quantized weight data.
26592685
devW, freeW, err := e.mmapDevicePtr(ms)
26602686
if err != nil {
2661-
return e.cpu.MatMul(ctx, a, b, dst...)
2687+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
26622688
}
26632689
defer freeW()
26642690

@@ -2668,90 +2694,90 @@ func (e *GPUEngine[T]) matMulMmap(ctx context.Context, ms *tensor.MmapStorage, a
26682694
if n == 1 {
26692695
devX, cleanupX, err := getDevicePtr(e, b)
26702696
if err != nil {
2671-
return e.cpu.MatMul(ctx, a, b, dst...)
2697+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
26722698
}
26732699
defer cleanupX()
26742700

26752701
f32Size := int(unsafe.Sizeof(float32(0)))
26762702
cSize := m * f32Size
26772703
devY, err := e.pool.Alloc(e.deviceID, cSize)
26782704
if err != nil {
2679-
return e.cpu.MatMul(ctx, a, b, dst...)
2705+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
26802706
}
26812707

26822708
var kerr error
26832709
switch qtype {
26842710
case tensor.GGMLTypeQ4_K:
26852711
if k%256 != 0 {
26862712
e.pool.Free(e.deviceID, devY, cSize)
2687-
return e.cpu.MatMul(ctx, a, b, dst...)
2713+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
26882714
}
26892715
kerr = e.kernels.GemvQ4KF32(devW, devX, devY, m, k, e.stream)
26902716
case tensor.GGMLTypeQ4_0:
26912717
if k%32 != 0 {
26922718
e.pool.Free(e.deviceID, devY, cSize)
2693-
return e.cpu.MatMul(ctx, a, b, dst...)
2719+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
26942720
}
26952721
totalBlocks := (m * k) / 32
26962722
dataOff := tensor.Q4GPUDataOffset(totalBlocks)
26972723
kerr = e.kernels.GemmQ4F32(devW, devX, devY, m, k, 1, dataOff, e.stream)
26982724
case tensor.GGMLTypeQ8_0:
26992725
if k%32 != 0 {
27002726
e.pool.Free(e.deviceID, devY, cSize)
2701-
return e.cpu.MatMul(ctx, a, b, dst...)
2727+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27022728
}
27032729
kerr = e.kernels.GemmQ8F32(devW, devX, devY, m, k, 1, e.stream)
27042730
case tensor.GGMLTypeQ6_K:
27052731
if k%256 != 0 {
27062732
e.pool.Free(e.deviceID, devY, cSize)
2707-
return e.cpu.MatMul(ctx, a, b, dst...)
2733+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27082734
}
27092735
kerr = e.kernels.GemvQ6KF32(devW, devX, devY, m, k, e.stream)
27102736
case tensor.GGMLTypeQ5_K:
27112737
if k%256 != 0 {
27122738
e.pool.Free(e.deviceID, devY, cSize)
2713-
return e.cpu.MatMul(ctx, a, b, dst...)
2739+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27142740
}
27152741
kerr = e.kernels.GemvQ5KF32(devW, devX, devY, m, k, e.stream)
27162742
default:
27172743
e.pool.Free(e.deviceID, devY, cSize)
2718-
return e.cpu.MatMul(ctx, a, b, dst...)
2744+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27192745
}
27202746
if kerr != nil {
27212747
e.pool.Free(e.deviceID, devY, cSize)
2722-
return e.cpu.MatMul(ctx, a, b, dst...)
2748+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27232749
}
27242750
return makeGPUResult[T](e, []int{m, n}, devY, m*n, dst...)
27252751
}
27262752

27272753
// General GEMM: dequantize Q4_K on GPU, then cuBLAS Sgemm.
27282754
// Only Q4_K has a GPU dequant kernel; others fall back to CPU.
27292755
if qtype != tensor.GGMLTypeQ4_K {
2730-
return e.cpu.MatMul(ctx, a, b, dst...)
2756+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27312757
}
27322758

27332759
f32Size := int(unsafe.Sizeof(float32(0)))
27342760
dequantSize := m * k * f32Size
27352761
devAF32, err := e.pool.Alloc(e.deviceID, dequantSize)
27362762
if err != nil {
2737-
return e.cpu.MatMul(ctx, a, b, dst...)
2763+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27382764
}
27392765
defer e.pool.Free(e.deviceID, devAF32, dequantSize)
27402766

27412767
if err := e.kernels.DequantQ4KF32(devW, devAF32, m, k, e.stream); err != nil {
2742-
return e.cpu.MatMul(ctx, a, b, dst...)
2768+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27432769
}
27442770

27452771
devB, cleanupB, err := getDevicePtr(e, b)
27462772
if err != nil {
2747-
return e.cpu.MatMul(ctx, a, b, dst...)
2773+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27482774
}
27492775
defer cleanupB()
27502776

27512777
cSize := m * n * f32Size
27522778
devC, err := e.pool.Alloc(e.deviceID, cSize)
27532779
if err != nil {
2754-
return e.cpu.MatMul(ctx, a, b, dst...)
2780+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27552781
}
27562782

27572783
if err := e.blas.Sgemm(m, n, k, 1.0, devAF32, devB, 0.0, devC); err != nil {
@@ -2768,7 +2794,7 @@ func (e *GPUEngine[T]) matMulMmapB(ctx context.Context, a *tensor.TensorNumeric[
27682794
aShape := a.Shape()
27692795
bShape := b.Shape()
27702796
if len(aShape) < 2 || len(bShape) < 2 || len(bShape) > 2 {
2771-
return e.cpu.MatMul(ctx, a, b, dst...)
2797+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27722798
}
27732799

27742800
// B is virtual-transposed: logical [K, N], physical [N, K].
@@ -2783,7 +2809,7 @@ func (e *GPUEngine[T]) matMulMmapB(ctx context.Context, a *tensor.TensorNumeric[
27832809

27842810
devW, freeW, err := e.mmapDevicePtr(ms)
27852811
if err != nil {
2786-
return e.cpu.MatMul(ctx, a, b, dst...)
2812+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27872813
}
27882814
defer freeW()
27892815

@@ -2794,58 +2820,58 @@ func (e *GPUEngine[T]) matMulMmapB(ctx context.Context, a *tensor.TensorNumeric[
27942820
if m == 1 {
27952821
devX, cleanupX, err := getDevicePtr(e, a)
27962822
if err != nil {
2797-
return e.cpu.MatMul(ctx, a, b, dst...)
2823+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
27982824
}
27992825
defer cleanupX()
28002826

28012827
f32Size := int(unsafe.Sizeof(float32(0)))
28022828
cSize := n * f32Size
28032829
devY, err := e.pool.Alloc(e.deviceID, cSize)
28042830
if err != nil {
2805-
return e.cpu.MatMul(ctx, a, b, dst...)
2831+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28062832
}
28072833

28082834
var kerr error
28092835
switch qtype {
28102836
case tensor.GGMLTypeQ4_K:
28112837
if k%256 != 0 {
28122838
e.pool.Free(e.deviceID, devY, cSize)
2813-
return e.cpu.MatMul(ctx, a, b, dst...)
2839+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28142840
}
28152841
kerr = e.kernels.GemvQ4KF32(devW, devX, devY, nPhys, k, e.stream)
28162842
case tensor.GGMLTypeQ4_0:
28172843
if k%32 != 0 {
28182844
e.pool.Free(e.deviceID, devY, cSize)
2819-
return e.cpu.MatMul(ctx, a, b, dst...)
2845+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28202846
}
28212847
totalBlocks := (nPhys * k) / 32
28222848
dataOff := tensor.Q4GPUDataOffset(totalBlocks)
28232849
kerr = e.kernels.GemmQ4F32(devW, devX, devY, nPhys, k, 1, dataOff, e.stream)
28242850
case tensor.GGMLTypeQ8_0:
28252851
if k%32 != 0 {
28262852
e.pool.Free(e.deviceID, devY, cSize)
2827-
return e.cpu.MatMul(ctx, a, b, dst...)
2853+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28282854
}
28292855
kerr = e.kernels.GemmQ8F32(devW, devX, devY, nPhys, k, 1, e.stream)
28302856
case tensor.GGMLTypeQ6_K:
28312857
if k%256 != 0 {
28322858
e.pool.Free(e.deviceID, devY, cSize)
2833-
return e.cpu.MatMul(ctx, a, b, dst...)
2859+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28342860
}
28352861
kerr = e.kernels.GemvQ6KF32(devW, devX, devY, nPhys, k, e.stream)
28362862
case tensor.GGMLTypeQ5_K:
28372863
if k%256 != 0 {
28382864
e.pool.Free(e.deviceID, devY, cSize)
2839-
return e.cpu.MatMul(ctx, a, b, dst...)
2865+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28402866
}
28412867
kerr = e.kernels.GemvQ5KF32(devW, devX, devY, nPhys, k, e.stream)
28422868
default:
28432869
e.pool.Free(e.deviceID, devY, cSize)
2844-
return e.cpu.MatMul(ctx, a, b, dst...)
2870+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28452871
}
28462872
if kerr != nil {
28472873
e.pool.Free(e.deviceID, devY, cSize)
2848-
return e.cpu.MatMul(ctx, a, b, dst...)
2874+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28492875
}
28502876

28512877
outShape := make([]int, len(aShape))
@@ -2857,31 +2883,31 @@ func (e *GPUEngine[T]) matMulMmapB(ctx context.Context, a *tensor.TensorNumeric[
28572883
// General GEMM: dequantize Q4_K on GPU, then cuBLAS SgemmNT.
28582884
// Only Q4_K has a GPU dequant kernel; others fall back to CPU.
28592885
if qtype != tensor.GGMLTypeQ4_K {
2860-
return e.cpu.MatMul(ctx, a, b, dst...)
2886+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28612887
}
28622888

28632889
f32Size := int(unsafe.Sizeof(float32(0)))
28642890
dequantSize := nPhys * k * f32Size
28652891
devBF32, err := e.pool.Alloc(e.deviceID, dequantSize)
28662892
if err != nil {
2867-
return e.cpu.MatMul(ctx, a, b, dst...)
2893+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28682894
}
28692895
defer e.pool.Free(e.deviceID, devBF32, dequantSize)
28702896

28712897
if err := e.kernels.DequantQ4KF32(devW, devBF32, nPhys, k, e.stream); err != nil {
2872-
return e.cpu.MatMul(ctx, a, b, dst...)
2898+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28732899
}
28742900

28752901
devA, cleanupA, err := getDevicePtr(e, a)
28762902
if err != nil {
2877-
return e.cpu.MatMul(ctx, a, b, dst...)
2903+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28782904
}
28792905
defer cleanupA()
28802906

28812907
cSize := m * n * f32Size
28822908
devC, err := e.pool.Alloc(e.deviceID, cSize)
28832909
if err != nil {
2884-
return e.cpu.MatMul(ctx, a, b, dst...)
2910+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
28852911
}
28862912

28872913
outShape := make([]int, len(aShape))
@@ -2899,7 +2925,7 @@ func (e *GPUEngine[T]) matMulMmapB(ctx context.Context, a *tensor.TensorNumeric[
28992925

29002926
// Fallback: CPU MatMul.
29012927
e.pool.Free(e.deviceID, devC, cSize)
2902-
return e.cpu.MatMul(ctx, a, b, dst...)
2928+
return e.cpuMatMulToGPU(ctx, a, b, dst...)
29032929
}
29042930

29052931
// mmapDevicePtr returns the GPU device pointer for MmapStorage data. If the data

0 commit comments

Comments
 (0)