Skip to content

Commit 18a53fe

Browse files
committed
fix(compute): GPUEngine.Reshape honors dst argument
The zero-copy GPUStorage and Float16Storage fast-paths in GPUEngine.Reshape used to construct a fresh tensor aliasing the source storage and silently drop the caller-provided dst, violating the compute.Engine.Reshape contract shared with CPUEngine. Callers that followed the common pattern of passing a pre-allocated dst and discarding the return value silently received stale pre-allocated storage. This is the root cause of zerfoo PatchTST GPU training convergence freeze (loss=0.268357 across all epochs on DGX GB10) tracked in #79. A one-line zerfoo workaround landed in zerfoo/zerfoo#371; this PR fixes the contract on the ztensor side so the next caller cannot hit the same trap. Fix: when dst is provided, mutate dst[0] to alias the reshaped view (SetStorage + SetShape + SetStrides) and return dst[0]. When no dst is provided, preserve current behavior (return a fresh tensor wrapping the view). The new aliasReshapeDst helper centralizes the dst-honoring logic for both Float16Storage and GPUStorage[T] fast-paths. Adds compute/gpu_reshape_dst_test.go with two regression tests: - TestGPUEngine_Reshape_HonorsDst: pre-allocates dst with a poison pattern, runs Reshape with dst, asserts (a) ret == dst, (b) shape is updated, (c) dst data reflects src not poison. - TestGPUEngine_Reshape_NoDst: preserves the no-dst fast-path behavior — returns a fresh tensor with correct shape and data. Both tests skip when CUDA is unavailable. Full ztensor test suite passes on CPU host. Closes #81 Refs #79 Refs zerfoo/zerfoo#371
1 parent 957e74c commit 18a53fe

2 files changed

Lines changed: 172 additions & 2 deletions

File tree

compute/gpu_engine_memory.go

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -636,18 +636,44 @@ func (e *GPUEngine[T]) Reshape(ctx context.Context, a *tensor.TensorNumeric[T],
636636
// Float16Storage: zero-copy reshape (same GPU pointer, new shape).
637637
if e.dtype != DTypeF32 {
638638
if fs, ok := any(a.GetStorage()).(*tensor.Float16Storage); ok && newSize == currentSize {
639-
return tensor.NewWithStorage[T](inferredShape, any(fs).(tensor.Storage[T]))
639+
storage := any(fs).(tensor.Storage[T])
640+
if len(dst) > 0 && dst[0] != nil {
641+
aliasReshapeDst(dst[0], inferredShape, storage)
642+
return dst[0], nil
643+
}
644+
return tensor.NewWithStorage[T](inferredShape, storage)
640645
}
641646
}
642647

643648
// GPUStorage[T]: zero-copy reshape.
644649
if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok && isFloat32[T]() && newSize == currentSize {
645-
return tensor.NewWithStorage[T](inferredShape, gs.View(gs.Len()))
650+
view := gs.View(gs.Len())
651+
if len(dst) > 0 && dst[0] != nil {
652+
aliasReshapeDst(dst[0], inferredShape, view)
653+
return dst[0], nil
654+
}
655+
return tensor.NewWithStorage[T](inferredShape, view)
646656
}
647657

648658
return e.cpu.Reshape(ctx, a, shape, dst...)
649659
}
650660

661+
// aliasReshapeDst mutates dst to alias the given storage under inferredShape,
662+
// honoring the compute.Engine Reshape contract when a caller-provided dst is
663+
// passed. Fixes the silent-zero trap where the GPU zero-copy fast-path used to
664+
// drop dst, leaving its pre-allocated storage stale. See zerfoo/ztensor#81.
665+
func aliasReshapeDst[T tensor.Numeric](dst *tensor.TensorNumeric[T], inferredShape []int, storage tensor.Storage[T]) {
666+
strides := make([]int, len(inferredShape))
667+
stride := 1
668+
for i := len(inferredShape) - 1; i >= 0; i-- {
669+
strides[i] = stride
670+
stride *= inferredShape[i]
671+
}
672+
dst.SetStorage(storage)
673+
dst.SetShape(inferredShape)
674+
dst.SetStrides(strides)
675+
}
676+
651677
// ConvertFP16ToF32 converts a tensor with Float16Storage to a regular float32
652678
// GPU tensor using the FP16->F32 kernel. Returns the input unchanged if it
653679
// does not have Float16Storage.

compute/gpu_reshape_dst_test.go

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
package compute
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/zerfoo/ztensor/internal/cuda"
8+
"github.com/zerfoo/ztensor/numeric"
9+
"github.com/zerfoo/ztensor/tensor"
10+
)
11+
12+
// TestGPUEngine_Reshape_HonorsDst is the regression test for zerfoo/ztensor#81.
13+
// Pre-fix, GPUEngine.Reshape's zero-copy GPUStorage fast-path returned a fresh
14+
// tensor aliasing the source storage but ignored the caller-provided dst,
15+
// leaving dst's pre-allocated (zero) buffer untouched. Callers that discarded
16+
// the return value (e.g. zerfoo PatchTST GPU backward) silently fed all-zero
17+
// gradients into encoderBackward and froze training loss. The fix mutates dst
18+
// to alias the reshaped view; this test asserts that contract.
19+
func TestGPUEngine_Reshape_HonorsDst(t *testing.T) {
20+
if !cuda.Available() {
21+
t.Skip("CUDA not available")
22+
}
23+
24+
ops := numeric.Float32Ops{}
25+
eng, err := NewGPUEngine[float32](ops)
26+
if err != nil {
27+
t.Fatalf("NewGPUEngine: %v", err)
28+
}
29+
defer func() { _ = eng.Close() }()
30+
31+
ctx := context.Background()
32+
33+
// Source: a [4,4] tensor on the GPU with non-zero data.
34+
src := make([]float32, 16)
35+
for i := range src {
36+
src[i] = float32(i + 1) // 1..16
37+
}
38+
srcGS, err := tensor.NewGPUStorageFromSlice[float32](src)
39+
if err != nil {
40+
t.Fatalf("NewGPUStorageFromSlice src: %v", err)
41+
}
42+
srcGPU, err := tensor.NewWithStorage[float32]([]int{4, 4}, srcGS)
43+
if err != nil {
44+
t.Fatalf("NewWithStorage src: %v", err)
45+
}
46+
47+
// Destination: pre-allocate a [2,8] GPU tensor full of poison (0xDEADBEEF
48+
// pattern as a recognisable non-zero value). The pre-fix bug left this
49+
// buffer untouched; the post-fix contract requires dst to reflect src.
50+
poison := make([]float32, 16)
51+
for i := range poison {
52+
poison[i] = -999.0
53+
}
54+
dstGS, err := tensor.NewGPUStorageFromSlice[float32](poison)
55+
if err != nil {
56+
t.Fatalf("NewGPUStorageFromSlice dst: %v", err)
57+
}
58+
dst, err := tensor.NewWithStorage[float32]([]int{2, 8}, dstGS)
59+
if err != nil {
60+
t.Fatalf("NewWithStorage dst: %v", err)
61+
}
62+
63+
// Reshape src into dst's shape, passing dst as the output buffer. Discard
64+
// the return value to mirror the zerfoo call pattern that triggered #81.
65+
ret, err := eng.Reshape(ctx, srcGPU, []int{2, 8}, dst)
66+
if err != nil {
67+
t.Fatalf("Reshape: %v", err)
68+
}
69+
70+
// Contract 1: ret must be the same tensor object as dst (dst-honoring).
71+
if ret != dst {
72+
t.Errorf("Reshape returned a fresh tensor instead of mutating dst; "+
73+
"caller-provided dst was ignored. ret=%p dst=%p", ret, dst)
74+
}
75+
76+
// Contract 2: dst's shape must be the requested shape.
77+
if got := dst.Shape(); len(got) != 2 || got[0] != 2 || got[1] != 8 {
78+
t.Errorf("dst.Shape() = %v, want [2 8]", got)
79+
}
80+
81+
// Contract 3: dst's data must reflect src's data, not the poison pattern.
82+
dstStorage, ok := dst.GetStorage().(*tensor.GPUStorage[float32])
83+
if !ok {
84+
t.Fatalf("dst storage is not *GPUStorage[float32]: %T", dst.GetStorage())
85+
}
86+
got := dstStorage.Slice()
87+
if len(got) != 16 {
88+
t.Fatalf("dst.GetStorage().Slice() len = %d, want 16", len(got))
89+
}
90+
for i, v := range got {
91+
want := float32(i + 1)
92+
if v != want {
93+
t.Errorf("dst.Data()[%d] = %v, want %v "+
94+
"(stale pre-allocated buffer — Reshape ignored dst)", i, v, want)
95+
}
96+
}
97+
}
98+
99+
// TestGPUEngine_Reshape_NoDst preserves the no-dst behavior: Reshape returns a
100+
// fresh tensor aliasing the source view. This is the fast-path most callers use.
101+
func TestGPUEngine_Reshape_NoDst(t *testing.T) {
102+
if !cuda.Available() {
103+
t.Skip("CUDA not available")
104+
}
105+
106+
ops := numeric.Float32Ops{}
107+
eng, err := NewGPUEngine[float32](ops)
108+
if err != nil {
109+
t.Fatalf("NewGPUEngine: %v", err)
110+
}
111+
defer func() { _ = eng.Close() }()
112+
113+
ctx := context.Background()
114+
115+
src := make([]float32, 12)
116+
for i := range src {
117+
src[i] = float32(i)
118+
}
119+
srcGS, err := tensor.NewGPUStorageFromSlice[float32](src)
120+
if err != nil {
121+
t.Fatalf("NewGPUStorageFromSlice: %v", err)
122+
}
123+
srcGPU, err := tensor.NewWithStorage[float32]([]int{3, 4}, srcGS)
124+
if err != nil {
125+
t.Fatalf("NewWithStorage: %v", err)
126+
}
127+
128+
out, err := eng.Reshape(ctx, srcGPU, []int{2, 6})
129+
if err != nil {
130+
t.Fatalf("Reshape: %v", err)
131+
}
132+
if got := out.Shape(); len(got) != 2 || got[0] != 2 || got[1] != 6 {
133+
t.Errorf("out.Shape() = %v, want [2 6]", got)
134+
}
135+
outGS, ok := out.GetStorage().(*tensor.GPUStorage[float32])
136+
if !ok {
137+
t.Fatalf("out storage is not *GPUStorage[float32]: %T", out.GetStorage())
138+
}
139+
for i, v := range outGS.Slice() {
140+
if v != float32(i) {
141+
t.Errorf("out.Data()[%d] = %v, want %v", i, v, float32(i))
142+
}
143+
}
144+
}

0 commit comments

Comments
 (0)