Skip to content

Commit 8d6c90b

Browse files
committed
feat(compute): add AllocDeviceFloat32 and CopyToDevice to FusedEncoderProvider
Enable callers to allocate persistent GPU buffers and upload weight data for the fused encoder kernel. Without this, CPU-backed weight tensors have no device pointer and the fused path always falls back to per-op. - AllocDeviceFloat32: pool-managed GPU allocation - CopyToDevice: host-to-device memcpy for float32 arrays
1 parent 716bbd6 commit 8d6c90b

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

compute/fused_encoder.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ type FusedEncoderProvider interface {
1717
// FusedEncoderAvailable returns true if the fused encoder kernel is loaded.
1818
FusedEncoderAvailable() bool
1919

20+
// AllocDeviceFloat32 allocates numElements float32s on the GPU and returns
21+
// the device pointer. Memory is pool-managed and freed when the engine closes.
22+
AllocDeviceFloat32(numElements int) (unsafe.Pointer, error)
23+
24+
// CopyToDevice copies len(src) float32 values from host to a device pointer.
25+
CopyToDevice(dst unsafe.Pointer, src []float32) error
26+
2027
// FusedEncoderForward executes one encoder layer forward pass.
2128
// weights: [16]unsafe.Pointer to layer weights.
2229
// bufs: [16]unsafe.Pointer to pre-allocated forward cache buffers.

compute/gpu_fused_encoder.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"unsafe"
66

77
"github.com/zerfoo/ztensor/internal/cublas"
8+
"github.com/zerfoo/ztensor/internal/gpuapi"
89
)
910

1011
// blasHandlePtr extracts the raw cuBLAS handle pointer from the BLAS interface.
@@ -68,5 +69,17 @@ func (e *GPUEngine[T]) FusedEncoderBackward(
6869
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches, e.stream)
6970
}
7071

72+
// AllocDeviceFloat32 allocates GPU memory for numElements float32 values.
73+
func (e *GPUEngine[T]) AllocDeviceFloat32(numElements int) (unsafe.Pointer, error) {
74+
e.setDevice()
75+
return e.pool.Alloc(e.deviceID, numElements*4)
76+
}
77+
78+
// CopyToDevice copies float32 data from host to device.
79+
func (e *GPUEngine[T]) CopyToDevice(dst unsafe.Pointer, src []float32) error {
80+
e.setDevice()
81+
return e.runtime.Memcpy(dst, unsafe.Pointer(&src[0]), len(src)*4, gpuapi.MemcpyHostToDevice)
82+
}
83+
7184
// Compile-time check that GPUEngine implements FusedEncoderProvider.
7285
var _ FusedEncoderProvider = (*GPUEngine[float32])(nil)

0 commit comments

Comments
 (0)