From e1fcd53344e672248509ca7df3f2494fea704199 Mon Sep 17 00:00:00 2001 From: EthanZero2Hero <1019419030@qq.com> Date: Mon, 29 Jun 2026 16:10:37 +0800 Subject: [PATCH] feat(kernels): add batch-invariant RMSNorm --- benchmarks/benchmark_rmsnorm.py | 74 +++ csrc/cuda/rmsnorm.cu | 322 ++++++++++++ csrc/ops.cpp | 113 ++++ rl_engine/kernels/ops/base.py | 2 + rl_engine/kernels/ops/cuda/norm/rmsnorm.py | 76 +++ .../ops/pytorch/norm/rmsnorm_native.py | 63 +++ .../kernels/ops/pytorch/norm/rmsnorm_ref.py | 56 ++ .../kernels/ops/triton/rmsnorm_triton.py | 98 ++++ rl_engine/tests/test_rmsnorm.py | 488 ++++++++++++++++++ setup.py | 1 + 10 files changed, 1293 insertions(+) create mode 100644 benchmarks/benchmark_rmsnorm.py create mode 100644 csrc/cuda/rmsnorm.cu create mode 100644 rl_engine/kernels/ops/cuda/norm/rmsnorm.py create mode 100644 rl_engine/kernels/ops/pytorch/norm/rmsnorm_native.py create mode 100644 rl_engine/kernels/ops/pytorch/norm/rmsnorm_ref.py create mode 100644 rl_engine/kernels/ops/triton/rmsnorm_triton.py create mode 100644 rl_engine/tests/test_rmsnorm.py diff --git a/benchmarks/benchmark_rmsnorm.py b/benchmarks/benchmark_rmsnorm.py new file mode 100644 index 0000000..f811119 --- /dev/null +++ b/benchmarks/benchmark_rmsnorm.py @@ -0,0 +1,74 @@ +import argparse +import time + +import torch + +from rl_engine.kernels.ops.pytorch.norm.rmsnorm_ref import rmsnorm_ref_custom +from rl_engine.kernels.ops.triton.rmsnorm_triton import rmsnorm_triton + +try: + from rl_engine.kernels.ops.cuda.norm.rmsnorm import rmsnorm_cuda + + HAS_CUDA_EXT = True +except Exception: + HAS_CUDA_EXT = False + + +def bench(fn, x, w, dy, warmup=20, iters=100): + for _ in range(warmup): + x.grad = None + w.grad = None + y = fn(x, w) + y.backward(dy) + torch.cuda.synchronize() + + start = time.time() + for _ in range(iters): + x.grad = None + w.grad = None + y = fn(x, w) + y.backward(dy) + torch.cuda.synchronize() + return (time.time() - start) * 1000.0 / iters + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--T", type=int, default=1024) + parser.add_argument("--H", type=int, default=4096) + parser.add_argument("--dtype", choices=["fp16", "bf16"], default="bf16") + args = parser.parse_args() + + dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16 + device = "cuda" + T, H = args.T, args.H + + torch.manual_seed(0) + x_base = torch.randn((T, H), device=device, dtype=dtype) * 0.2 + w_base = torch.randn((H,), device=device, dtype=dtype) * 0.2 + dy = torch.randn((T, H), device=device, dtype=dtype) * 0.2 + + def make_inputs(): + return ( + x_base.detach().clone().requires_grad_(True), + w_base.detach().clone().requires_grad_(True), + ) + + x, w = make_inputs() + t_ref = bench(lambda a, b: rmsnorm_ref_custom(a, b), x, w, dy) + print(f"pytorch ref : {t_ref:.4f} ms") + + x, w = make_inputs() + t_tri = bench(lambda a, b: rmsnorm_triton(a, b), x, w, dy) + print(f"triton : {t_tri:.4f} ms | speedup vs ref: {t_ref / t_tri:.2f}x") + + if HAS_CUDA_EXT: + x, w = make_inputs() + t_cuda = bench(lambda a, b: rmsnorm_cuda(a, b), x, w, dy) + print(f"cuda : {t_cuda:.4f} ms | speedup vs ref: {t_ref / t_cuda:.2f}x") + else: + print("cuda : skipped, extension is not built") + + +if __name__ == "__main__": + main() diff --git a/csrc/cuda/rmsnorm.cu b/csrc/cuda/rmsnorm.cu new file mode 100644 index 0000000..9eeb55b --- /dev/null +++ b/csrc/cuda/rmsnorm.cu @@ -0,0 +1,322 @@ +#include +#include +#include +#include +#include +#include + +template +__device__ __forceinline__ float load_as_float(const scalar_t* ptr) { + return static_cast(*ptr); +} + +template <> +__device__ __forceinline__ float load_as_float(const at::Half* ptr) { + const __half* p = reinterpret_cast(ptr); + return __half2float(*p); +} + +template <> +__device__ __forceinline__ float load_as_float(const at::BFloat16* ptr) { + const __nv_bfloat16* p = reinterpret_cast(ptr); + return __bfloat162float(*p); +} + + +template +__device__ __forceinline__ void store_from_float(scalar_t* ptr, float v) { + *ptr = static_cast(v); +} + +template <> +__device__ __forceinline__ void store_from_float(at::Half* ptr, float v) { + __half* p = reinterpret_cast<__half*>(ptr); + *p = __float2half(v); +} + +template <> +__device__ __forceinline__ void store_from_float(at::BFloat16* ptr, float v) { + __nv_bfloat16* p = reinterpret_cast<__nv_bfloat16*>(ptr); + *p = __float2bfloat16(v); +} + + +__device__ __forceinline__ float block_reduce_sum(float v) { + extern __shared__ float smem[]; + int tid = threadIdx.x; + + smem[tid] = v; + __syncthreads(); + + for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + smem[tid] += smem[tid + stride]; + } + __syncthreads(); + } + + return smem[0]; +} + + +static int choose_threads(int H) { + if (H <= 64) return 64; + if (H <= 128) return 128; + if (H <= 256) return 256; + return 512; +} + + +template +__global__ void rmsnorm_fwd_kernel( + const scalar_t* __restrict__ x, + const weight_t* __restrict__ weight, + scalar_t* __restrict__ y, + float* __restrict__ rstd, + int T, + int H, + float eps +) { + int row = blockIdx.x; + int tid = threadIdx.x; + + const scalar_t* x_row = x + row * H; + scalar_t* y_row = y + row * H; + + float local_sum = 0.0f; + + // 计算 sum(x^2),每个 thread 负责若干列。 + for (int col = tid; col < H; col += blockDim.x) { + float xv = load_as_float(x_row + col); + local_sum += xv * xv; + } + + // 固定 block reduction。 + float sum = block_reduce_sum(local_sum); + + float row_rstd = rsqrtf(sum / static_cast(H) + eps); + + if (tid == 0) { + rstd[row] = row_rstd; + } + + __syncthreads(); + + // 写出 y = x * rstd * weight。 + for (int col = tid; col < H; col += blockDim.x) { + float xv = load_as_float(x_row + col); + float wv = load_as_float(weight + col); + float out = xv * row_rstd * wv; + store_from_float(y_row + col, out); + } +} + + +template +__global__ void rmsnorm_bwd_dx_kernel( + const scalar_t* __restrict__ dy, + const scalar_t* __restrict__ x, + const weight_t* __restrict__ weight, + const float* __restrict__ rstd, + scalar_t* __restrict__ dx, + int T, + int H +) { + int row = blockIdx.x; + int tid = threadIdx.x; + + const scalar_t* dy_row = dy + row * H; + const scalar_t* x_row = x + row * H; + scalar_t* dx_row = dx + row * H; + + float local_dot = 0.0f; + + for (int col = tid; col < H; col += blockDim.x) { + float dyv = load_as_float(dy_row + col); + float xv = load_as_float(x_row + col); + float wv = load_as_float(weight + col); + local_dot += dyv * wv * xv; + } + + float dot = block_reduce_sum(local_dot); + + float r = rstd[row]; + float coeff = dot * r * r * r / static_cast(H); + + for (int col = tid; col < H; col += blockDim.x) { + float dyv = load_as_float(dy_row + col); + float xv = load_as_float(x_row + col); + float wv = load_as_float(weight + col); + + float out = r * dyv * wv - xv * coeff; + store_from_float(dx_row + col, out); + } +} + + +template +__global__ void rmsnorm_partial_dw_kernel( + const scalar_t* __restrict__ dy, + const scalar_t* __restrict__ x, + const float* __restrict__ rstd, + const bool* __restrict__ mask, + float* __restrict__ partial_dw, + int T, + int H +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = T * H; + + if (idx >= total) return; + + int row = idx / H; + + if (!mask[row]) { + partial_dw[idx] = 0.0f; + return; + } + + float dyv = load_as_float(dy + idx); + float xv = load_as_float(x + idx); + float r = rstd[row]; + + partial_dw[idx] = dyv * xv * r; +} + + +__global__ void rmsnorm_reduce_dw_kernel( + const float* __restrict__ partial_dw, + float* __restrict__ dw, + int T, + int H +) { + int h = blockIdx.x; + int tid = threadIdx.x; + + float local_sum = 0.0f; + + for (int t = tid; t < T; t += blockDim.x) { + local_sum += partial_dw[t * H + h]; + } + + float sum = block_reduce_sum(local_sum); + + if (tid == 0) { + dw[h] = sum; + } +} + + +void rmsnorm_forward_cuda( + torch::Tensor x, + torch::Tensor weight, + torch::Tensor y, + torch::Tensor rstd, + double eps +) { + int T = x.size(0); + int H = x.size(1); + int threads = choose_threads(H); + size_t smem = threads * sizeof(float); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "rmsnorm_forward_cuda", [&] { + using x_t = scalar_t; + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, weight.scalar_type(), "rmsnorm_forward_weight_cuda", [&] { + using w_t = scalar_t; + rmsnorm_fwd_kernel<<>>( + x.data_ptr(), + weight.data_ptr(), + y.data_ptr(), + rstd.data_ptr(), + T, + H, + static_cast(eps) + ); + }); + }); +} + + +void rmsnorm_backward_dx_cuda( + torch::Tensor dy, + torch::Tensor x, + torch::Tensor weight, + torch::Tensor rstd, + torch::Tensor dx +) { + int T = x.size(0); + int H = x.size(1); + int threads = choose_threads(H); + size_t smem = threads * sizeof(float); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "rmsnorm_backward_dx_cuda", [&] { + using x_t = scalar_t; + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, weight.scalar_type(), "rmsnorm_backward_dx_weight_cuda", [&] { + using w_t = scalar_t; + rmsnorm_bwd_dx_kernel<<>>( + dy.data_ptr(), + x.data_ptr(), + weight.data_ptr(), + rstd.data_ptr(), + dx.data_ptr(), + T, + H + ); + }); + }); +} + + +void rmsnorm_backward_partial_dw_cuda( + torch::Tensor dy, + torch::Tensor x, + torch::Tensor rstd, + torch::Tensor mask, + torch::Tensor partial_dw +) { + int T = x.size(0); + int H = x.size(1); + int total = T * H; + + int threads = 256; + int blocks = (total + threads - 1) / threads; + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "rmsnorm_partial_dw_cuda", [&] { + rmsnorm_partial_dw_kernel<<>>( + dy.data_ptr(), + x.data_ptr(), + rstd.data_ptr(), + mask.data_ptr(), + partial_dw.data_ptr(), + T, + H + ); + }); +} + + +void rmsnorm_backward_reduce_dw_cuda( + torch::Tensor partial_dw, + torch::Tensor dw +) { + int T = partial_dw.size(0); + int H = partial_dw.size(1); + + int threads = choose_threads(T); + size_t smem = threads * sizeof(float); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + rmsnorm_reduce_dw_kernel<<>>( + partial_dw.data_ptr(), + dw.data_ptr(), + T, + H + ); +} diff --git a/csrc/ops.cpp b/csrc/ops.cpp index f48e4f9..2963b44 100644 --- a/csrc/ops.cpp +++ b/csrc/ops.cpp @@ -25,6 +25,114 @@ torch::Tensor fused_logp_forward_online_fp32(torch::Tensor logits, torch::Tensor torch::Tensor fused_logp_forward_online_indexed_out(torch::Tensor logits, torch::Tensor token_ids, torch::Tensor row_indices, torch::Tensor output); torch::Tensor fused_logp_forward_online_indexed_fp32(torch::Tensor logits, torch::Tensor token_ids, torch::Tensor row_indices); +// RMSNorm Declarations & Wrappers + +void rmsnorm_forward_cuda( + torch::Tensor x, + torch::Tensor weight, + torch::Tensor y, + torch::Tensor rstd, + double eps); + +void rmsnorm_backward_dx_cuda( + torch::Tensor dy, + torch::Tensor x, + torch::Tensor weight, + torch::Tensor rstd, + torch::Tensor dx); + +void rmsnorm_backward_partial_dw_cuda( + torch::Tensor dy, + torch::Tensor x, + torch::Tensor rstd, + torch::Tensor mask, + torch::Tensor partial_dw); + +void rmsnorm_backward_reduce_dw_cuda( + torch::Tensor partial_dw, + torch::Tensor dw); + +static void rmsnorm_check_input(const torch::Tensor& x, const char* name) { + TORCH_CHECK(x.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(x.is_contiguous(), name, " must be contiguous"); +} + +std::vector rmsnorm_forward( + torch::Tensor x, + torch::Tensor weight, + double eps) +{ + rmsnorm_check_input(x, "x"); + rmsnorm_check_input(weight, "weight"); + + TORCH_CHECK(x.dim() == 2, "x must be 2D [T, H]"); + TORCH_CHECK(weight.dim() == 1, "weight must be 1D [H]"); + TORCH_CHECK(x.size(1) == weight.size(0), "x.size(1) must equal weight.size(0)"); + + auto T = x.size(0); + auto y = torch::empty_like(x); + auto rstd = torch::empty({T}, x.options().dtype(torch::kFloat32)); + + rmsnorm_forward_cuda(x, weight, y, rstd, eps); + + return {y, rstd}; +} + +torch::Tensor rmsnorm_backward_dx( + torch::Tensor dy, + torch::Tensor x, + torch::Tensor weight, + torch::Tensor rstd) +{ + rmsnorm_check_input(dy, "dy"); + rmsnorm_check_input(x, "x"); + rmsnorm_check_input(weight, "weight"); + rmsnorm_check_input(rstd, "rstd"); + + TORCH_CHECK(dy.sizes() == x.sizes(), "dy and x must have same shape"); + TORCH_CHECK(x.dim() == 2, "x must be 2D [T, H]"); + TORCH_CHECK(weight.dim() == 1, "weight must be 1D [H]"); + TORCH_CHECK(rstd.dim() == 1, "rstd must be 1D [T]"); + TORCH_CHECK(rstd.size(0) == x.size(0), "rstd.size(0) must equal x.size(0)"); + + auto dx = torch::empty_like(x); + + rmsnorm_backward_dx_cuda(dy, x, weight, rstd, dx); + + return dx; +} + +torch::Tensor rmsnorm_backward_dw( + torch::Tensor dy, + torch::Tensor x, + torch::Tensor rstd, + torch::Tensor mask) +{ + rmsnorm_check_input(dy, "dy"); + rmsnorm_check_input(x, "x"); + rmsnorm_check_input(rstd, "rstd"); + rmsnorm_check_input(mask, "mask"); + + TORCH_CHECK(dy.sizes() == x.sizes(), "dy and x must have same shape"); + TORCH_CHECK(x.dim() == 2, "x must be 2D [T, H]"); + TORCH_CHECK(rstd.dim() == 1, "rstd must be 1D [T]"); + TORCH_CHECK(mask.dim() == 1, "mask must be 1D [T]"); + TORCH_CHECK(mask.scalar_type() == torch::kBool, "mask must be bool"); + TORCH_CHECK(rstd.size(0) == x.size(0), "rstd.size(0) must equal x.size(0)"); + TORCH_CHECK(mask.size(0) == x.size(0), "mask.size(0) must equal x.size(0)"); + + auto T = x.size(0); + auto H = x.size(1); + + auto partial_dw = torch::empty({T, H}, x.options().dtype(torch::kFloat32)); + auto dw = torch::empty({H}, x.options().dtype(torch::kFloat32)); + + rmsnorm_backward_partial_dw_cuda(dy, x, rstd, mask, partial_dw); + rmsnorm_backward_reduce_dw_cuda(partial_dw, dw); + + return dw; +} + // Prefix-Shared Attention Declarations & Wrappers void prefix_shared_attention_forward( @@ -95,5 +203,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // registry Prefix-Shared Attention m.def("prefix_shared_attention", &prefix_shared_attention, "Prefix-Shared Fused Attention for GRPO"); + + // registry RMSNorm + m.def("rmsnorm_forward", &rmsnorm_forward, "Batch-invariant RMSNorm forward CUDA"); + m.def("rmsnorm_backward_dx", &rmsnorm_backward_dx, "Batch-invariant RMSNorm backward dx CUDA"); + m.def("rmsnorm_backward_dw", &rmsnorm_backward_dw, "Deterministic RMSNorm backward dweight CUDA"); #endif } diff --git a/rl_engine/kernels/ops/base.py b/rl_engine/kernels/ops/base.py index 7643e96..ead5c30 100644 --- a/rl_engine/kernels/ops/base.py +++ b/rl_engine/kernels/ops/base.py @@ -3,6 +3,8 @@ from typing import Any +import torch # noqa: F401 # Load torch shared libraries before importing the binary extension. + from rl_engine.utils.logger import logger _C: Any = None diff --git a/rl_engine/kernels/ops/cuda/norm/rmsnorm.py b/rl_engine/kernels/ops/cuda/norm/rmsnorm.py new file mode 100644 index 0000000..980deda --- /dev/null +++ b/rl_engine/kernels/ops/cuda/norm/rmsnorm.py @@ -0,0 +1,76 @@ +import torch + +from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE + + +class RMSNormCuda(torch.autograd.Function): + """ + PyTorch autograd wrapper for CUDA RMSNorm. + """ + + @staticmethod + def forward(ctx, x, weight, mask=None, eps=1e-6): + """ + Forward: + y = x * rsqrt(mean(x^2) + eps) * weight + + Input: + x: [T, H], fp16/bf16/fp32 CUDA tensor + weight: [H], fp16/bf16/fp32 CUDA tensor + mask: [T], bool CUDA tensor + eps: float + + Output: + y: [T, H] + """ + assert x.is_cuda, "x must be CUDA tensor" + assert weight.is_cuda, "weight must be CUDA tensor" + assert x.is_contiguous(), "x must be contiguous" + assert weight.is_contiguous(), "weight must be contiguous" + assert x.dim() == 2, "x must be [T, H]" + assert weight.dim() == 1, "weight must be [H]" + assert x.shape[1] == weight.shape[0], "hidden size mismatch" + assert _EXT_AVAILABLE and hasattr( + _C, "rmsnorm_forward" + ), "RMSNorm CUDA extension is unavailable. Please rebuild with rmsnorm.cu." + + if mask is None: + mask = torch.ones((x.shape[0],), device=x.device, dtype=torch.bool) + else: + assert mask.is_cuda, "mask must be CUDA tensor" + assert mask.is_contiguous(), "mask must be contiguous" + assert mask.dtype == torch.bool, "mask must be bool" + assert mask.dim() == 1, "mask must be [T]" + assert mask.shape[0] == x.shape[0], "mask length mismatch" + + y, rstd = _C.rmsnorm_forward(x, weight, float(eps)) + + ctx.save_for_backward(x, weight, rstd, mask) + ctx.eps = eps + + return y + + @staticmethod + def backward(ctx, grad_out): + """ + Backward: + dx = CUDA row-wise deterministic kernel + dw = CUDA two-pass deterministic kernel + """ + x, weight, rstd, mask = ctx.saved_tensors + dy = grad_out.contiguous() + + dx = _C.rmsnorm_backward_dx(dy, x, weight, rstd) + + dw = _C.rmsnorm_backward_dw(dy, x, rstd, mask) + + return dx, dw, None, None + + +def rmsnorm_cuda(x, weight, eps=1e-6, mask=None): + """ + use: + y = rmsnorm_cuda(x, weight) + y = rmsnorm_cuda(x, weight, mask=mask) + """ + return RMSNormCuda.apply(x, weight, mask, eps) diff --git a/rl_engine/kernels/ops/pytorch/norm/rmsnorm_native.py b/rl_engine/kernels/ops/pytorch/norm/rmsnorm_native.py new file mode 100644 index 0000000..c818c75 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/norm/rmsnorm_native.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import torch + + +class NativeRMSNormOp: + """ + Pure PyTorch RMSNorm reference that can run on CPU. + """ + + def __init__(self) -> None: + pass + + def __call__( + self, + x: torch.Tensor, + weight: torch.Tensor, + *, + eps: float = 1e-6, + ) -> torch.Tensor: + return self.forward(x, weight, eps=eps) + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + *, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Canonical dtype path: accumulate in fp32, then cast output back to x.dtype. + """ + return self._rms_norm(x, weight, eps=eps, output_dtype=x.dtype) + + def forward_fp32( + self, + x: torch.Tensor, + weight: torch.Tensor, + *, + eps: float = 1e-6, + ) -> torch.Tensor: + """Ground truth path: accumulate in fp32 and keep fp32 output.""" + return self._rms_norm(x, weight, eps=eps, output_dtype=torch.float32) + + @staticmethod + def _rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + *, + eps: float, + output_dtype: torch.dtype, + ) -> torch.Tensor: + if weight.dim() != 1 or weight.shape[0] != x.shape[-1]: + raise ValueError( + f"weight must be 1-D of size x.shape[-1]={x.shape[-1]}, " + f"got tuple(weight.shape)={tuple(weight.shape)}" + ) + + x_f = x.float() + var = x_f.pow(2).mean(dim=-1, keepdim=True) + normed = x_f * torch.rsqrt(var + eps) + out = normed * weight.float() + return out.to(output_dtype) diff --git a/rl_engine/kernels/ops/pytorch/norm/rmsnorm_ref.py b/rl_engine/kernels/ops/pytorch/norm/rmsnorm_ref.py new file mode 100644 index 0000000..2d3a298 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/norm/rmsnorm_ref.py @@ -0,0 +1,56 @@ +import torch + + +def rmsnorm_ref(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """PyTorch RMSNorm reference. + + x: [T, H], fp16/bf16/fp32 CUDA tensor + weight: [H], fp16/bf16/fp32 CUDA tensor + return: [T, H] + + Accumulation is done in FP32 for numerical stability. + """ + x_fp32 = x.float() + w_fp32 = weight.float() + var = x_fp32.pow(2).mean(dim=-1, keepdim=True) + rstd = torch.rsqrt(var + eps) + y = x_fp32 * rstd * w_fp32 + return y.to(x.dtype) + + +class RMSNormRef(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, eps: float = 1e-6): + x_fp32 = x.float() + w_fp32 = weight.float() + + # rstd[t] = 1 / sqrt(mean_h(x[t, h]^2) + eps) + var = x_fp32.pow(2).mean(dim=-1, keepdim=True) + rstd = torch.rsqrt(var + eps) + y = x_fp32 * rstd * w_fp32 + + ctx.save_for_backward(x, weight, rstd) + ctx.eps = eps + return y.to(x.dtype) + + @staticmethod + def backward(ctx, grad_out): + x, weight, rstd = ctx.saved_tensors + x_fp32 = x.float() + w_fp32 = weight.float() + go_fp32 = grad_out.float() + rstd_fp32 = rstd.float() + H = x.shape[-1] + + gw = go_fp32 * w_fp32 + dot = (gw * x_fp32).sum(dim=-1, keepdim=True) + dx = rstd_fp32 * gw - x_fp32 * (rstd_fp32**3) * dot / H + + # dw_i = sum_t(g_ti * x_ti * r_t) + dw = (go_fp32 * x_fp32 * rstd_fp32).sum(dim=0) + return dx.to(x.dtype), dw.to(weight.dtype), None + + +def rmsnorm_ref_custom(x, weight, eps: float = 1e-6): + return RMSNormRef.apply(x, weight, eps) diff --git a/rl_engine/kernels/ops/triton/rmsnorm_triton.py b/rl_engine/kernels/ops/triton/rmsnorm_triton.py new file mode 100644 index 0000000..382fb56 --- /dev/null +++ b/rl_engine/kernels/ops/triton/rmsnorm_triton.py @@ -0,0 +1,98 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rmsnorm_fwd_kernel( + X, W, Y, RSTD, T: tl.constexpr, H: tl.constexpr, EPS: tl.constexpr, BLOCK_H: tl.constexpr +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + mask = offs < H + + x = tl.load(X + row * H + offs, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W + offs, mask=mask, other=0.0).to(tl.float32) + + ss = tl.sum(x * x, axis=0) + rstd = tl.rsqrt(ss / H + EPS) + y = x * rstd * w + + tl.store(Y + row * H + offs, y, mask=mask) + tl.store(RSTD + row, rstd) + + +@triton.jit +def _rmsnorm_bwd_dx_kernel( + DY, X, W, RSTD, DX, PARTIAL_DW, T: tl.constexpr, H: tl.constexpr, BLOCK_H: tl.constexpr +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + mask = offs < H + + dy = tl.load(DY + row * H + offs, mask=mask, other=0.0).to(tl.float32) + x = tl.load(X + row * H + offs, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W + offs, mask=mask, other=0.0).to(tl.float32) + rstd = tl.load(RSTD + row).to(tl.float32) + + gw = dy * w + dot = tl.sum(gw * x, axis=0) + dx = rstd * gw - x * rstd * rstd * rstd * dot / H + + pdw = dy * x * rstd + + tl.store(DX + row * H + offs, dx, mask=mask) + tl.store(PARTIAL_DW + row * H + offs, pdw, mask=mask) + + +@triton.jit +def _rmsnorm_bwd_dw_kernel(PARTIAL_DW, DW, T: tl.constexpr, H: tl.constexpr, BLOCK_T: tl.constexpr): + col = tl.program_id(0) + offs_t = tl.arange(0, BLOCK_T) + mask = offs_t < T + + vals = tl.load(PARTIAL_DW + offs_t * H + col, mask=mask, other=0.0).to(tl.float32) + acc = tl.sum(vals, axis=0) + tl.store(DW + col, acc) + + +class RMSNormTriton(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, eps: float = 1e-6): + assert x.is_cuda and weight.is_cuda + assert x.dim() == 2 and weight.dim() == 1 + T, H = x.shape + assert weight.numel() == H + + y = torch.empty_like(x) + rstd = torch.empty((T,), device=x.device, dtype=torch.float32) + + block_h = triton.next_power_of_2(H) + assert block_h <= 131072, "H too large for this simple Triton kernel" + + _rmsnorm_fwd_kernel[(T,)](x, weight, y, rstd, T, H, eps, BLOCK_H=block_h) + ctx.save_for_backward(x, weight, rstd) + ctx.H = H + return y + + @staticmethod + def backward(ctx, grad_out): + x, weight, rstd = ctx.saved_tensors + T, H = x.shape + dx = torch.empty_like(x) + partial_dw = torch.empty((T, H), device=x.device, dtype=torch.float32) + dw = torch.empty((H,), device=x.device, dtype=torch.float32) + + block_h = triton.next_power_of_2(H) + block_t = triton.next_power_of_2(T) + assert block_t <= 131072, "T too large for this simple single-program dw reduction" + + _rmsnorm_bwd_dx_kernel[(T,)]( + grad_out, x, weight, rstd, dx, partial_dw, T, H, BLOCK_H=block_h + ) + _rmsnorm_bwd_dw_kernel[(H,)](partial_dw, dw, T, H, BLOCK_T=block_t) + return dx, dw.to(weight.dtype), None + + +def rmsnorm_triton(x, weight, eps: float = 1e-6): + return RMSNormTriton.apply(x, weight, eps) diff --git a/rl_engine/tests/test_rmsnorm.py b/rl_engine/tests/test_rmsnorm.py new file mode 100644 index 0000000..b934e01 --- /dev/null +++ b/rl_engine/tests/test_rmsnorm.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import argparse +import math +import sys +from pathlib import Path + +import torch + +ROOT = Path(__file__).resolve().parents[2] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from rl_engine.kernels.ops.pytorch.norm.rmsnorm_native import NativeRMSNormOp # noqa: E402 +from rl_engine.kernels.ops.pytorch.norm.rmsnorm_ref import rmsnorm_ref_custom # noqa: E402 +from rl_engine.kernels.ops.triton.rmsnorm_triton import rmsnorm_triton # noqa: E402 + +try: + from rl_engine.kernels.ops.cuda.norm.rmsnorm import rmsnorm_cuda + + HAS_CUDA_EXT = True +except Exception as exc: + CUDA_IMPORT_ERROR = exc + HAS_CUDA_EXT = False + + +def parse_dtype(name: str) -> torch.dtype: + if name == "fp32": + return torch.float32 + if name == "fp16": + return torch.float16 + if name == "bf16": + return torch.bfloat16 + raise ValueError(f"Unsupported dtype: {name}") + + +def get_impl(name: str): + if name == "pytorch": + return rmsnorm_ref_custom + if name == "triton": + return rmsnorm_triton + if name == "cuda": + if not HAS_CUDA_EXT: + raise RuntimeError(f"CUDA extension import failed: {CUDA_IMPORT_ERROR}") + return rmsnorm_cuda + raise ValueError(f"Unknown impl: {name}") + + +def max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float: + return (a.float() - b.float()).abs().max().item() + + +def max_rel_diff(a: torch.Tensor, b: torch.Tensor) -> float: + a_f = a.float() + b_f = b.float() + denom = b_f.abs().clamp_min(1e-6) + return ((a_f - b_f).abs() / denom).max().item() + + +def tolerances(dtype: torch.dtype) -> tuple[float, float]: + if dtype == torch.float32: + return 2e-5, 2e-5 + if dtype == torch.float16: + return 3e-3, 3e-3 + if dtype == torch.bfloat16: + return 2e-2, 2e-2 + raise ValueError(f"Unsupported dtype: {dtype}") + + +def assert_close_with_report( + name: str, + actual: torch.Tensor, + expected: torch.Tensor, + *, + atol: float, + rtol: float, +) -> None: + abs_diff = max_abs_diff(actual, expected) + rel_diff = max_rel_diff(actual, expected) + print( + f" {name:<8} " + f"abs_diff={abs_diff:.6e}, " + f"rel_diff={rel_diff:.6e}, " + f"atol={atol}, rtol={rtol}" + ) + if not torch.allclose(actual.float(), expected.float(), atol=atol, rtol=rtol): + raise AssertionError( + f"{name} mismatch: abs_diff={abs_diff}, rel_diff={rel_diff}, " + f"atol={atol}, rtol={rtol}" + ) + + +def assert_exact(name: str, actual: torch.Tensor, expected: torch.Tensor) -> None: + if not torch.allclose(actual, expected, atol=0.0, rtol=0.0): + diff = max_abs_diff(actual, expected) + raise AssertionError(f"{name} mismatch, max diff = {diff}") + + +def native_rmsnorm_forward( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> tuple[torch.Tensor, torch.Tensor]: + x_f = x.float() + w_f = weight.float() + var = x_f.pow(2).mean(dim=-1, keepdim=True) + rstd = torch.rsqrt(var + eps) + y = x_f * rstd * w_f + return y.to(x.dtype), rstd.squeeze(-1) + + +def native_dw( + x: torch.Tensor, + dy: torch.Tensor, + rstd: torch.Tensor, + mask: torch.Tensor, +) -> torch.Tensor: + contrib = dy.float() * x.float() * rstd.float().unsqueeze(-1) + return contrib.masked_fill(~mask[:, None], 0.0).sum(dim=0) + + +def run_forward_backward(fn, x: torch.Tensor, w: torch.Tensor, dy: torch.Tensor): + x_req = x.detach().clone().contiguous().requires_grad_(True) + w_req = w.detach().clone().contiguous().requires_grad_(True) + dy_req = dy.detach().clone().contiguous() + y = fn(x_req, w_req) + y.backward(dy_req) + return y.detach(), x_req.grad.detach(), w_req.grad.detach() + + +def run_cuda_dw( + x: torch.Tensor, + dy: torch.Tensor, + weight: torch.Tensor, + mask: torch.Tensor, + eps: float, +) -> torch.Tensor: + x_req = x.detach().clone().contiguous().requires_grad_(True) + w_req = weight.detach().clone().contiguous().requires_grad_(True) + y = rmsnorm_cuda(x_req, w_req, eps=eps, mask=mask) + y.backward(dy.detach().clone().contiguous()) + return w_req.grad.detach() + + +def build_padded_layout( + x_real: torch.Tensor, + dy_real: torch.Tensor, + total_rows: int, + real_positions: list[int], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = x_real.device + dtype = x_real.dtype + t_real, hidden = x_real.shape + + assert len(real_positions) == t_real + assert total_rows >= t_real + + x_pad = torch.randn((total_rows, hidden), device=device, dtype=torch.float32).to(dtype) + dy_pad = torch.randn((total_rows, hidden), device=device, dtype=torch.float32).to(dtype) + mask = torch.zeros((total_rows,), device=device, dtype=torch.bool) + + for src_t, dst_t in enumerate(real_positions): + x_pad[dst_t] = x_real[src_t] + dy_pad[dst_t] = dy_real[src_t] + mask[dst_t] = True + + return x_pad, dy_pad, mask + + +def test_correctness_case( + *, + impl_name: str, + dtype: torch.dtype, + total_rows: int, + hidden: int, + eps: float, +) -> None: + torch.manual_seed(0) + native = NativeRMSNormOp() + + x_cpu = torch.randn(total_rows, hidden, device="cpu", dtype=torch.float32) + w_cpu = torch.randn(hidden, device="cpu", dtype=torch.float32) + dy_cpu = torch.randn(total_rows, hidden, device="cpu", dtype=torch.float32) + + x_ref = x_cpu.to(dtype).float().detach().requires_grad_(True) + w_ref = w_cpu.to(dtype).float().detach().requires_grad_(True) + dy_ref = dy_cpu.to(dtype).float() + y_ref = native.forward_fp32(x_ref, w_ref, eps=eps) + y_ref.backward(dy_ref) + + x_gpu = x_cpu.to(device="cuda", dtype=dtype).detach().requires_grad_(True) + w_gpu = w_cpu.to(device="cuda", dtype=dtype).detach().requires_grad_(True) + dy_gpu = dy_cpu.to(device="cuda", dtype=dtype) + + fn = get_impl(impl_name) + y_gpu = fn(x_gpu, w_gpu, eps=eps) + y_gpu.backward(dy_gpu) + + got_y = y_gpu.detach().cpu() + got_dx = x_gpu.grad.detach().cpu() + got_dw = w_gpu.grad.detach().cpu() + + atol, rtol = tolerances(dtype) + dw_scale = max(1.0, math.sqrt(total_rows) / 4.0) + + print(f"[case] correctness impl={impl_name}, T={total_rows}, H={hidden}, dtype={dtype}") + assert_close_with_report("y", got_y, y_ref.detach(), atol=atol, rtol=rtol) + assert_close_with_report("dx", got_dx, x_ref.grad.detach(), atol=atol, rtol=rtol) + assert_close_with_report( + "dw", + got_dw, + w_ref.grad.detach(), + atol=atol * dw_scale, + rtol=rtol * dw_scale, + ) + print(" passed\n") + + +def test_deterministic_case( + *, + impl_name: str, + dtype: torch.dtype, + total_rows: int, + hidden: int, + repeat: int, +) -> None: + torch.manual_seed(0) + fn = get_impl(impl_name) + x = torch.randn(total_rows, hidden, device="cuda", dtype=dtype) + w = torch.randn(hidden, device="cuda", dtype=dtype) + dy = torch.randn(total_rows, hidden, device="cuda", dtype=dtype) + + y0, dx0, dw0 = run_forward_backward(fn, x, w, dy) + torch.cuda.synchronize() + + for i in range(repeat): + y, dx, dw = run_forward_backward(fn, x, w, dy) + torch.cuda.synchronize() + assert_exact(f"{impl_name} y repeat={i}", y0, y) + assert_exact(f"{impl_name} dx repeat={i}", dx0, dx) + assert_exact(f"{impl_name} dw repeat={i}", dw0, dw) + + print( + f"[PASS] deterministic impl={impl_name}, T={total_rows}, " + f"H={hidden}, dtype={dtype}, repeat={repeat}" + ) + + +def test_forward_dx_batch_position_case( + *, + impl_name: str, + dtype: torch.dtype, + hidden: int, +) -> None: + torch.manual_seed(1) + fn = get_impl(impl_name) + + target_x = torch.randn(1, hidden, device="cuda", dtype=dtype) + target_dy = torch.randn(1, hidden, device="cuda", dtype=dtype) + weight = torch.randn(hidden, device="cuda", dtype=dtype) + + y_single, dx_single, _ = run_forward_backward(fn, target_x, weight, target_dy) + + placements = [(16, 0), (16, 7), (64, 63)] + for total_rows, row_id in placements: + x = torch.randn(total_rows, hidden, device="cuda", dtype=dtype) + dy = torch.randn(total_rows, hidden, device="cuda", dtype=dtype) + x[row_id : row_id + 1] = target_x + dy[row_id : row_id + 1] = target_dy + y, dx, _ = run_forward_backward(fn, x, weight, dy) + assert_exact(f"{impl_name} y row={row_id}", y_single[0], y[row_id]) + assert_exact(f"{impl_name} dx row={row_id}", dx_single[0], dx[row_id]) + + print(f"[PASS] batch-position invariant impl={impl_name}, H={hidden}, dtype={dtype}") + + +def test_forward_dx_padding_layout_case( + *, + impl_name: str, + dtype: torch.dtype, + hidden: int, +) -> None: + torch.manual_seed(2) + fn = get_impl(impl_name) + + valid_rows = 4 + total_rows = 16 + positions = [1, 5, 9, 14] + + valid_x = torch.randn(valid_rows, hidden, device="cuda", dtype=dtype) + valid_dy = torch.randn(valid_rows, hidden, device="cuda", dtype=dtype) + weight = torch.randn(hidden, device="cuda", dtype=dtype) + + x_a = torch.randn(total_rows, hidden, device="cuda", dtype=dtype) + dy_a = torch.randn(total_rows, hidden, device="cuda", dtype=dtype) + x_a[0:valid_rows] = valid_x + dy_a[0:valid_rows] = valid_dy + y_a, dx_a, _ = run_forward_backward(fn, x_a, weight, dy_a) + + x_b = torch.randn(total_rows, hidden, device="cuda", dtype=dtype) + dy_b = torch.randn(total_rows, hidden, device="cuda", dtype=dtype) + for i, pos in enumerate(positions): + x_b[pos] = valid_x[i] + dy_b[pos] = valid_dy[i] + y_b, dx_b, _ = run_forward_backward(fn, x_b, weight, dy_b) + + for i, pos in enumerate(positions): + assert_exact(f"{impl_name} padding y row={i}", y_a[i], y_b[pos]) + assert_exact(f"{impl_name} padding dx row={i}", dx_a[i], dx_b[pos]) + + print(f"[PASS] padding-layout invariant impl={impl_name}, H={hidden}, dtype={dtype}") + + +def test_dw_padding_layout_case( + *, + dtype: torch.dtype, + total_rows: int, + hidden: int, + eps: float, + strict_bitwise: bool, +) -> None: + if not HAS_CUDA_EXT: + raise RuntimeError(f"CUDA extension import failed: {CUDA_IMPORT_ERROR}") + + torch.manual_seed(0) + x_real = torch.randn((total_rows, hidden), device="cuda", dtype=torch.float32).to(dtype) + dy_real = torch.randn((total_rows, hidden), device="cuda", dtype=torch.float32).to(dtype) + weight = torch.randn((hidden,), device="cuda", dtype=torch.float32).to(dtype) + + x1 = x_real.clone() + dy1 = dy_real.clone() + mask1 = torch.ones((total_rows,), device="cuda", dtype=torch.bool) + + x2, dy2, mask2 = build_padded_layout( + x_real=x_real, + dy_real=dy_real, + total_rows=2 * total_rows, + real_positions=[2 * i + 1 for i in range(total_rows)], + ) + + _, rstd1 = native_rmsnorm_forward(x1, weight, eps) + _, rstd2 = native_rmsnorm_forward(x2, weight, eps) + dw1 = run_cuda_dw(x1, dy1, weight, mask1, eps) + dw2 = run_cuda_dw(x2, dy2, weight, mask2, eps) + ref_dw1 = native_dw(x1, dy1, rstd1, mask1) + ref_dw2 = native_dw(x2, dy2, rstd2, mask2) + + atol, rtol = tolerances(dtype) + print(f"[case] dw padding invariant dtype={dtype}, T={total_rows}, H={hidden}") + print(" max |cuda dw1 - cuda dw2|:", max_abs_diff(dw1, dw2)) + print(" max |ref dw1 - ref dw2|:", max_abs_diff(ref_dw1, ref_dw2)) + print(" max |cuda dw1 - ref dw1|:", max_abs_diff(dw1, ref_dw1)) + print(" max |cuda dw2 - ref dw2|:", max_abs_diff(dw2, ref_dw2)) + assert_close_with_report("dw12", dw1, dw2, atol=atol, rtol=rtol) + assert_close_with_report("dw1ref", dw1, ref_dw1, atol=atol, rtol=rtol) + assert_close_with_report("dw2ref", dw2, ref_dw2, atol=atol, rtol=rtol) + if strict_bitwise: + assert_exact("cuda dw1 - cuda dw2", dw1.float(), dw2.float()) + + x3, dy3, mask3 = build_padded_layout( + x_real=x_real, + dy_real=dy_real, + total_rows=2 * total_rows + 1, + real_positions=[2 * i for i in range(total_rows)], + ) + _, rstd3 = native_rmsnorm_forward(x3, weight, eps) + dw3 = run_cuda_dw(x3, dy3, weight, mask3, eps) + ref_dw3 = native_dw(x3, dy3, rstd3, mask3) + + print(" max |cuda dw2 - cuda dw3|:", max_abs_diff(dw2, dw3)) + print(" max |cuda dw3 - ref dw3|:", max_abs_diff(dw3, ref_dw3)) + assert_close_with_report("dw23", dw2, dw3, atol=atol, rtol=rtol) + assert_close_with_report("dw3ref", dw3, ref_dw3, atol=atol, rtol=rtol) + if strict_bitwise: + assert_exact("cuda dw2 - cuda dw3", dw2.float(), dw3.float()) + + print("[PASS] backward dw is invariant under masked random-padding layout.") + + +def impls_from_arg(impl: str, *, include_pytorch: bool) -> list[str]: + if impl != "all": + return [impl] + impls = ["triton", "cuda"] + if include_pytorch: + impls.insert(0, "pytorch") + return impls + + +def kernel_impls_from_arg(impl: str) -> list[str]: + if impl == "pytorch": + raise ValueError( + "correctness suite compares kernels against PyTorch; " "use cuda, triton, or all" + ) + return impls_from_arg(impl, include_pytorch=False) + + +def dtypes_from_arg(dtype: str) -> list[torch.dtype]: + if dtype == "all": + return [torch.float32, torch.float16, torch.bfloat16] + return [parse_dtype(dtype)] + + +def hidden_sizes_from_args(hidden: int, sweep_hidden: bool) -> list[int]: + if not sweep_hidden: + return [hidden] + sizes = [63, 64, 65, 127, 128, 129, 255, 256, 257, hidden] + return list(dict.fromkeys(sizes)) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--suite", + choices=["all", "correctness", "deterministic", "batch", "dw"], + default="all", + ) + parser.add_argument("--impl", choices=["pytorch", "triton", "cuda", "all"], default="cuda") + parser.add_argument("--dtype", choices=["fp32", "fp16", "bf16", "all"], default="bf16") + parser.add_argument("--T", type=int, default=128) + parser.add_argument("--H", type=int, default=4096) + parser.add_argument("--eps", type=float, default=1e-6) + parser.add_argument("--repeat", type=int, default=50) + parser.add_argument("--sweep-hidden", action="store_true") + parser.add_argument("--strict-bitwise", action="store_true") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + + run_all = args.suite == "all" + + if run_all or args.suite == "correctness": + shapes = [ + (1, 128), + (2, 256), + (8, 768), + (16, 1024), + (32, 2048), + (64, 4096), + (128, 4096), + ] + for impl in kernel_impls_from_arg(args.impl): + for dtype in dtypes_from_arg(args.dtype): + for total_rows, hidden in shapes: + test_correctness_case( + impl_name=impl, + dtype=dtype, + total_rows=total_rows, + hidden=hidden, + eps=args.eps, + ) + print("All native-vs-kernel tests passed.") + + if run_all or args.suite == "deterministic": + for impl in impls_from_arg(args.impl, include_pytorch=True): + for dtype in dtypes_from_arg(args.dtype): + test_deterministic_case( + impl_name=impl, + dtype=dtype, + total_rows=args.T, + hidden=args.H, + repeat=args.repeat, + ) + print("all deterministic tests passed") + + if run_all or args.suite == "batch": + for impl in impls_from_arg(args.impl, include_pytorch=True): + for dtype in dtypes_from_arg(args.dtype): + for hidden in hidden_sizes_from_args(args.H, args.sweep_hidden): + test_forward_dx_batch_position_case(impl_name=impl, dtype=dtype, hidden=hidden) + test_forward_dx_padding_layout_case(impl_name=impl, dtype=dtype, hidden=hidden) + print("all batch-invariant tests passed") + + if run_all or args.suite == "dw": + for dtype in dtypes_from_arg(args.dtype): + test_dw_padding_layout_case( + dtype=dtype, + total_rows=args.T, + hidden=args.H, + eps=args.eps, + strict_bitwise=args.strict_bitwise, + ) + + print("all RMSNorm tests passed") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index d5ddb89..75ac4f0 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ def get_extensions(): "csrc/ops.cpp", "csrc/fused_logp_kernel.cu", "csrc/cuda/attention/prefix_shared_attention.cu", + "csrc/cuda/rmsnorm.cu", ] cc_major, cc_minor = torch.cuda.get_device_capability()