diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fcd19e4..79aff33 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,5 +18,5 @@ jobs: # Run go vet on all packages except those with intentional # unsafe.Pointer usage for GPU runtime bindings via purego/dlopen. # These warnings are expected and documented in docs/QUALITY.md. - go vet $(go list ./... | grep -v '/internal/cuda$' | grep -v '/internal/hip$' | grep -v '/internal/opencl$' | grep -v '/internal/cudnn$' | grep -v '/internal/tensorrt$' | grep -v '/internal/fpga$' | grep -v '/internal/sycl$' | grep -v '/internal/metal$' | grep -v '/internal/pjrt$' | grep -v '/internal/nccl$') + go vet $(go list ./... | grep -v '/internal/cuda$' | grep -v '/internal/cublas$' | grep -v '/internal/hip$' | grep -v '/internal/opencl$' | grep -v '/internal/cudnn$' | grep -v '/internal/tensorrt$' | grep -v '/internal/fpga$' | grep -v '/internal/sycl$' | grep -v '/internal/metal$' | grep -v '/internal/pjrt$' | grep -v '/internal/nccl$') - run: go test -race -timeout 300s ./... diff --git a/compute/fused_encoder.go b/compute/fused_encoder.go new file mode 100644 index 0000000..226ab2f --- /dev/null +++ b/compute/fused_encoder.go @@ -0,0 +1,47 @@ +package compute + +import "unsafe" + +// FusedEncoderProvider is implemented by engines that support fused PatchTST +// encoder layer forward and backward passes. The fused kernel replaces ~78 +// discrete engine operations per layer with a single orchestrated call, +// using cuBLAS for GEMMs and custom CUDA sub-kernels for LayerNorm, GELU, +// softmax, head transpose, and residual operations. +// +// Callers must pre-allocate all buffer arrays and pass device pointers. +// Buffer index constants (FEW_*, FEB_*, FEG_*, etc.) are defined in +// internal/cuda/kernels/fused_encoder_fwd_purego.go and fused_encoder_bwd_purego.go. +// +// This API is not covered by the v1 stability guarantee. +type FusedEncoderProvider interface { + // FusedEncoderAvailable returns true if the fused encoder kernel is loaded. + FusedEncoderAvailable() bool + + // FusedEncoderForward executes one encoder layer forward pass. + // weights: [16]unsafe.Pointer to layer weights. + // bufs: [16]unsafe.Pointer to pre-allocated forward cache buffers. + // input/output: [totalRows, dModel] device pointers. + FusedEncoderForward( + weights *[16]unsafe.Pointer, + bufs *[16]unsafe.Pointer, + input, output unsafe.Pointer, + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, + ) error + + // FusedEncoderBackward computes all gradients for one encoder layer. + // weights: [16]unsafe.Pointer to layer weights. + // weightT: [6]unsafe.Pointer to pre-transposed weights. + // fwdBufs: [16]unsafe.Pointer to forward cache (from FusedEncoderForward). + // bwdBufs: [15]unsafe.Pointer to backward scratch buffers. + // grads: [16]unsafe.Pointer to gradient accumulators (accumulated, not zeroed). + // dOutput: upstream gradient; dInput: output gradient; input: original layer input. + FusedEncoderBackward( + weights *[16]unsafe.Pointer, + weightT *[6]unsafe.Pointer, + fwdBufs *[16]unsafe.Pointer, + bwdBufs *[15]unsafe.Pointer, + grads *[16]unsafe.Pointer, + dOutput, dInput, input unsafe.Pointer, + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, + ) error +} diff --git a/compute/gpu_fused_encoder.go b/compute/gpu_fused_encoder.go new file mode 100644 index 0000000..f549305 --- /dev/null +++ b/compute/gpu_fused_encoder.go @@ -0,0 +1,72 @@ +package compute + +import ( + "fmt" + "unsafe" + + "github.com/zerfoo/ztensor/internal/cublas" +) + +// blasHandlePtr extracts the raw cuBLAS handle pointer from the BLAS interface. +// Returns nil if the BLAS is not backed by cuBLAS. +func blasHandlePtr(b interface{}) unsafe.Pointer { + type handleProvider interface { + Handle() *cublas.Handle + } + if hp, ok := b.(handleProvider); ok { + h := hp.Handle() + if h != nil { + return h.Ptr() + } + } + return nil +} + +// FusedEncoderAvailable returns true if the fused encoder kernel is loaded +// and the engine has a cuBLAS handle to pass to it. +func (e *GPUEngine[T]) FusedEncoderAvailable() bool { + return e.kernels.FusedEncoderFwdAvailable() && blasHandlePtr(e.blas) != nil +} + +// FusedEncoderForward executes one fused encoder layer forward pass. +func (e *GPUEngine[T]) FusedEncoderForward( + weights *[16]unsafe.Pointer, + bufs *[16]unsafe.Pointer, + input, output unsafe.Pointer, + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, +) error { + h := blasHandlePtr(e.blas) + if h == nil { + return fmt.Errorf("FusedEncoderForward: cuBLAS handle not available") + } + e.setDevice() + return e.kernels.FusedEncoderFwdF32(h, weights, bufs, input, output, + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches, e.stream) +} + +// FusedEncoderBackward computes all gradients for one fused encoder layer. +func (e *GPUEngine[T]) FusedEncoderBackward( + weights *[16]unsafe.Pointer, + weightT *[6]unsafe.Pointer, + fwdBufs *[16]unsafe.Pointer, + bwdBufs *[15]unsafe.Pointer, + grads *[16]unsafe.Pointer, + dOutput, dInput, input unsafe.Pointer, + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, +) error { + h := blasHandlePtr(e.blas) + if h == nil { + return fmt.Errorf("FusedEncoderBackward: cuBLAS handle not available") + } + e.setDevice() + // The KernelRunner interface uses *[16] for weightT, but we have *[6]. + // Convert via unsafe pointer. + var wt16 [16]unsafe.Pointer + copy(wt16[:6], weightT[:]) + return e.kernels.FusedEncoderBwdF32(h, weights, &wt16, fwdBufs, bwdBufs, grads, + dOutput, dInput, input, + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches, e.stream) +} + +// Compile-time check that GPUEngine implements FusedEncoderProvider. +var _ FusedEncoderProvider = (*GPUEngine[float32])(nil) diff --git a/internal/cublas/cublas_purego.go b/internal/cublas/cublas_purego.go index 287ea47..84d7ad1 100644 --- a/internal/cublas/cublas_purego.go +++ b/internal/cublas/cublas_purego.go @@ -107,6 +107,10 @@ type Handle struct { ptr uintptr // cublasHandle_t is a pointer } +// Ptr returns the raw cuBLAS handle pointer for passing to C functions +// (e.g., the fused encoder kernel orchestrator). +func (h *Handle) Ptr() unsafe.Pointer { return unsafe.Pointer(h.ptr) } + // CreateHandle creates a new cuBLAS context handle. func CreateHandle() (*Handle, error) { lib, err := getCublasLib() diff --git a/internal/cuda/kernels/Makefile b/internal/cuda/kernels/Makefile index 3e25b8d..12ceebc 100644 --- a/internal/cuda/kernels/Makefile +++ b/internal/cuda/kernels/Makefile @@ -11,7 +11,7 @@ ifeq ($(CUDA_ARCH),sm_121) NVCC_FLAGS += -DFLASH_BLOCK_SIZE=64 endif -SRCS = counter.cu dequant_q4k.cu dequant_q5_0.cu dequant_q5k.cu dequant_q6k.cu elementwise.cu elementwise_fp16.cu flash_attention.cu flash_attention2.cu flash_decode.cu fp4_gemv.cu fp8_gemm.cu fp8_ops.cu fused_add_rmsnorm.cu fused_norm_add.cu fused_qk_norm_rope.cu fused_repeat_interleave.cu fused_rope.cu fused_softmax_vmul.cu fused_swiglu.cu gather.cu gather_q8.cu gemm_int8.cu gemm_int4.cu gemm_q4.cu gemm_q8.cu gemv_q4k.cu gemv_q4k_sm121.cu gemv_q5k.cu gemv_q5_0.cu gemv_q6k.cu gemv_warp.cu megakernel_ops.cu offset_memcpy.cu paged_attention.cu ragged_attention.cu rope_select.cu scaled_softmax.cu selective_scan.cu sgemv_m1.cu ternary_gemv.cu transpose.cu rmsnorm.cu argmax.cu +SRCS = counter.cu dequant_q4k.cu dequant_q5_0.cu dequant_q5k.cu dequant_q6k.cu elementwise.cu elementwise_fp16.cu flash_attention.cu flash_attention2.cu flash_decode.cu fp4_gemv.cu fp8_gemm.cu fp8_ops.cu fused_add_rmsnorm.cu fused_encoder_fwd.cu fused_encoder_bwd.cu fused_norm_add.cu fused_qk_norm_rope.cu fused_repeat_interleave.cu fused_rope.cu fused_softmax_vmul.cu fused_swiglu.cu gather.cu gather_q8.cu gemm_int8.cu gemm_int4.cu gemm_q4.cu gemm_q8.cu gemv_q4k.cu gemv_q4k_sm121.cu gemv_q5k.cu gemv_q5_0.cu gemv_q6k.cu gemv_warp.cu megakernel_ops.cu offset_memcpy.cu paged_attention.cu ragged_attention.cu rope_select.cu scaled_softmax.cu selective_scan.cu sgemv_m1.cu ternary_gemv.cu transpose.cu rmsnorm.cu argmax.cu OBJS = $(SRCS:.cu=.o) PIC_OBJS = $(SRCS:.cu=.pic.o) LIB = libkernels.a @@ -27,7 +27,7 @@ $(LIB): $(OBJS) ar rcs $@ $^ $(SO): $(PIC_OBJS) - $(NVCC) -shared -o $@ $^ + $(NVCC) -shared -o $@ $^ -lcublas # Limit register pressure for kernels that benefit from higher occupancy. # gemm_q4: 40->32 regs/thread, no spills, occupancy 75%->100% (256-thread blocks). diff --git a/internal/cuda/kernels/fused_encoder_bwd.cu b/internal/cuda/kernels/fused_encoder_bwd.cu new file mode 100644 index 0000000..3c6e68a --- /dev/null +++ b/internal/cuda/kernels/fused_encoder_bwd.cu @@ -0,0 +1,731 @@ +/* fused_encoder_bwd.cu -- Fused PatchTST encoder layer backward pass. + * + * Host-side orchestrator that computes all gradients for one encoder layer + * in a single C function call. Uses the forward cache (FEB_* buffers) + * to avoid recomputation. + * + * Sub-kernel inventory (backward-specific): + * kernel_layernorm_bwd LayerNorm backward (dScale, dBias, dInput) + * kernel_gelu_bwd GELU derivative * upstream gradient + * kernel_softmax_bwd Softmax backward (Jacobian-vector product) + * kernel_bias_grad_reduce Sum rows to compute bias gradients + * kernel_add_elementwise Element-wise addition for residual gradients + * kernel_matmul_grad_accum Accumulate weight gradient: dW += A^T @ B + * + * cuBLAS calls (~14 total per layer): + * FFN2 backward: dW, dInput via Sgemm (2 calls) + * FFN1 backward: dW, dInput via Sgemm (2 calls) + * Output proj bwd: dW, dInput via Sgemm (2 calls) + * Attention bwd: dV, dScores, dQ, dK via batched Sgemm (4 calls) + * Q/K/V bwd: dW*3, dInput*3 via Sgemm (6 calls) + * + * Compile: nvcc -O3 --use_fast_math -arch=sm_121 -lcublas -c fused_encoder_bwd.cu + */ + +#include "fused_encoder_bwd.h" +#include +#include + +/* ------------------------------------------------------------------ */ +/* Helpers (same as fwd, repeated for compilation unit independence) */ +/* ------------------------------------------------------------------ */ + +static inline int next_pow2_bwd(int v) { + int b = 1; + while (b < v && b < 256) b <<= 1; + return b; +} + +/* Row-major C[M,N] = alpha * A[M,K] * B[K,N] + beta * C */ +static inline cublasStatus_t sgemm_nn( + cublasHandle_t h, int M, int N, int K, float alpha, + const float* A, const float* B, float beta, float* C) +{ + return cublasSgemm(h, CUBLAS_OP_N, CUBLAS_OP_N, + N, M, K, &alpha, B, N, A, K, &beta, C, N); +} + +/* Row-major C[M,N] = alpha * A^T[M,K] * B[K,N] + beta * C + * A is [K,M] (transposed to give [M,K]). */ +static inline cublasStatus_t sgemm_tn( + cublasHandle_t h, int M, int N, int K, float alpha, + const float* A, const float* B, float beta, float* C) +{ + /* Row-major A^T * B: column-major C^T = B^T * A + * cublas(CUBLAS_OP_N, CUBLAS_OP_T, N, M, K, alpha, B, N, A, M, beta, C, N) + * Wait, A is [K,M] row-major = [M,K] col-major. + * We want A^T[M,K] * B[K,N] in row-major. + * Col-major: C^T[N,M] = B_cm * (A^T)_cm + * B[K,N] row-major = B_cm[N,K] col-major. + * A^T: A is [K,M] row-major = A_cm[M,K] col-major. A^T in col-major = [K,M]. + * So (A^T)_cm[K,M] with op=N -> [K,M]. + * We need: C_cm[N,M] = ?[N,k] * ?[k,M] + * First: B_cm[N,K] with op=T -> [K,N]? No, we need [N,K]. + * B_cm[N,K] with op=N -> [N,K]. m=N, k... doesn't match. + * Let me use: cublas(transa, transb, m, n, k): + * C_cm[m,n] = op(A_cub)[m,k] * op(B_cub)[k,n] + * m=N, n=M, k=K + * op(A_cub)[N,K]: B row-major[K,N] -> col-major [N,K]. op=N. A_cub=B, lda=N. OK. + * op(B_cub)[K,M]: A row-major[K,M] -> col-major [M,K]. op=T -> [K,M]. B_cub=A, ldb=M. OK. + */ + return cublasSgemm(h, CUBLAS_OP_N, CUBLAS_OP_T, + N, M, K, &alpha, B, N, A, M, &beta, C, N); +} + +/* Row-major batched C[b,M,N] = alpha * A[b,M,K] * B[b,K,N]^T + beta * C */ +static inline cublasStatus_t sgemm_nt_batched( + cublasHandle_t h, int M, int N, int K, float alpha, + const float* A, long long sA, + const float* B, long long sB, + float beta, + float* C, long long sC, + int batch) +{ + return cublasSgemmStridedBatched(h, + CUBLAS_OP_T, CUBLAS_OP_N, + N, M, K, &alpha, + B, K, sB, + A, K, sA, + &beta, + C, N, sC, + batch); +} + +/* Row-major batched C[b,M,N] = alpha * A[b,M,K] * B[b,K,N] + beta * C */ +static inline cublasStatus_t sgemm_nn_batched( + cublasHandle_t h, int M, int N, int K, float alpha, + const float* A, long long sA, + const float* B, long long sB, + float beta, + float* C, long long sC, + int batch) +{ + return cublasSgemmStridedBatched(h, + CUBLAS_OP_N, CUBLAS_OP_N, + N, M, K, &alpha, + B, N, sB, + A, K, sA, + &beta, + C, N, sC, + batch); +} + +/* Row-major batched C[b,M,N] = alpha * A^T[b,M,K] * B[b,K,N] + beta * C + * A is [b,K,M], transposed per batch to [b,M,K]. */ +static inline cublasStatus_t sgemm_tn_batched( + cublasHandle_t h, int M, int N, int K, float alpha, + const float* A, long long sA, + const float* B, long long sB, + float beta, + float* C, long long sC, + int batch) +{ + /* Same derivation as sgemm_tn but batched. */ + return cublasSgemmStridedBatched(h, + CUBLAS_OP_N, CUBLAS_OP_T, + N, M, K, &alpha, + B, N, sB, + A, M, sA, + &beta, + C, N, sC, + batch); +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: LayerNorm backward */ +/* Given forward: y = (x - mean) * invstd * scale + bias */ +/* Computes: dScale, dBias (accumulated), dX */ +/* */ +/* Each block handles one row. */ +/* dScale[j] += sum_i (x[i,j] - mean[i]) * invstd[i] * dY[i,j] */ +/* dBias[j] += sum_i dY[i,j] */ +/* dX[i,j] = invstd * (scale * dY - mean(scale*dY) - centered * */ +/* mean(scale*dY*centered) * invstd^2) */ +/* */ +/* Simplified: uses the standard LayerNorm backward formula. */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_layernorm_bwd( + const float* __restrict__ dY, + const float* __restrict__ x, + const float* __restrict__ invstd, + const float* __restrict__ scale, + float* __restrict__ dX, + float* __restrict__ dScale, /* [D], atomicAdd */ + float* __restrict__ dBias, /* [D], atomicAdd */ + int D) +{ + int row = blockIdx.x; + const float* dy_r = dY + row * D; + const float* x_r = x + row * D; + float* dx_r = dX + row * D; + float inv = invstd[row]; + + extern __shared__ float smem[]; + + /* Compute mean of x for this row. */ + float local_sum = 0.0f; + for (int i = threadIdx.x; i < D; i += blockDim.x) { + local_sum += x_r[i]; + } + smem[threadIdx.x] = local_sum; + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) smem[threadIdx.x] += smem[threadIdx.x + s]; + __syncthreads(); + } + float mean = smem[0] / (float)D; + + /* Compute dScale and dBias contributions (atomicAdd across rows). */ + for (int i = threadIdx.x; i < D; i += blockDim.x) { + float centered = (x_r[i] - mean) * inv; + atomicAdd(&dScale[i], dy_r[i] * centered); + atomicAdd(&dBias[i], dy_r[i]); + } + + /* Compute dX using the standard LayerNorm backward formula: + * ds = sum(dy * scale * centered) + * dm = sum(dy * scale) + * dX = inv * (scale * dY - dm/D - centered * ds/D) */ + float ds_local = 0.0f; + float dm_local = 0.0f; + for (int i = threadIdx.x; i < D; i += blockDim.x) { + float centered = (x_r[i] - mean) * inv; + float scaled_dy = dy_r[i] * scale[i]; + ds_local += scaled_dy * centered; + dm_local += scaled_dy; + } + smem[threadIdx.x] = ds_local; + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) smem[threadIdx.x] += smem[threadIdx.x + s]; + __syncthreads(); + } + float ds = smem[0]; + + smem[threadIdx.x] = dm_local; + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) smem[threadIdx.x] += smem[threadIdx.x + s]; + __syncthreads(); + } + float dm = smem[0]; + + float inv_D = 1.0f / (float)D; + for (int i = threadIdx.x; i < D; i += blockDim.x) { + float centered = (x_r[i] - mean) * inv; + dx_r[i] = inv * (dy_r[i] * scale[i] - dm * inv_D - centered * ds * inv_D); + } +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: GELU backward */ +/* dX[i] = dY[i] * gelu'(pre_act[i]) */ +/* where gelu'(x) = 0.5*(1+tanh(u)) + 0.5*x*sech^2(u)*du/dx */ +/* u = sqrt(2/pi)*(x + 0.044715*x^3), du/dx = sqrt(2/pi)*(1+3*0.044715*x^2) */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_gelu_bwd( + const float* __restrict__ dY, + const float* __restrict__ pre_act, + float* __restrict__ dX, + int n) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n) return; + + float x = pre_act[idx]; + float x2 = x * x; + float x3 = x2 * x; + float sqrt2pi = 0.7978845608f; + float c = 0.044715f; + float u = sqrt2pi * (x + c * x3); + float tanh_u = tanhf(u); + float sech2 = 1.0f - tanh_u * tanh_u; + float du_dx = sqrt2pi * (1.0f + 3.0f * c * x2); + + float gelu_deriv = 0.5f * (1.0f + tanh_u) + 0.5f * x * sech2 * du_dx; + dX[idx] = dY[idx] * gelu_deriv; +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: Softmax backward */ +/* Given forward: s = softmax(logits) */ +/* dLogits[i] = s[i] * (dS[i] - sum_j(s[j] * dS[j])) */ +/* Each block handles one row. */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_softmax_bwd( + const float* __restrict__ scores, /* softmax output from forward */ + const float* __restrict__ dScores, /* upstream gradient */ + float* __restrict__ dLogits, + int cols) +{ + int row = blockIdx.x; + const float* s_r = scores + row * cols; + const float* ds_r = dScores + row * cols; + float* dl_r = dLogits + row * cols; + + extern __shared__ float smem[]; + + /* Compute dot = sum(s * dS). */ + float local_dot = 0.0f; + for (int i = threadIdx.x; i < cols; i += blockDim.x) { + local_dot += s_r[i] * ds_r[i]; + } + smem[threadIdx.x] = local_dot; + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) smem[threadIdx.x] += smem[threadIdx.x + s]; + __syncthreads(); + } + float dot = smem[0]; + + /* dLogits = s * (dS - dot) */ + for (int i = threadIdx.x; i < cols; i += blockDim.x) { + dl_r[i] = s_r[i] * (ds_r[i] - dot); + } +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: Bias gradient reduction */ +/* dBias[j] += sum_i dY[i*cols + j] */ +/* Each block handles one column j. */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_bias_grad_reduce( + const float* __restrict__ dY, + float* __restrict__ dBias, + int rows, int cols) +{ + int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= cols) return; + + float sum = 0.0f; + for (int i = 0; i < rows; i++) { + sum += dY[i * cols + j]; + } + atomicAdd(&dBias[j], sum); +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: Element-wise addition. */ +/* out[i] = a[i] + b[i] */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_add( + const float* __restrict__ a, + const float* __restrict__ b, + float* __restrict__ out, + int n) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = a[idx] + b[idx]; + } +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: Three-way addition. */ +/* out[i] = a[i] + b[i] + c[i] */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_add3( + const float* __restrict__ a, + const float* __restrict__ b, + const float* __restrict__ c, + float* __restrict__ out, + int n) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = a[idx] + b[idx] + c[idx]; + } +} + +/* Forward-declare head transpose kernels from fused_encoder_fwd.cu. + * Since both .cu files are compiled into the same libkernels.so, + * these symbols are available at link time. + * We redeclare the device-visible __global__ functions to launch them. + * NOTE: We cannot forward-declare __global__ functions across compilation + * units, so we duplicate the simple kernels here. */ + +__global__ void kernel_head_split_bwd( + const float* __restrict__ in, + float* __restrict__ out, + int bsC, int numPatches, int nHeads, int headDim) +{ + int total = bsC * numPatches * nHeads * headDim; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + int dModel = nHeads * headDim; + int d = idx % headDim; + int h = (idx / headDim) % nHeads; + int s = (idx / dModel) % numPatches; + int b = idx / (numPatches * dModel); + int in_idx = (b * numPatches + s) * dModel + h * headDim + d; + int out_idx = ((b * nHeads + h) * numPatches + s) * headDim + d; + out[out_idx] = in[in_idx]; +} + +__global__ void kernel_head_merge_bwd( + const float* __restrict__ in, + float* __restrict__ out, + int bsC, int numPatches, int nHeads, int headDim) +{ + int total = bsC * numPatches * nHeads * headDim; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + int dModel = nHeads * headDim; + int d = idx % headDim; + int s = (idx / headDim) % numPatches; + int h = (idx / (numPatches * headDim)) % nHeads; + int b = idx / (nHeads * numPatches * headDim); + int in_idx = ((b * nHeads + h) * numPatches + s) * headDim + d; + int out_idx = (b * numPatches + s) * dModel + h * headDim + d; + out[out_idx] = in[in_idx]; +} + +/* ------------------------------------------------------------------ */ +/* Orchestrator: fused encoder backward */ +/* ------------------------------------------------------------------ */ + +extern "C" { + +cudaError_t fused_encoder_bwd_f32( + void* cublas_handle, + const void** weights, + const void** weight_t, + const void** fwd_bufs, + void** bwd_bufs, + void** grads, + const float* dOutput, + float* dInput, + const float* input, + int totalRows, int dModel, int nHeads, int headDim, int ffnDim, + int bsC, int numPatches, + cudaStream_t stream) +{ + cublasHandle_t h = (cublasHandle_t)cublas_handle; + cublasStatus_t blas_stat; + + /* Forward cache (read-only). */ + const float* normed1 = (const float*)fwd_bufs[FEB_NORMED1]; + const float* ln1Invstd = (const float*)fwd_bufs[FEB_LN1_INVSTD]; + const float* fwd_Qh = (const float*)fwd_bufs[FEB_QH]; + const float* fwd_Kh = (const float*)fwd_bufs[FEB_KH]; + const float* fwd_Vh = (const float*)fwd_bufs[FEB_VH]; + const float* attnScores = (const float*)fwd_bufs[FEB_ATTN_SCORES]; + const float* attnOut = (const float*)fwd_bufs[FEB_ATTN_OUT]; + const float* xRes1 = (const float*)fwd_bufs[FEB_X_RES1]; + const float* normed2 = (const float*)fwd_bufs[FEB_NORMED2]; + const float* ln2Invstd = (const float*)fwd_bufs[FEB_LN2_INVSTD]; + const float* ffn1Pre = (const float*)fwd_bufs[FEB_FFN1_PRE]; + const float* ffn1Out = (const float*)fwd_bufs[FEB_FFN1_OUT]; + + /* Weight transposes (read-only). */ + const float* qWT = (const float*)weight_t[FEWT_QWT]; + const float* kWT = (const float*)weight_t[FEWT_KWT]; + const float* vWT = (const float*)weight_t[FEWT_VWT]; + const float* oWT = (const float*)weight_t[FEWT_OWT]; + const float* ffn1WT = (const float*)weight_t[FEWT_FFN1WT]; + const float* ffn2WT = (const float*)weight_t[FEWT_FFN2WT]; + + /* Weights (for norm scale). */ + const float* norm1W = (const float*)weights[FEW_NORM1W]; + const float* norm2W = (const float*)weights[FEW_NORM2W]; + + /* Backward scratch buffers. */ + float* dFfn1Out = (float*)bwd_bufs[FEBB_DFFN1_OUT]; + float* dFfn1Pre = (float*)bwd_bufs[FEBB_DFFN1_PRE]; + float* dNormed2 = (float*)bwd_bufs[FEBB_DNORMED2]; + float* dXRes1 = (float*)bwd_bufs[FEBB_DX_RES1]; + float* dAttnOut = (float*)bwd_bufs[FEBB_DATTN_OUT]; + float* dAttnOutH = (float*)bwd_bufs[FEBB_DATTN_OUT_H]; + float* dQh = (float*)bwd_bufs[FEBB_DQH]; + float* dKh = (float*)bwd_bufs[FEBB_DKH]; + float* dVh = (float*)bwd_bufs[FEBB_DVH]; + float* dScoresBuf = (float*)bwd_bufs[FEBB_DSCORES]; + float* dQ = (float*)bwd_bufs[FEBB_DQ]; + float* dK = (float*)bwd_bufs[FEBB_DK]; + float* dV = (float*)bwd_bufs[FEBB_DV]; + float* dNormed1 = (float*)bwd_bufs[FEBB_DNORMED1]; + float* temp = (float*)bwd_bufs[FEBB_TEMP]; + + /* Gradient accumulators (accumulated, not zeroed). */ + float* dg_qW = (float*)grads[FEG_DQW]; + float* dg_qB = (float*)grads[FEG_DQB]; + float* dg_kW = (float*)grads[FEG_DKW]; + float* dg_kB = (float*)grads[FEG_DKB]; + float* dg_vW = (float*)grads[FEG_DVW]; + float* dg_vB = (float*)grads[FEG_DVB]; + float* dg_oW = (float*)grads[FEG_DOW]; + float* dg_oB = (float*)grads[FEG_DOB]; + float* dg_ffn1W = (float*)grads[FEG_DFFN1W]; + float* dg_ffn1B = (float*)grads[FEG_DFFN1B]; + float* dg_ffn2W = (float*)grads[FEG_DFFN2W]; + float* dg_ffn2B = (float*)grads[FEG_DFFN2B]; + float* dg_norm1W = (float*)grads[FEG_DNORM1W]; + float* dg_norm1B = (float*)grads[FEG_DNORM1B]; + float* dg_norm2W = (float*)grads[FEG_DNORM2W]; + float* dg_norm2B = (float*)grads[FEG_DNORM2B]; + + int bnh = bsC * nHeads; + int trDm = totalRows * dModel; + int trFf = totalRows * ffnDim; + int block256 = 256; + int elemGridTrDm = (trDm + block256 - 1) / block256; + int elemGridTrFf = (trFf + block256 - 1) / block256; + int totalElems = bsC * numPatches * nHeads * headDim; + int elemGridTotal = (totalElems + block256 - 1) / block256; + + int lnBlock = next_pow2_bwd(dModel); + size_t lnSmem = lnBlock * sizeof(float); + + long long strideQK = (long long)numPatches * headDim; + long long strideScores = (long long)numPatches * numPatches; + float attnScale = 1.0f / sqrtf((float)headDim); + + /* ============================================================ */ + /* The backward proceeds in reverse layer order. */ + /* dOutput is the gradient from the next layer (or loss). */ + /* ============================================================ */ + + /* ------------------------------------------------------------ */ + /* Step 1: Residual 2 backward */ + /* output = proj + ffn2B + xRes1 */ + /* dOutput flows to both FFN2 proj path AND xRes1 (skip) */ + /* ------------------------------------------------------------ */ + /* dOutput is used directly for FFN2 backward and also added to */ + /* the residual path later (Step 5). */ + + /* ------------------------------------------------------------ */ + /* Step 2: FFN2 backward */ + /* proj = ffn1Out @ ffn2W */ + /* dW_ffn2 += ffn1Out^T @ dOutput [ffnDim, dModel] */ + /* dB_ffn2 += sum(dOutput, axis=0) [dModel] */ + /* dFfn1Out = dOutput @ ffn2W^T [totalRows, ffnDim] */ + /* ------------------------------------------------------------ */ + + /* dW_ffn2 += ffn1Out^T @ dOutput */ + blas_stat = sgemm_tn(h, ffnDim, dModel, totalRows, 1.0f, + ffn1Out, dOutput, 1.0f, dg_ffn2W); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* dB_ffn2 += sum(dOutput, axis=0) */ + int biasGrid = (dModel + block256 - 1) / block256; + kernel_bias_grad_reduce<<>>( + dOutput, dg_ffn2B, totalRows, dModel); + + /* dFfn1Out = dOutput @ ffn2W^T */ + blas_stat = sgemm_nn(h, totalRows, ffnDim, dModel, 1.0f, + dOutput, ffn2WT, 0.0f, dFfn1Out); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* ------------------------------------------------------------ */ + /* Step 3: GELU backward */ + /* dFfn1Pre = dFfn1Out * gelu'(ffn1Pre) */ + /* ------------------------------------------------------------ */ + kernel_gelu_bwd<<>>( + dFfn1Out, ffn1Pre, dFfn1Pre, trFf); + + /* ------------------------------------------------------------ */ + /* Step 4: FFN1 backward */ + /* proj = normed2 @ ffn1W */ + /* dW_ffn1 += normed2^T @ dFfn1Pre [dModel, ffnDim] */ + /* dB_ffn1 += sum(dFfn1Pre, axis=0) [ffnDim] */ + /* dNormed2 = dFfn1Pre @ ffn1W^T [totalRows, dModel] */ + /* ------------------------------------------------------------ */ + blas_stat = sgemm_tn(h, dModel, ffnDim, totalRows, 1.0f, + normed2, dFfn1Pre, 1.0f, dg_ffn1W); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + int ffnBiasGrid = (ffnDim + block256 - 1) / block256; + kernel_bias_grad_reduce<<>>( + dFfn1Pre, dg_ffn1B, totalRows, ffnDim); + + blas_stat = sgemm_nn(h, totalRows, dModel, ffnDim, 1.0f, + dFfn1Pre, ffn1WT, 0.0f, dNormed2); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* ------------------------------------------------------------ */ + /* Step 5: LayerNorm2 backward + residual skip */ + /* dXRes1_ln = layernorm_bwd(dNormed2, xRes1, ln2Invstd, norm2W) */ + /* dXRes1 = dXRes1_ln + dOutput (residual skip from step 1) */ + /* ------------------------------------------------------------ */ + kernel_layernorm_bwd<<>>( + dNormed2, xRes1, ln2Invstd, norm2W, + dXRes1, dg_norm2W, dg_norm2B, dModel); + + /* Add residual skip: dXRes1 += dOutput */ + kernel_add<<>>( + dXRes1, dOutput, dXRes1, trDm); + + /* ------------------------------------------------------------ */ + /* Step 6: Output projection backward */ + /* proj = attnOut @ oW */ + /* dW_o += attnOut^T @ dXRes1 [dModel, dModel] */ + /* dB_o += sum(dXRes1, axis=0) [dModel] */ + /* dAttnOut = dXRes1 @ oW^T [totalRows, dModel] */ + /* ------------------------------------------------------------ */ + blas_stat = sgemm_tn(h, dModel, dModel, totalRows, 1.0f, + attnOut, dXRes1, 1.0f, dg_oW); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + kernel_bias_grad_reduce<<>>( + dXRes1, dg_oB, totalRows, dModel); + + blas_stat = sgemm_nn(h, totalRows, dModel, dModel, 1.0f, + dXRes1, oWT, 0.0f, dAttnOut); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* ------------------------------------------------------------ */ + /* Step 7: Head split dAttnOut for multi-head backward */ + /* [totalRows, dModel] -> [bnh, numPatches, headDim] */ + /* ------------------------------------------------------------ */ + kernel_head_split_bwd<<>>( + dAttnOut, dAttnOutH, bsC, numPatches, nHeads, headDim); + + /* ------------------------------------------------------------ */ + /* Step 8: Attention backward */ + /* Forward: scores = softmax(Qh @ Kh^T / sqrt(d)) */ + /* attnOutH = scores @ Vh */ + /* */ + /* 8a: dVh = scores^T @ dAttnOutH [bnh, numPatches, headDim] */ + /* 8b: dScoresRaw = dAttnOutH @ Vh^T [bnh, numPatches, numPatches] */ + /* 8c: dLogits = softmax_bwd(scores, dScoresRaw) * scale */ + /* 8d: dQh = dLogits @ Kh [bnh, numPatches, headDim] */ + /* 8e: dKh = dLogits^T @ Qh [bnh, numPatches, headDim] */ + /* ------------------------------------------------------------ */ + + /* 8a: dVh = scores^T @ dAttnOutH */ + blas_stat = sgemm_tn_batched(h, + numPatches, headDim, numPatches, 1.0f, + attnScores, strideScores, + dAttnOutH, strideQK, + 0.0f, + dVh, strideQK, + bnh); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* 8b: dScoresRaw = dAttnOutH @ Vh^T */ + blas_stat = sgemm_nt_batched(h, + numPatches, numPatches, headDim, 1.0f, + dAttnOutH, strideQK, + fwd_Vh, strideQK, + 0.0f, + dScoresBuf, strideScores, + bnh); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* 8c: Softmax backward + scale */ + int sfBlock = next_pow2_bwd(numPatches); + size_t sfSmem = sfBlock * sizeof(float); + kernel_softmax_bwd<<>>( + attnScores, dScoresBuf, dScoresBuf, numPatches); + + /* Apply attention scale to dLogits. The forward scaled Q@K^T by + * 1/sqrt(d) before softmax, so the backward gradient through the + * scale is just multiply by 1/sqrt(d). */ + { + int scoreElems = bnh * numPatches * numPatches; + int scoreGrid = (scoreElems + block256 - 1) / block256; + /* Inline scale multiply kernel (reuse temp if needed). */ + /* We scale dScoresBuf in-place using a simple kernel. */ + /* For simplicity, fold scale into the cuBLAS calls below by + * using alpha=attnScale. Actually, the scale was already applied + * in the forward (scores = softmax(Q@K^T * scale)). The softmax + * backward gives dLogits = softmax_bwd(dScoresRaw). The chain + * rule through the scale multiply gives dScaled = dLogits * scale. + * We apply this by setting alpha=attnScale in the dQ/dK GEMMs. */ + (void)scoreGrid; + } + + /* 8d: dQh = attnScale * dLogits @ Kh */ + blas_stat = sgemm_nn_batched(h, + numPatches, headDim, numPatches, attnScale, + dScoresBuf, strideScores, + fwd_Kh, strideQK, + 0.0f, + dQh, strideQK, + bnh); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* 8e: dKh = attnScale * dLogits^T @ Qh */ + blas_stat = sgemm_tn_batched(h, + numPatches, headDim, numPatches, attnScale, + dScoresBuf, strideScores, + fwd_Qh, strideQK, + 0.0f, + dKh, strideQK, + bnh); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* ------------------------------------------------------------ */ + /* Step 9: Head merge dQh/dKh/dVh back to flat */ + /* [bnh, numPatches, headDim] -> [totalRows, dModel] */ + /* ------------------------------------------------------------ */ + kernel_head_merge_bwd<<>>(dQh, dQ, bsC, numPatches, nHeads, headDim); + kernel_head_merge_bwd<<>>(dKh, dK, bsC, numPatches, nHeads, headDim); + kernel_head_merge_bwd<<>>(dVh, dV, bsC, numPatches, nHeads, headDim); + + /* ------------------------------------------------------------ */ + /* Step 10: Q/K/V projection backward */ + /* Q = normed1 @ qW + qB */ + /* dW_q += normed1^T @ dQ [dModel, dModel] */ + /* dB_q += sum(dQ, axis=0) [dModel] */ + /* dNormed1_q = dQ @ qW^T [totalRows, dModel] */ + /* (same for K, V; sum the three dNormed1 contributions) */ + /* ------------------------------------------------------------ */ + + /* Q backward */ + blas_stat = sgemm_tn(h, dModel, dModel, totalRows, 1.0f, + normed1, dQ, 1.0f, dg_qW); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + kernel_bias_grad_reduce<<>>( + dQ, dg_qB, totalRows, dModel); + /* dNormed1 = dQ @ qW^T (first contribution) */ + blas_stat = sgemm_nn(h, totalRows, dModel, dModel, 1.0f, + dQ, qWT, 0.0f, dNormed1); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* K backward */ + blas_stat = sgemm_tn(h, dModel, dModel, totalRows, 1.0f, + normed1, dK, 1.0f, dg_kW); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + kernel_bias_grad_reduce<<>>( + dK, dg_kB, totalRows, dModel); + /* dNormed1 += dK @ kW^T (accumulate) */ + blas_stat = sgemm_nn(h, totalRows, dModel, dModel, 1.0f, + dK, kWT, 1.0f, dNormed1); /* beta=1 to accumulate */ + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* V backward */ + blas_stat = sgemm_tn(h, dModel, dModel, totalRows, 1.0f, + normed1, dV, 1.0f, dg_vW); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + kernel_bias_grad_reduce<<>>( + dV, dg_vB, totalRows, dModel); + /* dNormed1 += dV @ vW^T (accumulate) */ + blas_stat = sgemm_nn(h, totalRows, dModel, dModel, 1.0f, + dV, vWT, 1.0f, dNormed1); /* beta=1 to accumulate */ + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* ------------------------------------------------------------ */ + /* Step 11: LayerNorm1 backward + residual skip */ + /* dLN1Input = layernorm_bwd(dNormed1, input, ln1Invstd, norm1W) */ + /* dInput = dLN1Input + dXRes1 (residual skip from step 5) */ + /* ------------------------------------------------------------ */ + /* Use temp for LN1 backward output, then add residual. */ + kernel_layernorm_bwd<<>>( + dNormed1, input, ln1Invstd, norm1W, + temp, dg_norm1W, dg_norm1B, dModel); + + /* dInput = temp + dXRes1 */ + kernel_add<<>>( + temp, dXRes1, dInput, trDm); + + return cudaGetLastError(); +} + +} /* extern "C" */ diff --git a/internal/cuda/kernels/fused_encoder_bwd.go b/internal/cuda/kernels/fused_encoder_bwd.go new file mode 100644 index 0000000..ab1f99c --- /dev/null +++ b/internal/cuda/kernels/fused_encoder_bwd.go @@ -0,0 +1,46 @@ +//go:build cuda + +package kernels + +/* +#cgo LDFLAGS: -L${SRCDIR} -lkernels -lcudart -lcublas -lstdc++ +#include "fused_encoder_bwd.h" +*/ +import "C" + +import ( + "fmt" + "unsafe" +) + +// FusedEncoderBwdF32 computes all gradients for one encoder layer backward pass (CGo path). +func FusedEncoderBwdF32( + cublasHandle unsafe.Pointer, + weights *[FEW_COUNT]unsafe.Pointer, + weightT *[FEWT_COUNT]unsafe.Pointer, + fwdBufs *[FEB_COUNT]unsafe.Pointer, + bwdBufs *[FEBB_COUNT]unsafe.Pointer, + grads *[FEG_COUNT]unsafe.Pointer, + dOutput, dInput, input unsafe.Pointer, + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, + stream unsafe.Pointer, +) error { + err := C.fused_encoder_bwd_f32( + cublasHandle, + (*unsafe.Pointer)(unsafe.Pointer(weights)), + (*unsafe.Pointer)(unsafe.Pointer(weightT)), + (*unsafe.Pointer)(unsafe.Pointer(fwdBufs)), + (*unsafe.Pointer)(unsafe.Pointer(bwdBufs)), + (*unsafe.Pointer)(unsafe.Pointer(grads)), + (*C.float)(dOutput), + (*C.float)(dInput), + (*C.float)(input), + C.int(totalRows), C.int(dModel), C.int(nHeads), C.int(headDim), + C.int(ffnDim), C.int(bsC), C.int(numPatches), + C.cudaStream_t(stream), + ) + if err != 0 { + return fmt.Errorf("fused_encoder_bwd_f32 failed with cuda error %d", int(err)) + } + return nil +} diff --git a/internal/cuda/kernels/fused_encoder_bwd.h b/internal/cuda/kernels/fused_encoder_bwd.h new file mode 100644 index 0000000..7012178 --- /dev/null +++ b/internal/cuda/kernels/fused_encoder_bwd.h @@ -0,0 +1,95 @@ +/* fused_encoder_bwd.h -- Fused PatchTST encoder layer backward pass. + * + * Companion to fused_encoder_fwd.h. Computes gradients for all weights + * and the input gradient in a single host-side orchestrator call. + * Reads from the forward cache (FEB_* buffers) to avoid recomputation. + */ + +#ifndef FUSED_ENCODER_BWD_H +#define FUSED_ENCODER_BWD_H + +#include +#include "fused_encoder_fwd.h" /* FEW_*, FEB_* enums */ + +#ifdef __cplusplus +extern "C" { +#endif + +/* Weight gradient output indices (16 gradient tensors, same order as FEW_*). */ +enum FusedEncoderGrad { + FEG_DQW = 0, FEG_DQB, + FEG_DKW, FEG_DKB, + FEG_DVW, FEG_DVB, + FEG_DOW, FEG_DOB, + FEG_DFFN1W, FEG_DFFN1B, + FEG_DFFN2W, FEG_DFFN2B, + FEG_DNORM1W, FEG_DNORM1B, + FEG_DNORM2W, FEG_DNORM2B, + FEG_COUNT /* = 16 */ +}; + +/* Weight transpose indices (pre-computed by Go for backward efficiency). */ +enum FusedEncoderWeightT { + FEWT_QWT = 0, /* qW^T [dModel, dModel] */ + FEWT_KWT, /* kW^T [dModel, dModel] */ + FEWT_VWT, /* vW^T [dModel, dModel] */ + FEWT_OWT, /* oW^T [dModel, dModel] */ + FEWT_FFN1WT, /* ffn1W^T [ffnDim, dModel] */ + FEWT_FFN2WT, /* ffn2W^T [dModel, ffnDim] */ + FEWT_COUNT /* = 6 */ +}; + +/* Backward scratch buffer indices. */ +enum FusedEncoderBwdBuf { + FEBB_DFFN1_OUT = 0, /* [totalRows, ffnDim] */ + FEBB_DFFN1_PRE, /* [totalRows, ffnDim] */ + FEBB_DNORMED2, /* [totalRows, dModel] */ + FEBB_DX_RES1, /* [totalRows, dModel] */ + FEBB_DATTN_OUT, /* [totalRows, dModel] */ + FEBB_DATTN_OUT_H, /* [bnh, numPatches, headDim] */ + FEBB_DQH, /* [bnh, numPatches, headDim] */ + FEBB_DKH, /* [bnh, numPatches, headDim] */ + FEBB_DVH, /* [bnh, numPatches, headDim] */ + FEBB_DSCORES, /* [bnh, numPatches, numPatches] */ + FEBB_DQ, /* [totalRows, dModel] */ + FEBB_DK, /* [totalRows, dModel] */ + FEBB_DV, /* [totalRows, dModel] */ + FEBB_DNORMED1, /* [totalRows, dModel] */ + FEBB_TEMP, /* [max(totalRows*dModel, totalRows*ffnDim)] scratch */ + FEBB_COUNT /* = 15 */ +}; + +/* fused_encoder_bwd_f32 -- compute gradients for one encoder layer. + * + * Parameters: + * cublas_handle cuBLAS handle with stream already set + * weights FEW_COUNT pointers to layer weights (read-only) + * weight_t FEWT_COUNT pointers to pre-transposed weights (read-only) + * fwd_bufs FEB_COUNT pointers to forward cache (read-only) + * bwd_bufs FEBB_COUNT pointers to backward scratch (read-write) + * grads FEG_COUNT pointers to gradient accumulators (ACCUMULATED, not zeroed) + * dOutput [totalRows, dModel] upstream gradient (read-only) + * dInput [totalRows, dModel] input gradient output + * input [totalRows, dModel] original layer input (for LN1 backward) + * totalRows..numPatches dimension parameters + * stream CUDA stream + */ +cudaError_t fused_encoder_bwd_f32( + void* cublas_handle, + const void** weights, + const void** weight_t, + const void** fwd_bufs, + void** bwd_bufs, + void** grads, + const float* dOutput, + float* dInput, + const float* input, + int totalRows, int dModel, int nHeads, int headDim, int ffnDim, + int bsC, int numPatches, + cudaStream_t stream); + +#ifdef __cplusplus +} +#endif + +#endif /* FUSED_ENCODER_BWD_H */ diff --git a/internal/cuda/kernels/fused_encoder_bwd_purego.go b/internal/cuda/kernels/fused_encoder_bwd_purego.go new file mode 100644 index 0000000..4f14b2a --- /dev/null +++ b/internal/cuda/kernels/fused_encoder_bwd_purego.go @@ -0,0 +1,94 @@ +//go:build !cuda + +package kernels + +import ( + "fmt" + "unsafe" + + "github.com/zerfoo/ztensor/internal/cuda" +) + +// Backward buffer index constants matching C enum FusedEncoderGrad. +const ( + FEG_DQW = 0 + FEG_DQB = 1 + FEG_DKW = 2 + FEG_DKB = 3 + FEG_DVW = 4 + FEG_DVB = 5 + FEG_DOW = 6 + FEG_DOB = 7 + FEG_DFFN1W = 8 + FEG_DFFN1B = 9 + FEG_DFFN2W = 10 + FEG_DFFN2B = 11 + FEG_DNORM1W = 12 + FEG_DNORM1B = 13 + FEG_DNORM2W = 14 + FEG_DNORM2B = 15 + FEG_COUNT = 16 +) + +// Weight transpose indices matching C enum FusedEncoderWeightT. +const ( + FEWT_QWT = 0 + FEWT_KWT = 1 + FEWT_VWT = 2 + FEWT_OWT = 3 + FEWT_FFN1WT = 4 + FEWT_FFN2WT = 5 + FEWT_COUNT = 6 +) + +// Backward scratch buffer indices matching C enum FusedEncoderBwdBuf. +const ( + FEBB_DFFN1_OUT = 0 + FEBB_DFFN1_PRE = 1 + FEBB_DNORMED2 = 2 + FEBB_DX_RES1 = 3 + FEBB_DATTN_OUT = 4 + FEBB_DATTN_OUT_H = 5 + FEBB_DQH = 6 + FEBB_DKH = 7 + FEBB_DVH = 8 + FEBB_DSCORES = 9 + FEBB_DQ = 10 + FEBB_DK = 11 + FEBB_DV = 12 + FEBB_DNORMED1 = 13 + FEBB_TEMP = 14 + FEBB_COUNT = 15 +) + +// FusedEncoderBwdF32 computes all gradients for one encoder layer backward pass. +func FusedEncoderBwdF32( + cublasHandle unsafe.Pointer, + weights *[FEW_COUNT]unsafe.Pointer, + weightT *[FEWT_COUNT]unsafe.Pointer, + fwdBufs *[FEB_COUNT]unsafe.Pointer, + bwdBufs *[FEBB_COUNT]unsafe.Pointer, + grads *[FEG_COUNT]unsafe.Pointer, + dOutput, dInput, input unsafe.Pointer, + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, + stream unsafe.Pointer, +) error { + k := klib() + if k == nil || k.launchFusedEncoderBwdF32 == 0 { + return fmt.Errorf("fused_encoder_bwd_f32 kernel: not available") + } + ret := cuda.Ccall(k.launchFusedEncoderBwdF32, + uintptr(cublasHandle), + uintptr(unsafe.Pointer(weights)), + uintptr(unsafe.Pointer(weightT)), + uintptr(unsafe.Pointer(fwdBufs)), + uintptr(unsafe.Pointer(bwdBufs)), + uintptr(unsafe.Pointer(grads)), + uintptr(dOutput), + uintptr(dInput), + uintptr(input), + uintptr(totalRows), uintptr(dModel), uintptr(nHeads), uintptr(headDim), + uintptr(ffnDim), uintptr(bsC), uintptr(numPatches), + uintptr(stream)) + return checkKernel(ret, "fused_encoder_bwd_f32") +} diff --git a/internal/cuda/kernels/fused_encoder_fwd.cu b/internal/cuda/kernels/fused_encoder_fwd.cu new file mode 100644 index 0000000..73a6c3a --- /dev/null +++ b/internal/cuda/kernels/fused_encoder_fwd.cu @@ -0,0 +1,578 @@ +/* fused_encoder_fwd.cu -- Fused PatchTST encoder layer forward pass. + * + * Host-side orchestrator that replaces ~78 discrete Engine[T] operations + * per encoder layer with a single C function call. Internally launches + * cuBLAS GEMMs for matrix multiplications and custom CUDA sub-kernels + * for LayerNorm, head transpose, GELU, softmax, and residual operations. + * + * Sub-kernel inventory: + * kernel_layernorm_fwd Standard LayerNorm (mean + variance) + * kernel_bias_add Broadcast bias addition along rows + * kernel_head_split [B*S, H*D] -> [B*H, S, D] transpose + * kernel_head_merge [B*H, S, D] -> [B*S, H*D] transpose + * kernel_scaled_softmax_fwd Row-wise scaled softmax + * kernel_bias_gelu_fwd Fused bias add + GELU activation + * kernel_bias_residual_add Fused bias add + residual connection + * + * cuBLAS calls (7 total per layer): + * 3x Sgemm for Q/K/V projections + * 1x SgemmStridedBatched for attention scores (Q @ K^T) + * 1x SgemmStridedBatched for attention output (scores @ V) + * 1x Sgemm for output projection + * 1x Sgemm for FFN1 projection + * 1x Sgemm for FFN2 projection + * + * Compile: nvcc -O3 --use_fast_math -arch=sm_121 -lcublas -c fused_encoder_fwd.cu + */ + +#include "fused_encoder_fwd.h" +#include +#include +#include /* memcpy for bits_to_float */ + +/* ------------------------------------------------------------------ */ +/* Helpers */ +/* ------------------------------------------------------------------ */ + +/* Minimum block size for reduction kernels. */ +static inline int next_pow2(int v) { + int b = 1; + while (b < v && b < 256) b <<= 1; + return b; +} + +/* Row-major C[M,N] = alpha * A[M,K] * B[K,N] + beta * C[M,N] via cuBLAS. + * cuBLAS is column-major; the standard trick swaps A/B and m/n. */ +static inline cublasStatus_t sgemm_nn( + cublasHandle_t h, int M, int N, int K, float alpha, + const float* A, const float* B, float beta, float* C) +{ + return cublasSgemm(h, + CUBLAS_OP_N, CUBLAS_OP_N, + N, M, K, + &alpha, + B, N, /* cuBLAS "A" = B_rm, lda = N */ + A, K, /* cuBLAS "B" = A_rm, ldb = K */ + &beta, + C, N); /* cuBLAS "C", ldc = N */ +} + +/* Row-major batched C[b,M,N] = alpha * A[b,M,K] * B[b,K,N]^T + beta * C[b,M,N]. + * B is [b,N,K] (each batch element is [N,K], transposed to give [K,N]). + * Used for attention: scores = Q @ K^T. */ +static inline cublasStatus_t sgemm_nt_batched( + cublasHandle_t h, int M, int N, int K, float alpha, + const float* A, long long sA, + const float* B, long long sB, + float beta, + float* C, long long sC, + int batch) +{ + /* Row-major C = A * B^T. Column-major: C^T = B * A^T. + * cublas(transa=T, transb=N, m=N, n=M, k=K, + * A_cublas=A_rm[M,K]->cm[K,M] with transa=T -> [M,K], NO wrong. + * + * Easier derivation: + * Row-major X[p,q] in memory = col-major X'[q,p]. + * C_rm[M,N] = A_rm[M,K] * B_rm[N,K]^T + * In col-major: C'[N,M] = (A * B^T)' = B * A' + * B_rm[N,K] -> col-major B'[K,N]. We need "B" in cuBLAS sense -> B'[K,N] with op=N -> [K,N]. + * BUT we need first dim = m=N. So op=T -> [N,K]. Hmm. + * + * Let's just do it step by step: + * cublasSgemm(h, transa, transb, m, n, k, alpha, Acub, lda, Bcub, ldb, beta, Ccub, ldc) + * Ccub[m,n] = op(Acub)[m,k] * op(Bcub)[k,n] + * + * We want C'[N,M] = B_rm_as_cm * A_rm_as_cm + * = B'[K,N]^T * A'[K,M] with B' transposed = [N,K] + * No, C'[N,M] = ?[N,k] * ?[k,M]. k=K. + * First factor [N,K]: take B'[K,N] with CUBLAS_OP_T -> [N,K]. Acub=B_rm, lda=K. transa=T. m=N. + * Second factor [K,M]: take A'[K,M] with CUBLAS_OP_N -> [K,M]. Bcub=A_rm, ldb=K. transb=N. n=M. + * Ccub[N,M] at C_rm, ldc=N. + */ + return cublasSgemmStridedBatched(h, + CUBLAS_OP_T, CUBLAS_OP_N, + N, M, K, + &alpha, + B, K, sB, /* Acub = B_rm, lda=K, transa=T -> [N,K] */ + A, K, sA, /* Bcub = A_rm, ldb=K, transb=N -> [K,M] */ + &beta, + C, N, sC, + batch); +} + +/* Row-major batched C[b,M,N] = alpha * A[b,M,K] * B[b,K,N] + beta * C[b,M,N]. + * Standard NN batched multiply. */ +static inline cublasStatus_t sgemm_nn_batched( + cublasHandle_t h, int M, int N, int K, float alpha, + const float* A, long long sA, + const float* B, long long sB, + float beta, + float* C, long long sC, + int batch) +{ + return cublasSgemmStridedBatched(h, + CUBLAS_OP_N, CUBLAS_OP_N, + N, M, K, + &alpha, + B, N, sB, + A, K, sA, + &beta, + C, N, sC, + batch); +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: Standard LayerNorm forward */ +/* Each block processes one row of length D. */ +/* out[i] = (x[i] - mean) / sqrt(var + eps) * scale[i] + bias[i] */ +/* Also writes invstd_out for backward use. */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_layernorm_fwd( + const float* __restrict__ x, + const float* __restrict__ scale, + const float* __restrict__ bias, + float* __restrict__ out, + float* __restrict__ invstd_out, + int D) +{ + int row = blockIdx.x; + const float* xr = x + row * D; + float* outr = out + row * D; + + extern __shared__ float smem[]; + + /* Phase 1: compute mean. */ + float local_sum = 0.0f; + for (int i = threadIdx.x; i < D; i += blockDim.x) { + local_sum += xr[i]; + } + smem[threadIdx.x] = local_sum; + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) smem[threadIdx.x] += smem[threadIdx.x + s]; + __syncthreads(); + } + float mean = smem[0] / (float)D; + + /* Phase 2: compute variance. */ + float local_var = 0.0f; + for (int i = threadIdx.x; i < D; i += blockDim.x) { + float c = xr[i] - mean; + local_var += c * c; + } + smem[threadIdx.x] = local_var; + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) smem[threadIdx.x] += smem[threadIdx.x + s]; + __syncthreads(); + } + float invstd = rsqrtf(smem[0] / (float)D + 1e-5f); + + /* Store invstd for backward. */ + if (threadIdx.x == 0 && invstd_out != NULL) { + invstd_out[row] = invstd; + } + + /* Phase 3: normalize, scale, and bias. */ + for (int i = threadIdx.x; i < D; i += blockDim.x) { + outr[i] = (xr[i] - mean) * invstd * scale[i] + bias[i]; + } +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: Broadcast bias addition. */ +/* out[i*cols + j] = x[i*cols + j] + bias[j] */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_bias_add( + const float* __restrict__ x, + const float* __restrict__ bias, + float* __restrict__ out, + int n, int cols) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = x[idx] + bias[idx % cols]; + } +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: Head split transpose. */ +/* [bsC*numPatches, nHeads*headDim] -> [bsC*nHeads, numPatches, headDim] */ +/* */ +/* Input layout: in[b*S + s][h*D + d] (b=batch, s=seq, h=head, d=dim) */ +/* Output layout: out[(b*H + h)*S + s][d] */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_head_split( + const float* __restrict__ in, + float* __restrict__ out, + int bsC, int numPatches, int nHeads, int headDim) +{ + int total = bsC * numPatches * nHeads * headDim; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + + int dModel = nHeads * headDim; + /* Decode flat index -> (b, s, h, d) in input layout. */ + int d = idx % headDim; + int h = (idx / headDim) % nHeads; + int s = (idx / dModel) % numPatches; + int b = idx / (numPatches * dModel); + + /* Input: in[(b * numPatches + s) * dModel + h * headDim + d] */ + int in_idx = (b * numPatches + s) * dModel + h * headDim + d; + /* Output: out[((b * nHeads + h) * numPatches + s) * headDim + d] */ + int out_idx = ((b * nHeads + h) * numPatches + s) * headDim + d; + + out[out_idx] = in[in_idx]; +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: Head merge transpose (reverse of head_split). */ +/* [bsC*nHeads, numPatches, headDim] -> [bsC*numPatches, nHeads*headDim] */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_head_merge( + const float* __restrict__ in, + float* __restrict__ out, + int bsC, int numPatches, int nHeads, int headDim) +{ + int total = bsC * numPatches * nHeads * headDim; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + + int dModel = nHeads * headDim; + /* Decode flat index -> (b, h, s, d) in input layout. */ + int d = idx % headDim; + int s = (idx / headDim) % numPatches; + int h = (idx / (numPatches * headDim)) % nHeads; + int b = idx / (nHeads * numPatches * headDim); + + /* Input: in[((b * nHeads + h) * numPatches + s) * headDim + d] */ + int in_idx = ((b * nHeads + h) * numPatches + s) * headDim + d; + /* Output: out[(b * numPatches + s) * dModel + h * headDim + d] */ + int out_idx = (b * numPatches + s) * dModel + h * headDim + d; + + out[out_idx] = in[in_idx]; +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: Row-wise softmax (in-place). */ +/* Each block handles one row of length cols. */ +/* out[i] = exp(in[i] - max) / sum(exp(in[j] - max)) */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_softmax_fwd( + float* __restrict__ data, + int cols) +{ + int row = blockIdx.x; + float* r = data + row * cols; + + extern __shared__ float smem[]; + + /* Find row max. */ + float local_max = -1e30f; + for (int i = threadIdx.x; i < cols; i += blockDim.x) { + float v = r[i]; + if (v > local_max) local_max = v; + } + smem[threadIdx.x] = local_max; + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + s]); + } + __syncthreads(); + } + float row_max = smem[0]; + + /* Compute exp and sum. */ + float local_sum = 0.0f; + for (int i = threadIdx.x; i < cols; i += blockDim.x) { + float e = expf(r[i] - row_max); + r[i] = e; + local_sum += e; + } + smem[threadIdx.x] = local_sum; + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) smem[threadIdx.x] += smem[threadIdx.x + s]; + __syncthreads(); + } + float inv_sum = 1.0f / smem[0]; + + /* Normalize. */ + for (int i = threadIdx.x; i < cols; i += blockDim.x) { + r[i] *= inv_sum; + } +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: Fused bias add + GELU activation. */ +/* pre_act[i] = x[i] + bias[i % cols] */ +/* out[i] = 0.5 * pre_act * (1 + tanh(sqrt(2/pi) * (pre_act + 0.044715 * pre_act^3))) */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_bias_gelu_fwd( + const float* __restrict__ x, + const float* __restrict__ bias, + float* __restrict__ pre_act_out, + float* __restrict__ out, + int n, int cols) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n) return; + + float v = x[idx] + bias[idx % cols]; + if (pre_act_out != NULL) pre_act_out[idx] = v; + + /* GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) */ + float v3 = v * v * v; + float inner = 0.7978845608f * (v + 0.044715f * v3); /* sqrt(2/pi) ~ 0.7978845608 */ + out[idx] = 0.5f * v * (1.0f + tanhf(inner)); +} + +/* ------------------------------------------------------------------ */ +/* Sub-kernel: Fused bias add + residual connection. */ +/* out[i] = proj[i] + bias[i % cols] + residual[i] */ +/* ------------------------------------------------------------------ */ + +__global__ void kernel_bias_residual_add( + const float* __restrict__ proj, + const float* __restrict__ bias, + const float* __restrict__ residual, + float* __restrict__ out, + int n, int cols) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n) { + return; + } + out[idx] = proj[idx] + bias[idx % cols] + residual[idx]; +} + +/* ------------------------------------------------------------------ */ +/* Orchestrator: fused encoder forward */ +/* ------------------------------------------------------------------ */ + +extern "C" { + +long long fused_encoder_fwd_scratch_bytes( + int totalRows, int dModel, int nHeads, int headDim, int ffnDim, + int bsC, int numPatches) +{ + long long bnh = (long long)bsC * nHeads; + long long bytes = 0; + /* FEB_NORMED1 */ bytes += (long long)totalRows * dModel * 4; + /* FEB_LN1_INVSTD */ bytes += (long long)totalRows * 4; + /* FEB_Q */ bytes += (long long)totalRows * dModel * 4; + /* FEB_K */ bytes += (long long)totalRows * dModel * 4; + /* FEB_V */ bytes += (long long)totalRows * dModel * 4; + /* FEB_QH */ bytes += bnh * numPatches * headDim * 4; + /* FEB_KH */ bytes += bnh * numPatches * headDim * 4; + /* FEB_VH */ bytes += bnh * numPatches * headDim * 4; + /* FEB_ATTN_SCORES */ bytes += bnh * numPatches * numPatches * 4; + /* FEB_ATTN_OUT_H */ bytes += bnh * numPatches * headDim * 4; + /* FEB_ATTN_OUT */ bytes += (long long)totalRows * dModel * 4; + /* FEB_X_RES1 */ bytes += (long long)totalRows * dModel * 4; + /* FEB_NORMED2 */ bytes += (long long)totalRows * dModel * 4; + /* FEB_LN2_INVSTD */ bytes += (long long)totalRows * 4; + /* FEB_FFN1_PRE */ bytes += (long long)totalRows * ffnDim * 4; + /* FEB_FFN1_OUT */ bytes += (long long)totalRows * ffnDim * 4; + return bytes; +} + +cudaError_t fused_encoder_fwd_f32( + void* cublas_handle, + const void** weights, + void** bufs, + const float* input, + float* output, + int totalRows, int dModel, int nHeads, int headDim, int ffnDim, + int bsC, int numPatches, + cudaStream_t stream) +{ + cublasHandle_t h = (cublasHandle_t)cublas_handle; + cublasStatus_t blas_stat; + + /* Extract weight pointers. */ + const float* qW = (const float*)weights[FEW_QW]; + const float* qB = (const float*)weights[FEW_QB]; + const float* kW = (const float*)weights[FEW_KW]; + const float* kB = (const float*)weights[FEW_KB]; + const float* vW = (const float*)weights[FEW_VW]; + const float* vB = (const float*)weights[FEW_VB]; + const float* oW = (const float*)weights[FEW_OW]; + const float* oB = (const float*)weights[FEW_OB]; + const float* ffn1W = (const float*)weights[FEW_FFN1W]; + const float* ffn1B = (const float*)weights[FEW_FFN1B]; + const float* ffn2W = (const float*)weights[FEW_FFN2W]; + const float* ffn2B = (const float*)weights[FEW_FFN2B]; + const float* norm1W = (const float*)weights[FEW_NORM1W]; + const float* norm1B = (const float*)weights[FEW_NORM1B]; + const float* norm2W = (const float*)weights[FEW_NORM2W]; + const float* norm2B = (const float*)weights[FEW_NORM2B]; + + /* Extract buffer pointers. */ + float* normed1 = (float*)bufs[FEB_NORMED1]; + float* ln1Invstd = (float*)bufs[FEB_LN1_INVSTD]; + float* Q = (float*)bufs[FEB_Q]; + float* K = (float*)bufs[FEB_K]; + float* V = (float*)bufs[FEB_V]; + float* Qh = (float*)bufs[FEB_QH]; + float* Kh = (float*)bufs[FEB_KH]; + float* Vh = (float*)bufs[FEB_VH]; + float* attnScores = (float*)bufs[FEB_ATTN_SCORES]; + float* attnOutH = (float*)bufs[FEB_ATTN_OUT_H]; + float* attnOut = (float*)bufs[FEB_ATTN_OUT]; + float* xRes1 = (float*)bufs[FEB_X_RES1]; + float* normed2 = (float*)bufs[FEB_NORMED2]; + float* ln2Invstd = (float*)bufs[FEB_LN2_INVSTD]; + float* ffn1Pre = (float*)bufs[FEB_FFN1_PRE]; + float* ffn1Out = (float*)bufs[FEB_FFN1_OUT]; + + int bnh = bsC * nHeads; + int trDm = totalRows * dModel; /* total elements [totalRows, dModel] */ + int trFf = totalRows * ffnDim; /* total elements [totalRows, ffnDim] */ + + /* Common kernel launch params. */ + int block256 = 256; + int elemGridTrDm = (trDm + block256 - 1) / block256; + int elemGridTrFf = (trFf + block256 - 1) / block256; + int totalElems = bsC * numPatches * nHeads * headDim; + int elemGridTotal = (totalElems + block256 - 1) / block256; + + /* LayerNorm block size: next power of 2 up to min(dModel, 256). */ + int lnBlock = next_pow2(dModel); + size_t lnSmem = lnBlock * sizeof(float); + + /* ------------------------------------------------------------ */ + /* Step 1: LayerNorm1 */ + /* ------------------------------------------------------------ */ + kernel_layernorm_fwd<<>>( + input, norm1W, norm1B, normed1, ln1Invstd, dModel); + + /* ------------------------------------------------------------ */ + /* Step 2: Q/K/V projections via cuBLAS + bias add */ + /* Q = normed1 @ qW + qB */ + /* K = normed1 @ kW + kB */ + /* V = normed1 @ vW + vB */ + /* ------------------------------------------------------------ */ + + /* Q = normed1 @ qW (beta=0 to zero-initialize output) */ + blas_stat = sgemm_nn(h, totalRows, dModel, dModel, 1.0f, normed1, qW, 0.0f, Q); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + kernel_bias_add<<>>(Q, qB, Q, trDm, dModel); + + /* K = normed1 @ kW + kB */ + blas_stat = sgemm_nn(h, totalRows, dModel, dModel, 1.0f, normed1, kW, 0.0f, K); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + kernel_bias_add<<>>(K, kB, K, trDm, dModel); + + /* V = normed1 @ vW + vB */ + blas_stat = sgemm_nn(h, totalRows, dModel, dModel, 1.0f, normed1, vW, 0.0f, V); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + kernel_bias_add<<>>(V, vB, V, trDm, dModel); + + /* ------------------------------------------------------------ */ + /* Step 3: Head split transpose for Q, K, V */ + /* [bsC*numPatches, nHeads*headDim] -> [bsC*nHeads, numPatches, headDim] */ + /* ------------------------------------------------------------ */ + kernel_head_split<<>>(Q, Qh, bsC, numPatches, nHeads, headDim); + kernel_head_split<<>>(K, Kh, bsC, numPatches, nHeads, headDim); + kernel_head_split<<>>(V, Vh, bsC, numPatches, nHeads, headDim); + + /* ------------------------------------------------------------ */ + /* Step 4: Attention scores = (Qh @ Kh^T) / sqrt(headDim) */ + /* [bnh, numPatches, headDim] @ [bnh, headDim, numPatches] */ + /* -> [bnh, numPatches, numPatches] */ + /* ------------------------------------------------------------ */ + float attnScale = 1.0f / sqrtf((float)headDim); + long long strideQK = (long long)numPatches * headDim; + long long strideScores = (long long)numPatches * numPatches; + + blas_stat = sgemm_nt_batched(h, + numPatches, numPatches, headDim, + attnScale, + Qh, strideQK, + Kh, strideQK, + 0.0f, + attnScores, strideScores, + bnh); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* ------------------------------------------------------------ */ + /* Step 5: Softmax on attention scores (in-place) */ + /* Each row of [bnh * numPatches, numPatches] */ + /* ------------------------------------------------------------ */ + int sfBlock = next_pow2(numPatches); + size_t sfSmem = sfBlock * sizeof(float); + kernel_softmax_fwd<<>>( + attnScores, numPatches); + + /* ------------------------------------------------------------ */ + /* Step 6: Attention output = scores @ Vh */ + /* [bnh, numPatches, numPatches] @ [bnh, numPatches, headDim] */ + /* -> [bnh, numPatches, headDim] */ + /* ------------------------------------------------------------ */ + blas_stat = sgemm_nn_batched(h, + numPatches, headDim, numPatches, + 1.0f, + attnScores, strideScores, + Vh, strideQK, + 0.0f, + attnOutH, strideQK, + bnh); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + + /* ------------------------------------------------------------ */ + /* Step 7: Head merge transpose */ + /* [bsC*nHeads, numPatches, headDim] -> [bsC*numPatches, nHeads*headDim] */ + /* ------------------------------------------------------------ */ + kernel_head_merge<<>>(attnOutH, attnOut, bsC, numPatches, nHeads, headDim); + + /* ------------------------------------------------------------ */ + /* Step 8: Output projection + bias + residual 1 */ + /* proj = attnOut @ oW */ + /* xRes1 = proj + oB + input (fused bias + residual) */ + /* ------------------------------------------------------------ */ + /* Use output buffer as temporary for projection result. */ + blas_stat = sgemm_nn(h, totalRows, dModel, dModel, 1.0f, attnOut, oW, 0.0f, output); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + kernel_bias_residual_add<<>>( + output, oB, input, xRes1, trDm, dModel); + + /* ------------------------------------------------------------ */ + /* Step 9: LayerNorm2 */ + /* ------------------------------------------------------------ */ + kernel_layernorm_fwd<<>>( + xRes1, norm2W, norm2B, normed2, ln2Invstd, dModel); + + /* ------------------------------------------------------------ */ + /* Step 10: FFN1 + bias + GELU */ + /* ffn1Pre = normed2 @ ffn1W (linear projection) */ + /* ffn1Out = gelu(ffn1Pre + ffn1B) */ + /* ------------------------------------------------------------ */ + blas_stat = sgemm_nn(h, totalRows, ffnDim, dModel, 1.0f, normed2, ffn1W, 0.0f, ffn1Pre); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + kernel_bias_gelu_fwd<<>>( + ffn1Pre, ffn1B, ffn1Pre, ffn1Out, trFf, ffnDim); + + /* ------------------------------------------------------------ */ + /* Step 11: FFN2 + bias + residual 2 */ + /* proj = ffn1Out @ ffn2W */ + /* output = proj + ffn2B + xRes1 */ + /* ------------------------------------------------------------ */ + blas_stat = sgemm_nn(h, totalRows, dModel, ffnDim, 1.0f, ffn1Out, ffn2W, 0.0f, output); + if (blas_stat != CUBLAS_STATUS_SUCCESS) return cudaErrorUnknown; + kernel_bias_residual_add<<>>( + output, ffn2B, xRes1, output, trDm, dModel); + + return cudaGetLastError(); +} + +} /* extern "C" */ diff --git a/internal/cuda/kernels/fused_encoder_fwd.go b/internal/cuda/kernels/fused_encoder_fwd.go new file mode 100644 index 0000000..7fd3387 --- /dev/null +++ b/internal/cuda/kernels/fused_encoder_fwd.go @@ -0,0 +1,48 @@ +//go:build cuda + +package kernels + +/* +#cgo LDFLAGS: -L${SRCDIR} -lkernels -lcudart -lcublas -lstdc++ +#include "fused_encoder_fwd.h" +*/ +import "C" + +import ( + "fmt" + "unsafe" +) + +// FusedEncoderFwdF32 executes one encoder layer forward pass in a single call (CGo path). +func FusedEncoderFwdF32( + cublasHandle unsafe.Pointer, + weights *[FEW_COUNT]unsafe.Pointer, + bufs *[FEB_COUNT]unsafe.Pointer, + input, output unsafe.Pointer, + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, + stream unsafe.Pointer, +) error { + err := C.fused_encoder_fwd_f32( + cublasHandle, + (*unsafe.Pointer)(unsafe.Pointer(weights)), + (*unsafe.Pointer)(unsafe.Pointer(bufs)), + (*C.float)(input), + (*C.float)(output), + C.int(totalRows), C.int(dModel), C.int(nHeads), C.int(headDim), + C.int(ffnDim), C.int(bsC), C.int(numPatches), + C.cudaStream_t(stream), + ) + if err != 0 { + return fmt.Errorf("fused_encoder_fwd_f32 failed with cuda error %d", int(err)) + } + return nil +} + +// FusedEncoderFwdScratchBytes returns the total bytes needed for all FEB_COUNT buffers. +func FusedEncoderFwdScratchBytes( + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, +) int64 { + return int64(C.fused_encoder_fwd_scratch_bytes( + C.int(totalRows), C.int(dModel), C.int(nHeads), C.int(headDim), + C.int(ffnDim), C.int(bsC), C.int(numPatches))) +} diff --git a/internal/cuda/kernels/fused_encoder_fwd.h b/internal/cuda/kernels/fused_encoder_fwd.h new file mode 100644 index 0000000..49b115f --- /dev/null +++ b/internal/cuda/kernels/fused_encoder_fwd.h @@ -0,0 +1,97 @@ +/* fused_encoder_fwd.h -- Fused PatchTST encoder layer forward pass. + * + * A single host-side orchestrator function that replaces ~78 discrete + * Engine operations per encoder layer with ~14 internal sub-operations + * (cuBLAS GEMMs + custom CUDA kernels) launched on the same stream. + * + * Buffer indices for the ptrs[] and cache[] arrays are defined as enums + * so Go bindings can reference the same layout. + */ + +#ifndef FUSED_ENCODER_FWD_H +#define FUSED_ENCODER_FWD_H + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* Weight pointer indices (16 pointers). */ +enum FusedEncoderWeight { + FEW_QW = 0, FEW_QB, + FEW_KW, FEW_KB, + FEW_VW, FEW_VB, + FEW_OW, FEW_OB, + FEW_FFN1W, FEW_FFN1B, + FEW_FFN2W, FEW_FFN2B, + FEW_NORM1W, FEW_NORM1B, + FEW_NORM2W, FEW_NORM2B, + FEW_COUNT /* = 16 */ +}; + +/* Forward cache buffer indices (pre-allocated by Go, written by kernel). + * The backward pass reads these to compute gradients. */ +enum FusedEncoderFwdBuf { + FEB_NORMED1 = 0, /* [totalRows, dModel] */ + FEB_LN1_INVSTD, /* [totalRows] */ + FEB_Q, /* [totalRows, dModel] */ + FEB_K, /* [totalRows, dModel] */ + FEB_V, /* [totalRows, dModel] */ + FEB_QH, /* [bsC*nHeads, numPatches, headDim] */ + FEB_KH, /* [bsC*nHeads, numPatches, headDim] */ + FEB_VH, /* [bsC*nHeads, numPatches, headDim] */ + FEB_ATTN_SCORES, /* [bsC*nHeads, numPatches, numPatches] */ + FEB_ATTN_OUT_H, /* [bsC*nHeads, numPatches, headDim] */ + FEB_ATTN_OUT, /* [totalRows, dModel] */ + FEB_X_RES1, /* [totalRows, dModel] */ + FEB_NORMED2, /* [totalRows, dModel] */ + FEB_LN2_INVSTD, /* [totalRows] */ + FEB_FFN1_PRE, /* [totalRows, ffnDim] */ + FEB_FFN1_OUT, /* [totalRows, ffnDim] */ + FEB_COUNT /* = 16 */ +}; + +/* fused_encoder_fwd_f32 -- execute one encoder layer forward pass. + * + * Parameters: + * cublas_handle cuBLAS handle with stream already set + * weights FEW_COUNT pointers to layer weight device memory + * bufs FEB_COUNT pointers to pre-allocated cache buffers + * input [totalRows, dModel] layer input (device ptr) + * output [totalRows, dModel] layer output (device ptr, may alias input) + * totalRows bsC * numPatches + * dModel embedding dimension + * nHeads number of attention heads + * headDim dModel / nHeads + * ffnDim feed-forward hidden dimension (typically dModel * 4) + * bsC batch_size * channels + * numPatches sequence length (number of patches) + * stream CUDA stream + * + * Returns cudaSuccess on success. + */ +cudaError_t fused_encoder_fwd_f32( + void* cublas_handle, + const void** weights, + void** bufs, + const float* input, + float* output, + int totalRows, int dModel, int nHeads, int headDim, int ffnDim, + int bsC, int numPatches, + cudaStream_t stream); + +/* fused_encoder_fwd_scratch_bytes -- compute minimum buffer sizes. + * + * Returns the total bytes needed across all FEB_COUNT buffers. + * Caller can use this to validate pre-allocation. + */ +long long fused_encoder_fwd_scratch_bytes( + int totalRows, int dModel, int nHeads, int headDim, int ffnDim, + int bsC, int numPatches); + +#ifdef __cplusplus +} +#endif + +#endif /* FUSED_ENCODER_FWD_H */ diff --git a/internal/cuda/kernels/fused_encoder_fwd_purego.go b/internal/cuda/kernels/fused_encoder_fwd_purego.go new file mode 100644 index 0000000..a907b7e --- /dev/null +++ b/internal/cuda/kernels/fused_encoder_fwd_purego.go @@ -0,0 +1,105 @@ +//go:build !cuda + +package kernels + +import ( + "fmt" + "unsafe" + + "github.com/zerfoo/ztensor/internal/cuda" +) + +// Buffer index constants matching the C enum FusedEncoderWeight. +const ( + FEW_QW = 0 + FEW_QB = 1 + FEW_KW = 2 + FEW_KB = 3 + FEW_VW = 4 + FEW_VB = 5 + FEW_OW = 6 + FEW_OB = 7 + FEW_FFN1W = 8 + FEW_FFN1B = 9 + FEW_FFN2W = 10 + FEW_FFN2B = 11 + FEW_NORM1W = 12 + FEW_NORM1B = 13 + FEW_NORM2W = 14 + FEW_NORM2B = 15 + FEW_COUNT = 16 +) + +// Buffer index constants matching the C enum FusedEncoderFwdBuf. +const ( + FEB_NORMED1 = 0 + FEB_LN1_INVSTD = 1 + FEB_Q = 2 + FEB_K = 3 + FEB_V = 4 + FEB_QH = 5 + FEB_KH = 6 + FEB_VH = 7 + FEB_ATTN_SCORES = 8 + FEB_ATTN_OUT_H = 9 + FEB_ATTN_OUT = 10 + FEB_X_RES1 = 11 + FEB_NORMED2 = 12 + FEB_LN2_INVSTD = 13 + FEB_FFN1_PRE = 14 + FEB_FFN1_OUT = 15 + FEB_COUNT = 16 +) + +// FusedEncoderFwdF32 executes one encoder layer forward pass in a single call. +// +// Parameters: +// - cublasHandle: raw cuBLAS handle pointer (from cublas.Handle.Ptr()) +// - weights: [FEW_COUNT]unsafe.Pointer to layer weight device memory +// - bufs: [FEB_COUNT]unsafe.Pointer to pre-allocated cache buffers +// - input, output: [totalRows, dModel] device pointers +// - totalRows..numPatches: dimension parameters +// - stream: CUDA stream pointer +func FusedEncoderFwdF32( + cublasHandle unsafe.Pointer, + weights *[FEW_COUNT]unsafe.Pointer, + bufs *[FEB_COUNT]unsafe.Pointer, + input, output unsafe.Pointer, + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, + stream unsafe.Pointer, +) error { + k := klib() + if k == nil || k.launchFusedEncoderFwdF32 == 0 { + return fmt.Errorf("fused_encoder_fwd_f32 kernel: not available") + } + ret := cuda.Ccall(k.launchFusedEncoderFwdF32, + uintptr(cublasHandle), + uintptr(unsafe.Pointer(weights)), + uintptr(unsafe.Pointer(bufs)), + uintptr(input), + uintptr(output), + uintptr(totalRows), uintptr(dModel), uintptr(nHeads), uintptr(headDim), + uintptr(ffnDim), uintptr(bsC), uintptr(numPatches), + uintptr(stream)) + return checkKernel(ret, "fused_encoder_fwd_f32") +} + +// FusedEncoderFwdScratchBytes returns the total bytes needed for all FEB_COUNT buffers. +func FusedEncoderFwdScratchBytes( + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, +) int64 { + k := klib() + if k == nil || k.launchFusedEncoderFwdScratch == 0 { + return -1 + } + ret := cuda.Ccall(k.launchFusedEncoderFwdScratch, + uintptr(totalRows), uintptr(dModel), uintptr(nHeads), uintptr(headDim), + uintptr(ffnDim), uintptr(bsC), uintptr(numPatches)) + return int64(ret) +} + +// FusedEncoderAvailable returns true if the fused encoder kernels are loaded. +func FusedEncoderAvailable() bool { + k := klib() + return k != nil && k.launchFusedEncoderFwdF32 != 0 +} diff --git a/internal/cuda/kernels/purego.go b/internal/cuda/kernels/purego.go index 5cf49f2..552b143 100644 --- a/internal/cuda/kernels/purego.go +++ b/internal/cuda/kernels/purego.go @@ -180,6 +180,11 @@ type KernelLib struct { launchIQDequantNLF32 uintptr launchIQDequant3SF32 uintptr launchIQDequant2XXSF32 uintptr + + // fused_encoder (PatchTST encoder layer orchestrator) + launchFusedEncoderFwdF32 uintptr + launchFusedEncoderFwdScratch uintptr + launchFusedEncoderBwdF32 uintptr } var ( @@ -349,6 +354,10 @@ func openKernelLib() (*KernelLib, error) { {"iq_dequant_nl_f32", &k.launchIQDequantNLF32}, {"iq_dequant_3s_f32", &k.launchIQDequant3SF32}, {"iq_dequant_2xxs_f32", &k.launchIQDequant2XXSF32}, + // fused_encoder (PatchTST encoder layer orchestrator) + {"fused_encoder_fwd_f32", &k.launchFusedEncoderFwdF32}, + {"fused_encoder_fwd_scratch_bytes", &k.launchFusedEncoderFwdScratch}, + {"fused_encoder_bwd_f32", &k.launchFusedEncoderBwdF32}, } // Optional symbols: missing is non-fatal (kernel not compiled yet). optionalSyms := map[string]bool{ @@ -394,6 +403,9 @@ func openKernelLib() (*KernelLib, error) { "iq_dequant_nl_f32": true, "iq_dequant_3s_f32": true, "iq_dequant_2xxs_f32": true, + "fused_encoder_fwd_f32": true, + "fused_encoder_fwd_scratch_bytes": true, + "fused_encoder_bwd_f32": true, } for _, s := range syms { ptr, dlErr := cuda.Dlsym(lib, s.name) diff --git a/internal/gpuapi/cuda_kernels.go b/internal/gpuapi/cuda_kernels.go index ae0954a..4a3a604 100644 --- a/internal/gpuapi/cuda_kernels.go +++ b/internal/gpuapi/cuda_kernels.go @@ -311,5 +311,26 @@ func (k *CUDAKernels) FusedSoftmaxVMulF32(scores, V, output unsafe.Pointer, scal return kernels.FusedSoftmaxVMulF32(scores, V, output, scale, BH, seqKV, D, streamPtr(s)) } +func (k *CUDAKernels) FusedEncoderFwdF32(cublasHandle unsafe.Pointer, weights, bufs *[16]unsafe.Pointer, input, output unsafe.Pointer, totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, s Stream) error { + return kernels.FusedEncoderFwdF32(cublasHandle, (*[kernels.FEW_COUNT]unsafe.Pointer)(weights), (*[kernels.FEB_COUNT]unsafe.Pointer)(bufs), input, output, totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches, streamPtr(s)) +} + +func (k *CUDAKernels) FusedEncoderBwdF32(cublasHandle unsafe.Pointer, weights, weightT *[16]unsafe.Pointer, fwdBufs *[16]unsafe.Pointer, bwdBufs *[15]unsafe.Pointer, grads *[16]unsafe.Pointer, dOutput, dInput, input unsafe.Pointer, totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, s Stream) error { + // Use unsafe.Pointer intermediary for array size conversions (interface uses + // fixed sizes; C kernel reads only the elements it needs per enum count). + return kernels.FusedEncoderBwdF32(cublasHandle, + (*[kernels.FEW_COUNT]unsafe.Pointer)(unsafe.Pointer(weights)), + (*[kernels.FEWT_COUNT]unsafe.Pointer)(unsafe.Pointer(weightT)), + (*[kernels.FEB_COUNT]unsafe.Pointer)(unsafe.Pointer(fwdBufs)), + (*[kernels.FEBB_COUNT]unsafe.Pointer)(unsafe.Pointer(bwdBufs)), + (*[kernels.FEG_COUNT]unsafe.Pointer)(unsafe.Pointer(grads)), + dOutput, dInput, input, + totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches, streamPtr(s)) +} + +func (k *CUDAKernels) FusedEncoderFwdAvailable() bool { + return kernels.FusedEncoderAvailable() +} + // Compile-time interface assertion. var _ KernelRunner = (*CUDAKernels)(nil) diff --git a/internal/gpuapi/fpga_kernels.go b/internal/gpuapi/fpga_kernels.go index 7567d42..96a6e45 100644 --- a/internal/gpuapi/fpga_kernels.go +++ b/internal/gpuapi/fpga_kernels.go @@ -207,6 +207,18 @@ func (k *FPGAKernels) FusedQKNormRoPEF32(_, _, _, _, _, _ unsafe.Pointer, _ floa return fmt.Errorf("FusedQKNormRoPEF32: not implemented for FPGA") } +func (k *FPGAKernels) FusedEncoderFwdF32(_ unsafe.Pointer, _, _ *[16]unsafe.Pointer, _, _ unsafe.Pointer, _, _, _, _, _, _, _ int, _ Stream) error { + return fmt.Errorf("FusedEncoderFwdF32: not implemented for FPGA") +} + +func (k *FPGAKernels) FusedEncoderBwdF32(_ unsafe.Pointer, _, _ *[16]unsafe.Pointer, _ *[16]unsafe.Pointer, _ *[15]unsafe.Pointer, _ *[16]unsafe.Pointer, _, _, _ unsafe.Pointer, _, _, _, _, _, _, _ int, _ Stream) error { + return fmt.Errorf("FusedEncoderBwdF32: not implemented for FPGA") +} + +func (k *FPGAKernels) FusedEncoderFwdAvailable() bool { + return false +} + func (k *FPGAKernels) ScaledSoftmaxF32(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error { return fmt.Errorf("ScaledSoftmaxF32: not implemented for FPGA") } diff --git a/internal/gpuapi/gpuapi_test.go b/internal/gpuapi/gpuapi_test.go index c4cbc25..96aad92 100644 --- a/internal/gpuapi/gpuapi_test.go +++ b/internal/gpuapi/gpuapi_test.go @@ -225,6 +225,15 @@ func (stubKernelRunner) FusedNormAddF32(_, _, _, _ unsafe.Pointer, _ float32, _, func (stubKernelRunner) FusedQKNormRoPEF32(_, _, _, _, _, _ unsafe.Pointer, _ float32, _, _, _, _ int, _ gpuapi.Stream) error { return nil } +func (stubKernelRunner) FusedEncoderFwdF32(_ unsafe.Pointer, _, _ *[16]unsafe.Pointer, _, _ unsafe.Pointer, _, _, _, _, _, _, _ int, _ gpuapi.Stream) error { + return nil +} +func (stubKernelRunner) FusedEncoderBwdF32(_ unsafe.Pointer, _, _ *[16]unsafe.Pointer, _ *[16]unsafe.Pointer, _ *[15]unsafe.Pointer, _ *[16]unsafe.Pointer, _, _, _ unsafe.Pointer, _, _, _, _, _, _, _ int, _ gpuapi.Stream) error { + return nil +} +func (stubKernelRunner) FusedEncoderFwdAvailable() bool { + return false +} func (stubKernelRunner) ScaledSoftmaxF32(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ gpuapi.Stream) error { return nil } diff --git a/internal/gpuapi/kernels.go b/internal/gpuapi/kernels.go index 4698e9b..32c75a0 100644 --- a/internal/gpuapi/kernels.go +++ b/internal/gpuapi/kernels.go @@ -211,4 +211,19 @@ type KernelRunner interface { // Decode-optimized (seqQ=1): avoids materializing the attention weights tensor. // scores: [BH, seqKV], V: [BH, seqKV, D], output: [BH, D]. FusedSoftmaxVMulF32(scores, V, output unsafe.Pointer, scale float32, BH, seqKV, D int, stream Stream) error + + // FusedEncoderFwdF32 executes one PatchTST encoder layer forward pass. + // Replaces ~78 discrete engine operations with a single orchestrated call. + // cublasHandle: raw cuBLAS handle; weights: 16 weight pointers; + // bufs: 16 cache buffer pointers; input/output: [totalRows, dModel]. + FusedEncoderFwdF32(cublasHandle unsafe.Pointer, weights, bufs *[16]unsafe.Pointer, input, output unsafe.Pointer, totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, stream Stream) error + + // FusedEncoderBwdF32 computes all gradients for one encoder layer backward. + // weights: 16 weight pointers; weightT: 6 transposed weight pointers; + // fwdBufs: 16 forward cache pointers; bwdBufs: 15 scratch pointers; + // grads: 16 gradient accumulator pointers. + FusedEncoderBwdF32(cublasHandle unsafe.Pointer, weights, weightT *[16]unsafe.Pointer, fwdBufs *[16]unsafe.Pointer, bwdBufs *[15]unsafe.Pointer, grads *[16]unsafe.Pointer, dOutput, dInput, input unsafe.Pointer, totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int, stream Stream) error + + // FusedEncoderFwdAvailable returns true if the fused encoder kernel is loaded. + FusedEncoderFwdAvailable() bool } diff --git a/internal/gpuapi/metal_kernels.go b/internal/gpuapi/metal_kernels.go index af3b107..cb3bfeb 100644 --- a/internal/gpuapi/metal_kernels.go +++ b/internal/gpuapi/metal_kernels.go @@ -317,6 +317,18 @@ func (k *MetalKernels) FusedSoftmaxVMulF32(_, _, _ unsafe.Pointer, _ float32, _, return fmt.Errorf("FusedSoftmaxVMulF32: not yet implemented for Metal") } +func (k *MetalKernels) FusedEncoderFwdF32(_ unsafe.Pointer, _, _ *[16]unsafe.Pointer, _, _ unsafe.Pointer, _, _, _, _, _, _, _ int, _ Stream) error { + return fmt.Errorf("FusedEncoderFwdF32: not implemented for Metal") +} + +func (k *MetalKernels) FusedEncoderBwdF32(_ unsafe.Pointer, _, _ *[16]unsafe.Pointer, _ *[16]unsafe.Pointer, _ *[15]unsafe.Pointer, _ *[16]unsafe.Pointer, _, _, _ unsafe.Pointer, _, _, _, _, _, _, _ int, _ Stream) error { + return fmt.Errorf("FusedEncoderBwdF32: not implemented for Metal") +} + +func (k *MetalKernels) FusedEncoderFwdAvailable() bool { + return false +} + // --- Gather --- func (k *MetalKernels) Gather(table, indices, output unsafe.Pointer, N, D, _ int, _ Stream) error { diff --git a/internal/gpuapi/opencl_kernels.go b/internal/gpuapi/opencl_kernels.go index 645c801..723a280 100644 --- a/internal/gpuapi/opencl_kernels.go +++ b/internal/gpuapi/opencl_kernels.go @@ -224,6 +224,18 @@ func (k *OpenCLKernels) FusedQKNormRoPEF32(_, _, _, _, _, _ unsafe.Pointer, _ fl return fmt.Errorf("FusedQKNormRoPEF32: not implemented for OpenCL") } +func (k *OpenCLKernels) FusedEncoderFwdF32(_ unsafe.Pointer, _, _ *[16]unsafe.Pointer, _, _ unsafe.Pointer, _, _, _, _, _, _, _ int, _ Stream) error { + return fmt.Errorf("FusedEncoderFwdF32: not implemented for OpenCL") +} + +func (k *OpenCLKernels) FusedEncoderBwdF32(_ unsafe.Pointer, _, _ *[16]unsafe.Pointer, _ *[16]unsafe.Pointer, _ *[15]unsafe.Pointer, _ *[16]unsafe.Pointer, _, _, _ unsafe.Pointer, _, _, _, _, _, _, _ int, _ Stream) error { + return fmt.Errorf("FusedEncoderBwdF32: not implemented for OpenCL") +} + +func (k *OpenCLKernels) FusedEncoderFwdAvailable() bool { + return false +} + func (k *OpenCLKernels) ScaledSoftmaxF32(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error { return fmt.Errorf("ScaledSoftmaxF32: not implemented for OpenCL") } diff --git a/internal/gpuapi/rocm_kernels.go b/internal/gpuapi/rocm_kernels.go index 11f2565..c2fde30 100644 --- a/internal/gpuapi/rocm_kernels.go +++ b/internal/gpuapi/rocm_kernels.go @@ -214,6 +214,18 @@ func (k *ROCmKernels) FusedQKNormRoPEF32(_, _, _, _, _, _ unsafe.Pointer, _ floa return fmt.Errorf("FusedQKNormRoPEF32: not implemented for ROCm") } +func (k *ROCmKernels) FusedEncoderFwdF32(_ unsafe.Pointer, _, _ *[16]unsafe.Pointer, _, _ unsafe.Pointer, _, _, _, _, _, _, _ int, _ Stream) error { + return fmt.Errorf("FusedEncoderFwdF32: not implemented for ROCm") +} + +func (k *ROCmKernels) FusedEncoderBwdF32(_ unsafe.Pointer, _, _ *[16]unsafe.Pointer, _ *[16]unsafe.Pointer, _ *[15]unsafe.Pointer, _ *[16]unsafe.Pointer, _, _, _ unsafe.Pointer, _, _, _, _, _, _, _ int, _ Stream) error { + return fmt.Errorf("FusedEncoderBwdF32: not implemented for ROCm") +} + +func (k *ROCmKernels) FusedEncoderFwdAvailable() bool { + return false +} + func (k *ROCmKernels) ScaledSoftmaxF32(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error { return fmt.Errorf("ScaledSoftmaxF32: not implemented for ROCm") } diff --git a/internal/gpuapi/sycl_kernels.go b/internal/gpuapi/sycl_kernels.go index f101a53..ecee11e 100644 --- a/internal/gpuapi/sycl_kernels.go +++ b/internal/gpuapi/sycl_kernels.go @@ -209,6 +209,18 @@ func (k *SYCLKernels) FusedQKNormRoPEF32(_, _, _, _, _, _ unsafe.Pointer, _ floa return fmt.Errorf("FusedQKNormRoPEF32: not implemented for SYCL") } +func (k *SYCLKernels) FusedEncoderFwdF32(_ unsafe.Pointer, _, _ *[16]unsafe.Pointer, _, _ unsafe.Pointer, _, _, _, _, _, _, _ int, _ Stream) error { + return fmt.Errorf("FusedEncoderFwdF32: not implemented for SYCL") +} + +func (k *SYCLKernels) FusedEncoderBwdF32(_ unsafe.Pointer, _, _ *[16]unsafe.Pointer, _ *[16]unsafe.Pointer, _ *[15]unsafe.Pointer, _ *[16]unsafe.Pointer, _, _, _ unsafe.Pointer, _, _, _, _, _, _, _ int, _ Stream) error { + return fmt.Errorf("FusedEncoderBwdF32: not implemented for SYCL") +} + +func (k *SYCLKernels) FusedEncoderFwdAvailable() bool { + return false +} + func (k *SYCLKernels) ScaledSoftmaxF32(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, s Stream) error { if !sycl.ScaledSoftmaxF32Available() { return fmt.Errorf("ScaledSoftmaxF32: SYCL kernel not available")