Skip to content

Commit 26bbd49

Browse files
committed
fix(compute): reuse dst GPU memory instead of allocating per call (#84)
GPU ops (gpuBinaryOp, gpuUnaryOp, gpuScalarOp, Transpose, MatMul, Sum) were allocating fresh device memory via pool.Alloc on every call even when a pre-sized dst tensor was provided, then swapping dst's storage to the new allocation. The old GPUStorage was orphaned and depended on Go's GC finalizer to call pool.Free. At large training shapes with hundreds of batches and ~20 ops per batch, orphaned allocations piled up faster than the GC could reclaim, causing unbounded GPU memory growth and OOM. Fix: add tryReuseDstPtr helper that checks if dst[0] already has a GPUStorage with sufficient capacity. If so, the kernel writes directly into the existing device pointer — no pool.Alloc, no orphaned storage, no GC pressure. When dst is nil or undersized, the existing alloc path is preserved unchanged. Applied to the six hot-path op families that cover PatchTST GPU training: - gpuBinaryOp (Add, Sub, Mul same-shape) - gpuUnaryOp (Exp, Log, Sin, Cos, Tanh, Sqrt) - gpuScalarOp (MulScalar, AddScalar, DivScalar) - Transpose (gpu_engine_memory.go) - MatMul standard float32 path (gpu_engine.go) - Sum/ReduceSum (gpu_kernels.go) Other ops (broadcast, Q4/Q8/BF16 matmul, fused kernels) continue using the existing alloc path and can be converted incrementally. Full ztensor test suite passes on CPU host. Closes #84 Refs zerfoo/zerfoo#373
1 parent 18a53fe commit 26bbd49

3 files changed

Lines changed: 124 additions & 31 deletions

File tree

compute/gpu_engine.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -978,13 +978,16 @@ func (e *GPUEngine[T]) MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T]
978978
return nil, err
979979
}
980980

981-
// Allocate device output.
982-
devCTotal, err := e.pool.Alloc(e.deviceID, outputBytes)
983-
if err != nil {
984-
e.oomFallbackCount.Add(1)
985-
e.logger.Warn("MatMul: GPU output alloc failed, falling back to CPU", "error", err.Error())
981+
// Reuse dst's existing GPU memory when possible (#84).
982+
devCTotal, reusedC := tryReuseDstPtr[T](batchSize*cMatSize, dst)
983+
if !reusedC {
984+
devCTotal, err = e.pool.Alloc(e.deviceID, outputBytes)
985+
if err != nil {
986+
e.oomFallbackCount.Add(1)
987+
e.logger.Warn("MatMul: GPU output alloc failed, falling back to CPU", "error", err.Error())
986988

987-
return e.cpu.MatMul(ctx, a, b, dst...)
989+
return e.cpu.MatMul(ctx, a, b, dst...)
990+
}
988991
}
989992

990993
// Use strided batched GEMM when available for float32 with batch > 1.
@@ -1013,9 +1016,14 @@ func (e *GPUEngine[T]) MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T]
10131016
if err := batched.SgemmStridedBatched(m, n, k, 1.0,
10141017
devA, strideA, devB, strideBVal, 0.0,
10151018
devCTotal, strideC, batchSize); err != nil {
1016-
e.pool.Free(e.deviceID, devCTotal, outputBytes)
1019+
if !reusedC {
1020+
e.pool.Free(e.deviceID, devCTotal, outputBytes)
1021+
}
10171022
return nil, fmt.Errorf("MatMul: batched GEMM: %w", err)
10181023
}
1024+
if reusedC {
1025+
return finishReusedDst[T](dst[0], outShape), nil
1026+
}
10191027
return makeGPUResult[T](e, outShape, devCTotal, batchSize*cMatSize, dst...)
10201028
}
10211029
}
@@ -1052,12 +1060,17 @@ func (e *GPUEngine[T]) MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T]
10521060
}
10531061

10541062
if blasErr != nil {
1055-
e.pool.Free(e.deviceID, devCTotal, outputBytes)
1063+
if !reusedC {
1064+
e.pool.Free(e.deviceID, devCTotal, outputBytes)
1065+
}
10561066

10571067
return nil, fmt.Errorf("MatMul: BLAS batch %d: %w", batch, blasErr)
10581068
}
10591069
}
10601070

1071+
if reusedC {
1072+
return finishReusedDst[T](dst[0], outShape), nil
1073+
}
10611074
return makeGPUResult[T](e, outShape, devCTotal, batchSize*cMatSize, dst...)
10621075
}
10631076

compute/gpu_engine_memory.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,14 @@ func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T]
132132
}
133133

134134
byteSize := total * f32Size
135-
devOut, err := e.pool.Alloc(e.deviceID, byteSize)
136-
if err != nil {
137-
return e.cpu.Transpose(ctx, a, axes, dst...)
135+
136+
// Reuse dst's existing GPU memory when possible (#84).
137+
devOut, reused := tryReuseDstPtr[T](total, dst)
138+
if !reused {
139+
devOut, err = e.pool.Alloc(e.deviceID, byteSize)
140+
if err != nil {
141+
return e.cpu.Transpose(ctx, a, axes, dst...)
142+
}
138143
}
139144

140145
// Fast path: 2D transpose.
@@ -145,9 +150,14 @@ func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T]
145150
"cols", fmt.Sprintf("%d", shape[1]))
146151
}
147152
if err := e.kernels.Transpose2D(devIn, devOut, shape[0], shape[1], e.stream); err != nil {
148-
e.pool.Free(e.deviceID, devOut, byteSize)
153+
if !reused {
154+
e.pool.Free(e.deviceID, devOut, byteSize)
155+
}
149156
return nil, err
150157
}
158+
if reused {
159+
return finishReusedDst[T](dst[0], outShape), nil
160+
}
151161
return makeGPUResult[T](e, outShape, devOut, total, dst...)
152162
}
153163

@@ -175,10 +185,15 @@ func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T]
175185
}
176186

177187
if err := e.kernels.TransposeND(devIn, devOut, inStrides32, outStrides32, perm32, rank, total, e.stream); err != nil {
178-
e.pool.Free(e.deviceID, devOut, byteSize)
188+
if !reused {
189+
e.pool.Free(e.deviceID, devOut, byteSize)
190+
}
179191
return nil, err
180192
}
181193

194+
if reused {
195+
return finishReusedDst[T](dst[0], outShape), nil
196+
}
182197
return makeGPUResult[T](e, outShape, devOut, total, dst...)
183198
}
184199

compute/gpu_kernels.go

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,35 @@ func getDevicePtr[T tensor.Numeric](e *GPUEngine[T], t *tensor.TensorNumeric[T])
115115
return devPtr, cleanup, nil
116116
}
117117

118+
// tryReuseDstPtr checks whether dst[0] already has a GPUStorage with at least
119+
// neededElems capacity. If so, it returns the existing device pointer so the
120+
// caller can write kernel output directly into it, avoiding a pool.Alloc and
121+
// the resulting GC-pressure from orphaned GPUStorage objects. See ztensor#84.
122+
func tryReuseDstPtr[T tensor.Numeric](neededElems int, dst []*tensor.TensorNumeric[T]) (unsafe.Pointer, bool) {
123+
if len(dst) == 0 || dst[0] == nil {
124+
return nil, false
125+
}
126+
gs, ok := dst[0].GetStorage().(*tensor.GPUStorage[T])
127+
if !ok || gs.Len() < neededElems {
128+
return nil, false
129+
}
130+
return gs.Ptr(), true
131+
}
132+
133+
// finishReusedDst updates dst's shape and strides in place after a kernel has
134+
// written into dst's existing device memory. No new GPUStorage is created.
135+
func finishReusedDst[T tensor.Numeric](dst *tensor.TensorNumeric[T], shape []int) *tensor.TensorNumeric[T] {
136+
strides := make([]int, len(shape))
137+
stride := 1
138+
for i := len(shape) - 1; i >= 0; i-- {
139+
strides[i] = stride
140+
stride *= shape[i]
141+
}
142+
dst.SetShape(shape)
143+
dst.SetStrides(strides)
144+
return dst
145+
}
146+
118147
// makeGPUResult creates a tensor with pool-backed GPUStorage wrapping the given
119148
// device pointer. When the tensor is freed, the pointer is returned to the pool
120149
// for reuse instead of calling cudaFree.
@@ -522,17 +551,26 @@ func gpuBinaryOp[T tensor.Numeric](
522551

523552
byteSize := n * f32Size
524553

525-
devC, err := e.pool.Alloc(e.deviceID, byteSize)
526-
if err != nil {
527-
return nil, err
554+
// Reuse dst's existing GPU memory when possible (#84).
555+
devC, reused := tryReuseDstPtr[T](n, dst)
556+
if !reused {
557+
devC, err = e.pool.Alloc(e.deviceID, byteSize)
558+
if err != nil {
559+
return nil, err
560+
}
528561
}
529562

530563
if err := kernelFn(devA, devB, devC, n, e.stream); err != nil {
531-
e.pool.Free(e.deviceID, devC, byteSize)
564+
if !reused {
565+
e.pool.Free(e.deviceID, devC, byteSize)
566+
}
532567

533568
return nil, err
534569
}
535570

571+
if reused {
572+
return finishReusedDst[T](dst[0], a.Shape()), nil
573+
}
536574
return makeGPUResult[T](e, a.Shape(), devC, n, dst...)
537575
}
538576

@@ -559,17 +597,26 @@ func gpuUnaryOp[T tensor.Numeric](
559597

560598
byteSize := n * f32Size
561599

562-
devC, err := e.pool.Alloc(e.deviceID, byteSize)
563-
if err != nil {
564-
return nil, err
600+
// Reuse dst's existing GPU memory when possible (#84).
601+
devC, reused := tryReuseDstPtr[T](n, dst)
602+
if !reused {
603+
devC, err = e.pool.Alloc(e.deviceID, byteSize)
604+
if err != nil {
605+
return nil, err
606+
}
565607
}
566608

567609
if err := kernelFn(devA, devC, n, e.stream); err != nil {
568-
e.pool.Free(e.deviceID, devC, byteSize)
610+
if !reused {
611+
e.pool.Free(e.deviceID, devC, byteSize)
612+
}
569613

570614
return nil, err
571615
}
572616

617+
if reused {
618+
return finishReusedDst[T](dst[0], a.Shape()), nil
619+
}
573620
return makeGPUResult[T](e, a.Shape(), devC, n, dst...)
574621
}
575622

@@ -597,17 +644,26 @@ func gpuScalarOp[T tensor.Numeric](
597644

598645
byteSize := n * f32Size
599646

600-
devC, err := e.pool.Alloc(e.deviceID, byteSize)
601-
if err != nil {
602-
return nil, err
647+
// Reuse dst's existing GPU memory when possible (#84).
648+
devC, reused := tryReuseDstPtr[T](n, dst)
649+
if !reused {
650+
devC, err = e.pool.Alloc(e.deviceID, byteSize)
651+
if err != nil {
652+
return nil, err
653+
}
603654
}
604655

605656
if err := kernelFn(devA, scalar, devC, n, e.stream); err != nil {
606-
e.pool.Free(e.deviceID, devC, byteSize)
657+
if !reused {
658+
e.pool.Free(e.deviceID, devC, byteSize)
659+
}
607660

608661
return nil, err
609662
}
610663

664+
if reused {
665+
return finishReusedDst[T](dst[0], a.Shape()), nil
666+
}
611667
return makeGPUResult[T](e, a.Shape(), devC, n, dst...)
612668
}
613669

@@ -957,20 +1013,29 @@ func (e *GPUEngine[T]) gpuSum(ctx context.Context, a *tensor.TensorNumeric[T], a
9571013

9581014
outByteSize := numStripes * f32Size
9591015

960-
devOut, err := e.pool.Alloc(e.deviceID, outByteSize)
961-
if err != nil {
962-
e.oomFallbackCount.Add(1)
963-
e.logger.Warn("Sum: GPU output alloc failed, falling back to CPU", "error", err.Error())
1016+
// Reuse dst's existing GPU memory when possible (#84).
1017+
devOut, reused := tryReuseDstPtr[T](numStripes, dst)
1018+
if !reused {
1019+
devOut, err = e.pool.Alloc(e.deviceID, outByteSize)
1020+
if err != nil {
1021+
e.oomFallbackCount.Add(1)
1022+
e.logger.Warn("Sum: GPU output alloc failed, falling back to CPU", "error", err.Error())
9641023

965-
return e.cpu.Sum(ctx, a, axis, keepDims, dst...)
1024+
return e.cpu.Sum(ctx, a, axis, keepDims, dst...)
1025+
}
9661026
}
9671027

9681028
if err := e.kernels.SumAxis(devIn, devOut, outer, inner, axisSize, e.stream); err != nil {
969-
e.pool.Free(e.deviceID, devOut, outByteSize)
1029+
if !reused {
1030+
e.pool.Free(e.deviceID, devOut, outByteSize)
1031+
}
9701032

9711033
return nil, err
9721034
}
9731035

1036+
if reused {
1037+
return finishReusedDst[T](dst[0], newShape), nil
1038+
}
9741039
return makeGPUResult[T](e, newShape, devOut, numStripes, dst...)
9751040
}
9761041

0 commit comments

Comments
 (0)