diff --git a/compute/gpu_engine_memory.go b/compute/gpu_engine_memory.go index be6d118..1aef149 100644 --- a/compute/gpu_engine_memory.go +++ b/compute/gpu_engine_memory.go @@ -636,18 +636,44 @@ func (e *GPUEngine[T]) Reshape(ctx context.Context, a *tensor.TensorNumeric[T], // Float16Storage: zero-copy reshape (same GPU pointer, new shape). if e.dtype != DTypeF32 { if fs, ok := any(a.GetStorage()).(*tensor.Float16Storage); ok && newSize == currentSize { - return tensor.NewWithStorage[T](inferredShape, any(fs).(tensor.Storage[T])) + storage := any(fs).(tensor.Storage[T]) + if len(dst) > 0 && dst[0] != nil { + aliasReshapeDst(dst[0], inferredShape, storage) + return dst[0], nil + } + return tensor.NewWithStorage[T](inferredShape, storage) } } // GPUStorage[T]: zero-copy reshape. if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok && isFloat32[T]() && newSize == currentSize { - return tensor.NewWithStorage[T](inferredShape, gs.View(gs.Len())) + view := gs.View(gs.Len()) + if len(dst) > 0 && dst[0] != nil { + aliasReshapeDst(dst[0], inferredShape, view) + return dst[0], nil + } + return tensor.NewWithStorage[T](inferredShape, view) } return e.cpu.Reshape(ctx, a, shape, dst...) } +// aliasReshapeDst mutates dst to alias the given storage under inferredShape, +// honoring the compute.Engine Reshape contract when a caller-provided dst is +// passed. Fixes the silent-zero trap where the GPU zero-copy fast-path used to +// drop dst, leaving its pre-allocated storage stale. See zerfoo/ztensor#81. +func aliasReshapeDst[T tensor.Numeric](dst *tensor.TensorNumeric[T], inferredShape []int, storage tensor.Storage[T]) { + strides := make([]int, len(inferredShape)) + stride := 1 + for i := len(inferredShape) - 1; i >= 0; i-- { + strides[i] = stride + stride *= inferredShape[i] + } + dst.SetStorage(storage) + dst.SetShape(inferredShape) + dst.SetStrides(strides) +} + // ConvertFP16ToF32 converts a tensor with Float16Storage to a regular float32 // GPU tensor using the FP16->F32 kernel. Returns the input unchanged if it // does not have Float16Storage. diff --git a/compute/gpu_reshape_dst_test.go b/compute/gpu_reshape_dst_test.go new file mode 100644 index 0000000..734637a --- /dev/null +++ b/compute/gpu_reshape_dst_test.go @@ -0,0 +1,144 @@ +package compute + +import ( + "context" + "testing" + + "github.com/zerfoo/ztensor/internal/cuda" + "github.com/zerfoo/ztensor/numeric" + "github.com/zerfoo/ztensor/tensor" +) + +// TestGPUEngine_Reshape_HonorsDst is the regression test for zerfoo/ztensor#81. +// Pre-fix, GPUEngine.Reshape's zero-copy GPUStorage fast-path returned a fresh +// tensor aliasing the source storage but ignored the caller-provided dst, +// leaving dst's pre-allocated (zero) buffer untouched. Callers that discarded +// the return value (e.g. zerfoo PatchTST GPU backward) silently fed all-zero +// gradients into encoderBackward and froze training loss. The fix mutates dst +// to alias the reshaped view; this test asserts that contract. +func TestGPUEngine_Reshape_HonorsDst(t *testing.T) { + if !cuda.Available() { + t.Skip("CUDA not available") + } + + ops := numeric.Float32Ops{} + eng, err := NewGPUEngine[float32](ops) + if err != nil { + t.Fatalf("NewGPUEngine: %v", err) + } + defer func() { _ = eng.Close() }() + + ctx := context.Background() + + // Source: a [4,4] tensor on the GPU with non-zero data. + src := make([]float32, 16) + for i := range src { + src[i] = float32(i + 1) // 1..16 + } + srcGS, err := tensor.NewGPUStorageFromSlice[float32](src) + if err != nil { + t.Fatalf("NewGPUStorageFromSlice src: %v", err) + } + srcGPU, err := tensor.NewWithStorage[float32]([]int{4, 4}, srcGS) + if err != nil { + t.Fatalf("NewWithStorage src: %v", err) + } + + // Destination: pre-allocate a [2,8] GPU tensor full of poison (0xDEADBEEF + // pattern as a recognisable non-zero value). The pre-fix bug left this + // buffer untouched; the post-fix contract requires dst to reflect src. + poison := make([]float32, 16) + for i := range poison { + poison[i] = -999.0 + } + dstGS, err := tensor.NewGPUStorageFromSlice[float32](poison) + if err != nil { + t.Fatalf("NewGPUStorageFromSlice dst: %v", err) + } + dst, err := tensor.NewWithStorage[float32]([]int{2, 8}, dstGS) + if err != nil { + t.Fatalf("NewWithStorage dst: %v", err) + } + + // Reshape src into dst's shape, passing dst as the output buffer. Discard + // the return value to mirror the zerfoo call pattern that triggered #81. + ret, err := eng.Reshape(ctx, srcGPU, []int{2, 8}, dst) + if err != nil { + t.Fatalf("Reshape: %v", err) + } + + // Contract 1: ret must be the same tensor object as dst (dst-honoring). + if ret != dst { + t.Errorf("Reshape returned a fresh tensor instead of mutating dst; "+ + "caller-provided dst was ignored. ret=%p dst=%p", ret, dst) + } + + // Contract 2: dst's shape must be the requested shape. + if got := dst.Shape(); len(got) != 2 || got[0] != 2 || got[1] != 8 { + t.Errorf("dst.Shape() = %v, want [2 8]", got) + } + + // Contract 3: dst's data must reflect src's data, not the poison pattern. + dstStorage, ok := dst.GetStorage().(*tensor.GPUStorage[float32]) + if !ok { + t.Fatalf("dst storage is not *GPUStorage[float32]: %T", dst.GetStorage()) + } + got := dstStorage.Slice() + if len(got) != 16 { + t.Fatalf("dst.GetStorage().Slice() len = %d, want 16", len(got)) + } + for i, v := range got { + want := float32(i + 1) + if v != want { + t.Errorf("dst.Data()[%d] = %v, want %v "+ + "(stale pre-allocated buffer — Reshape ignored dst)", i, v, want) + } + } +} + +// TestGPUEngine_Reshape_NoDst preserves the no-dst behavior: Reshape returns a +// fresh tensor aliasing the source view. This is the fast-path most callers use. +func TestGPUEngine_Reshape_NoDst(t *testing.T) { + if !cuda.Available() { + t.Skip("CUDA not available") + } + + ops := numeric.Float32Ops{} + eng, err := NewGPUEngine[float32](ops) + if err != nil { + t.Fatalf("NewGPUEngine: %v", err) + } + defer func() { _ = eng.Close() }() + + ctx := context.Background() + + src := make([]float32, 12) + for i := range src { + src[i] = float32(i) + } + srcGS, err := tensor.NewGPUStorageFromSlice[float32](src) + if err != nil { + t.Fatalf("NewGPUStorageFromSlice: %v", err) + } + srcGPU, err := tensor.NewWithStorage[float32]([]int{3, 4}, srcGS) + if err != nil { + t.Fatalf("NewWithStorage: %v", err) + } + + out, err := eng.Reshape(ctx, srcGPU, []int{2, 6}) + if err != nil { + t.Fatalf("Reshape: %v", err) + } + if got := out.Shape(); len(got) != 2 || got[0] != 2 || got[1] != 6 { + t.Errorf("out.Shape() = %v, want [2 6]", got) + } + outGS, ok := out.GetStorage().(*tensor.GPUStorage[float32]) + if !ok { + t.Fatalf("out storage is not *GPUStorage[float32]: %T", out.GetStorage()) + } + for i, v := range outGS.Slice() { + if v != float32(i) { + t.Errorf("out.Data()[%d] = %v, want %v", i, v, float32(i)) + } + } +}