@@ -2640,14 +2640,40 @@ func (e *GPUEngine[T]) matMulBF16BWeight(ctx context.Context, a *tensor.TensorNu
26402640 return makeGPUResult [T ](e , outShape , devC , m * n , dst ... )
26412641}
26422642
2643+ // cpuMatMulToGPU runs MatMul on the CPU engine then uploads the result to GPU.
2644+ // This ensures callers always receive a GPU-resident tensor, maintaining device
2645+ // consistency when the GPU engine falls back to CPU for unsupported quant types.
2646+ func (e * GPUEngine [T ]) cpuMatMulToGPU (ctx context.Context , a , b * tensor.TensorNumeric [T ], dst ... * tensor.TensorNumeric [T ]) (* tensor.TensorNumeric [T ], error ) {
2647+ result , err := e .cpu .MatMul (ctx , a , b , dst ... )
2648+ if err != nil {
2649+ return nil , err
2650+ }
2651+ // If already GPU-resident (e.g., dst was provided with GPUStorage), return as-is.
2652+ if _ , ok := result .GetStorage ().(* tensor.GPUStorage [T ]); ok {
2653+ return result , nil
2654+ }
2655+ // Upload CPU result to GPU.
2656+ data := result .Data ()
2657+ byteSize := len (data ) * int (unsafe .Sizeof (* new (T )))
2658+ devPtr , err := e .pool .Alloc (e .deviceID , byteSize )
2659+ if err != nil {
2660+ return result , nil // fallback: return CPU tensor if GPU alloc fails
2661+ }
2662+ if err := e .runtime .Memcpy (devPtr , unsafe .Pointer (& data [0 ]), byteSize , gpuapi .MemcpyHostToDevice ); err != nil {
2663+ e .pool .Free (e .deviceID , devPtr , byteSize )
2664+ return result , nil
2665+ }
2666+ return makeGPUResult [T ](e , result .Shape (), devPtr , len (data ), dst ... )
2667+ }
2668+
26432669// matMulMmap handles MatMul where A has MmapStorage. Routes to the appropriate
26442670// quantized kernel based on QType, using the pre-uploaded GPU pointer from
26452671// UploadWeights or uploading raw bytes on the fly.
26462672func (e * GPUEngine [T ]) matMulMmap (ctx context.Context , ms * tensor.MmapStorage , a , b * tensor.TensorNumeric [T ], dst ... * tensor.TensorNumeric [T ]) (* tensor.TensorNumeric [T ], error ) {
26472673 aShape := a .Shape ()
26482674 bShape := b .Shape ()
26492675 if len (aShape ) < 2 || len (bShape ) < 2 || len (aShape ) > 2 || len (bShape ) > 2 {
2650- return e .cpu . MatMul (ctx , a , b , dst ... )
2676+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
26512677 }
26522678
26532679 m := aShape [0 ]
@@ -2658,7 +2684,7 @@ func (e *GPUEngine[T]) matMulMmap(ctx context.Context, ms *tensor.MmapStorage, a
26582684 // Acquire GPU pointer for the quantized weight data.
26592685 devW , freeW , err := e .mmapDevicePtr (ms )
26602686 if err != nil {
2661- return e .cpu . MatMul (ctx , a , b , dst ... )
2687+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
26622688 }
26632689 defer freeW ()
26642690
@@ -2668,90 +2694,90 @@ func (e *GPUEngine[T]) matMulMmap(ctx context.Context, ms *tensor.MmapStorage, a
26682694 if n == 1 {
26692695 devX , cleanupX , err := getDevicePtr (e , b )
26702696 if err != nil {
2671- return e .cpu . MatMul (ctx , a , b , dst ... )
2697+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
26722698 }
26732699 defer cleanupX ()
26742700
26752701 f32Size := int (unsafe .Sizeof (float32 (0 )))
26762702 cSize := m * f32Size
26772703 devY , err := e .pool .Alloc (e .deviceID , cSize )
26782704 if err != nil {
2679- return e .cpu . MatMul (ctx , a , b , dst ... )
2705+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
26802706 }
26812707
26822708 var kerr error
26832709 switch qtype {
26842710 case tensor .GGMLTypeQ4_K :
26852711 if k % 256 != 0 {
26862712 e .pool .Free (e .deviceID , devY , cSize )
2687- return e .cpu . MatMul (ctx , a , b , dst ... )
2713+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
26882714 }
26892715 kerr = e .kernels .GemvQ4KF32 (devW , devX , devY , m , k , e .stream )
26902716 case tensor .GGMLTypeQ4_0 :
26912717 if k % 32 != 0 {
26922718 e .pool .Free (e .deviceID , devY , cSize )
2693- return e .cpu . MatMul (ctx , a , b , dst ... )
2719+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
26942720 }
26952721 totalBlocks := (m * k ) / 32
26962722 dataOff := tensor .Q4GPUDataOffset (totalBlocks )
26972723 kerr = e .kernels .GemmQ4F32 (devW , devX , devY , m , k , 1 , dataOff , e .stream )
26982724 case tensor .GGMLTypeQ8_0 :
26992725 if k % 32 != 0 {
27002726 e .pool .Free (e .deviceID , devY , cSize )
2701- return e .cpu . MatMul (ctx , a , b , dst ... )
2727+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27022728 }
27032729 kerr = e .kernels .GemmQ8F32 (devW , devX , devY , m , k , 1 , e .stream )
27042730 case tensor .GGMLTypeQ6_K :
27052731 if k % 256 != 0 {
27062732 e .pool .Free (e .deviceID , devY , cSize )
2707- return e .cpu . MatMul (ctx , a , b , dst ... )
2733+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27082734 }
27092735 kerr = e .kernels .GemvQ6KF32 (devW , devX , devY , m , k , e .stream )
27102736 case tensor .GGMLTypeQ5_K :
27112737 if k % 256 != 0 {
27122738 e .pool .Free (e .deviceID , devY , cSize )
2713- return e .cpu . MatMul (ctx , a , b , dst ... )
2739+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27142740 }
27152741 kerr = e .kernels .GemvQ5KF32 (devW , devX , devY , m , k , e .stream )
27162742 default :
27172743 e .pool .Free (e .deviceID , devY , cSize )
2718- return e .cpu . MatMul (ctx , a , b , dst ... )
2744+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27192745 }
27202746 if kerr != nil {
27212747 e .pool .Free (e .deviceID , devY , cSize )
2722- return e .cpu . MatMul (ctx , a , b , dst ... )
2748+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27232749 }
27242750 return makeGPUResult [T ](e , []int {m , n }, devY , m * n , dst ... )
27252751 }
27262752
27272753 // General GEMM: dequantize Q4_K on GPU, then cuBLAS Sgemm.
27282754 // Only Q4_K has a GPU dequant kernel; others fall back to CPU.
27292755 if qtype != tensor .GGMLTypeQ4_K {
2730- return e .cpu . MatMul (ctx , a , b , dst ... )
2756+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27312757 }
27322758
27332759 f32Size := int (unsafe .Sizeof (float32 (0 )))
27342760 dequantSize := m * k * f32Size
27352761 devAF32 , err := e .pool .Alloc (e .deviceID , dequantSize )
27362762 if err != nil {
2737- return e .cpu . MatMul (ctx , a , b , dst ... )
2763+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27382764 }
27392765 defer e .pool .Free (e .deviceID , devAF32 , dequantSize )
27402766
27412767 if err := e .kernels .DequantQ4KF32 (devW , devAF32 , m , k , e .stream ); err != nil {
2742- return e .cpu . MatMul (ctx , a , b , dst ... )
2768+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27432769 }
27442770
27452771 devB , cleanupB , err := getDevicePtr (e , b )
27462772 if err != nil {
2747- return e .cpu . MatMul (ctx , a , b , dst ... )
2773+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27482774 }
27492775 defer cleanupB ()
27502776
27512777 cSize := m * n * f32Size
27522778 devC , err := e .pool .Alloc (e .deviceID , cSize )
27532779 if err != nil {
2754- return e .cpu . MatMul (ctx , a , b , dst ... )
2780+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27552781 }
27562782
27572783 if err := e .blas .Sgemm (m , n , k , 1.0 , devAF32 , devB , 0.0 , devC ); err != nil {
@@ -2768,7 +2794,7 @@ func (e *GPUEngine[T]) matMulMmapB(ctx context.Context, a *tensor.TensorNumeric[
27682794 aShape := a .Shape ()
27692795 bShape := b .Shape ()
27702796 if len (aShape ) < 2 || len (bShape ) < 2 || len (bShape ) > 2 {
2771- return e .cpu . MatMul (ctx , a , b , dst ... )
2797+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27722798 }
27732799
27742800 // B is virtual-transposed: logical [K, N], physical [N, K].
@@ -2783,7 +2809,7 @@ func (e *GPUEngine[T]) matMulMmapB(ctx context.Context, a *tensor.TensorNumeric[
27832809
27842810 devW , freeW , err := e .mmapDevicePtr (ms )
27852811 if err != nil {
2786- return e .cpu . MatMul (ctx , a , b , dst ... )
2812+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27872813 }
27882814 defer freeW ()
27892815
@@ -2794,58 +2820,58 @@ func (e *GPUEngine[T]) matMulMmapB(ctx context.Context, a *tensor.TensorNumeric[
27942820 if m == 1 {
27952821 devX , cleanupX , err := getDevicePtr (e , a )
27962822 if err != nil {
2797- return e .cpu . MatMul (ctx , a , b , dst ... )
2823+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
27982824 }
27992825 defer cleanupX ()
28002826
28012827 f32Size := int (unsafe .Sizeof (float32 (0 )))
28022828 cSize := n * f32Size
28032829 devY , err := e .pool .Alloc (e .deviceID , cSize )
28042830 if err != nil {
2805- return e .cpu . MatMul (ctx , a , b , dst ... )
2831+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28062832 }
28072833
28082834 var kerr error
28092835 switch qtype {
28102836 case tensor .GGMLTypeQ4_K :
28112837 if k % 256 != 0 {
28122838 e .pool .Free (e .deviceID , devY , cSize )
2813- return e .cpu . MatMul (ctx , a , b , dst ... )
2839+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28142840 }
28152841 kerr = e .kernels .GemvQ4KF32 (devW , devX , devY , nPhys , k , e .stream )
28162842 case tensor .GGMLTypeQ4_0 :
28172843 if k % 32 != 0 {
28182844 e .pool .Free (e .deviceID , devY , cSize )
2819- return e .cpu . MatMul (ctx , a , b , dst ... )
2845+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28202846 }
28212847 totalBlocks := (nPhys * k ) / 32
28222848 dataOff := tensor .Q4GPUDataOffset (totalBlocks )
28232849 kerr = e .kernels .GemmQ4F32 (devW , devX , devY , nPhys , k , 1 , dataOff , e .stream )
28242850 case tensor .GGMLTypeQ8_0 :
28252851 if k % 32 != 0 {
28262852 e .pool .Free (e .deviceID , devY , cSize )
2827- return e .cpu . MatMul (ctx , a , b , dst ... )
2853+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28282854 }
28292855 kerr = e .kernels .GemmQ8F32 (devW , devX , devY , nPhys , k , 1 , e .stream )
28302856 case tensor .GGMLTypeQ6_K :
28312857 if k % 256 != 0 {
28322858 e .pool .Free (e .deviceID , devY , cSize )
2833- return e .cpu . MatMul (ctx , a , b , dst ... )
2859+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28342860 }
28352861 kerr = e .kernels .GemvQ6KF32 (devW , devX , devY , nPhys , k , e .stream )
28362862 case tensor .GGMLTypeQ5_K :
28372863 if k % 256 != 0 {
28382864 e .pool .Free (e .deviceID , devY , cSize )
2839- return e .cpu . MatMul (ctx , a , b , dst ... )
2865+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28402866 }
28412867 kerr = e .kernels .GemvQ5KF32 (devW , devX , devY , nPhys , k , e .stream )
28422868 default :
28432869 e .pool .Free (e .deviceID , devY , cSize )
2844- return e .cpu . MatMul (ctx , a , b , dst ... )
2870+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28452871 }
28462872 if kerr != nil {
28472873 e .pool .Free (e .deviceID , devY , cSize )
2848- return e .cpu . MatMul (ctx , a , b , dst ... )
2874+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28492875 }
28502876
28512877 outShape := make ([]int , len (aShape ))
@@ -2857,31 +2883,31 @@ func (e *GPUEngine[T]) matMulMmapB(ctx context.Context, a *tensor.TensorNumeric[
28572883 // General GEMM: dequantize Q4_K on GPU, then cuBLAS SgemmNT.
28582884 // Only Q4_K has a GPU dequant kernel; others fall back to CPU.
28592885 if qtype != tensor .GGMLTypeQ4_K {
2860- return e .cpu . MatMul (ctx , a , b , dst ... )
2886+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28612887 }
28622888
28632889 f32Size := int (unsafe .Sizeof (float32 (0 )))
28642890 dequantSize := nPhys * k * f32Size
28652891 devBF32 , err := e .pool .Alloc (e .deviceID , dequantSize )
28662892 if err != nil {
2867- return e .cpu . MatMul (ctx , a , b , dst ... )
2893+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28682894 }
28692895 defer e .pool .Free (e .deviceID , devBF32 , dequantSize )
28702896
28712897 if err := e .kernels .DequantQ4KF32 (devW , devBF32 , nPhys , k , e .stream ); err != nil {
2872- return e .cpu . MatMul (ctx , a , b , dst ... )
2898+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28732899 }
28742900
28752901 devA , cleanupA , err := getDevicePtr (e , a )
28762902 if err != nil {
2877- return e .cpu . MatMul (ctx , a , b , dst ... )
2903+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28782904 }
28792905 defer cleanupA ()
28802906
28812907 cSize := m * n * f32Size
28822908 devC , err := e .pool .Alloc (e .deviceID , cSize )
28832909 if err != nil {
2884- return e .cpu . MatMul (ctx , a , b , dst ... )
2910+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
28852911 }
28862912
28872913 outShape := make ([]int , len (aShape ))
@@ -2899,7 +2925,7 @@ func (e *GPUEngine[T]) matMulMmapB(ctx context.Context, a *tensor.TensorNumeric[
28992925
29002926 // Fallback: CPU MatMul.
29012927 e .pool .Free (e .deviceID , devC , cSize )
2902- return e .cpu . MatMul (ctx , a , b , dst ... )
2928+ return e .cpuMatMulToGPU (ctx , a , b , dst ... )
29032929}
29042930
29052931// mmapDevicePtr returns the GPU device pointer for MmapStorage data. If the data
0 commit comments