Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions benchmarks/benchmark_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import argparse
import time

import torch

from rl_engine.kernels.ops.pytorch.norm.rmsnorm_ref import rmsnorm_ref_custom
from rl_engine.kernels.ops.triton.rmsnorm_triton import rmsnorm_triton

try:
from rl_engine.kernels.ops.cuda.norm.rmsnorm import rmsnorm_cuda

HAS_CUDA_EXT = True
except Exception:
HAS_CUDA_EXT = False
Comment on lines +9 to +14

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '\n== benchmarks/benchmark_rmsnorm.py ==\n'
wc -l benchmarks/benchmark_rmsnorm.py
cat -n benchmarks/benchmark_rmsnorm.py

printf '\n== referenced rmsnorm module(s) ==\n'
fd -a 'rmsnorm.py' .
fd -a 'rmsnorm' rl_engine benchmarks | sed -n '1,120p'

printf '\n== search for HAS_CUDA_EXT / device guards ==\n'
rg -n "HAS_CUDA_EXT|device *= *\"cuda\"|cuda.is_available|rmsnorm_cuda" benchmarks rl_engine

Repository: RL-Align/RL-Kernel

Length of output: 10618


🏁 Script executed:

#!/bin/bash
set -euo pipefail

wc -l rl_engine/kernels/ops/cuda/norm/rmsnorm.py
cat -n rl_engine/kernels/ops/cuda/norm/rmsnorm.py

Repository: RL-Align/RL-Kernel

Length of output: 3104


🏁 Script executed:

#!/bin/bash
set -euo pipefail

wc -l rl_engine/kernels/ops/base.py
cat -n rl_engine/kernels/ops/base.py

Repository: RL-Align/RL-Kernel

Length of output: 810


Narrow the import guard
Catching Exception here hides real import-time failures in rl_engine.kernels.ops.cuda.norm.rmsnorm and makes them look like “extension not built.” Use ImportError so broken load/ABI issues surface instead of being skipped.

🧰 Tools
🪛 Ruff (0.15.18)

[warning] 13-13: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/benchmark_rmsnorm.py` around lines 9 - 14, The import guard around
rmsnorm_cuda is too broad and masks real module load failures as a missing
extension. Update the try/except in benchmark_rmsnorm.py to catch ImportError
only, keeping HAS_CUDA_EXT false only when the CUDA extension is actually
unavailable while allowing other import-time issues from
rl_engine.kernels.ops.cuda.norm.rmsnorm to surface.

Source: Linters/SAST tools



def bench(fn, x, w, dy, warmup=20, iters=100):
for _ in range(warmup):
x.grad = None
w.grad = None
y = fn(x, w)
y.backward(dy)
torch.cuda.synchronize()

start = time.time()
for _ in range(iters):
x.grad = None
w.grad = None
y = fn(x, w)
y.backward(dy)
torch.cuda.synchronize()
return (time.time() - start) * 1000.0 / iters


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--T", type=int, default=1024)
parser.add_argument("--H", type=int, default=4096)
parser.add_argument("--dtype", choices=["fp16", "bf16"], default="bf16")
args = parser.parse_args()

dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16
device = "cuda"
T, H = args.T, args.H

torch.manual_seed(0)
x_base = torch.randn((T, H), device=device, dtype=dtype) * 0.2
w_base = torch.randn((H,), device=device, dtype=dtype) * 0.2
dy = torch.randn((T, H), device=device, dtype=dtype) * 0.2
Comment on lines +42 to +49

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win

Fail fast when CUDA is unavailable.

device = "cuda" is unconditional, so on a CPU-only PyTorch build this dies during tensor creation with a less helpful runtime error. A short availability check would make the benchmark failure mode explicit.

Proposed fix
     dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16
+    if not torch.cuda.is_available():
+        raise SystemExit("benchmark_rmsnorm.py requires a CUDA-capable PyTorch runtime")
     device = "cuda"
     T, H = args.T, args.H
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16
device = "cuda"
T, H = args.T, args.H
torch.manual_seed(0)
x_base = torch.randn((T, H), device=device, dtype=dtype) * 0.2
w_base = torch.randn((H,), device=device, dtype=dtype) * 0.2
dy = torch.randn((T, H), device=device, dtype=dtype) * 0.2
dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16
if not torch.cuda.is_available():
raise SystemExit("benchmark_rmsnorm.py requires a CUDA-capable PyTorch runtime")
device = "cuda"
T, H = args.T, args.H
torch.manual_seed(0)
x_base = torch.randn((T, H), device=device, dtype=dtype) * 0.2
w_base = torch.randn((H,), device=device, dtype=dtype) * 0.2
dy = torch.randn((T, H), device=device, dtype=dtype) * 0.2
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/benchmark_rmsnorm.py` around lines 42 - 49, The benchmark setup in
benchmark_rmsnorm.py unconditionally assigns device to CUDA, which causes a
confusing tensor creation failure on CPU-only builds. Add an explicit CUDA
availability check before creating x_base, w_base, and dy, and fail fast with a
clear error message if CUDA is not available; keep the check near the existing
device assignment so the benchmark’s initialization path stays easy to find.


def make_inputs():
return (
x_base.detach().clone().requires_grad_(True),
w_base.detach().clone().requires_grad_(True),
)

x, w = make_inputs()
t_ref = bench(lambda a, b: rmsnorm_ref_custom(a, b), x, w, dy)
print(f"pytorch ref : {t_ref:.4f} ms")

x, w = make_inputs()
t_tri = bench(lambda a, b: rmsnorm_triton(a, b), x, w, dy)
print(f"triton : {t_tri:.4f} ms | speedup vs ref: {t_ref / t_tri:.2f}x")

if HAS_CUDA_EXT:
x, w = make_inputs()
t_cuda = bench(lambda a, b: rmsnorm_cuda(a, b), x, w, dy)
print(f"cuda : {t_cuda:.4f} ms | speedup vs ref: {t_ref / t_cuda:.2f}x")
else:
print("cuda : skipped, extension is not built")


if __name__ == "__main__":
main()
322 changes: 322 additions & 0 deletions csrc/cuda/rmsnorm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>

template <typename scalar_t>
__device__ __forceinline__ float load_as_float(const scalar_t* ptr) {
return static_cast<float>(*ptr);
}

template <>
__device__ __forceinline__ float load_as_float<at::Half>(const at::Half* ptr) {
const __half* p = reinterpret_cast<const __half*>(ptr);
return __half2float(*p);
}

template <>
__device__ __forceinline__ float load_as_float<at::BFloat16>(const at::BFloat16* ptr) {
const __nv_bfloat16* p = reinterpret_cast<const __nv_bfloat16*>(ptr);
return __bfloat162float(*p);
}


template <typename scalar_t>
__device__ __forceinline__ void store_from_float(scalar_t* ptr, float v) {
*ptr = static_cast<scalar_t>(v);
}

template <>
__device__ __forceinline__ void store_from_float<at::Half>(at::Half* ptr, float v) {
__half* p = reinterpret_cast<__half*>(ptr);
*p = __float2half(v);
}

template <>
__device__ __forceinline__ void store_from_float<at::BFloat16>(at::BFloat16* ptr, float v) {
__nv_bfloat16* p = reinterpret_cast<__nv_bfloat16*>(ptr);
*p = __float2bfloat16(v);
}


__device__ __forceinline__ float block_reduce_sum(float v) {
extern __shared__ float smem[];
int tid = threadIdx.x;

smem[tid] = v;
__syncthreads();

for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
smem[tid] += smem[tid + stride];
}
__syncthreads();
}

return smem[0];
}


static int choose_threads(int H) {
if (H <= 64) return 64;
if (H <= 128) return 128;
if (H <= 256) return 256;
return 512;
}


template <typename scalar_t, typename weight_t>
__global__ void rmsnorm_fwd_kernel(
const scalar_t* __restrict__ x,
const weight_t* __restrict__ weight,
scalar_t* __restrict__ y,
float* __restrict__ rstd,
int T,
int H,
float eps
) {
int row = blockIdx.x;
int tid = threadIdx.x;

const scalar_t* x_row = x + row * H;
scalar_t* y_row = y + row * H;

float local_sum = 0.0f;

// 计算 sum(x^2),每个 thread 负责若干列。
for (int col = tid; col < H; col += blockDim.x) {
float xv = load_as_float<scalar_t>(x_row + col);
local_sum += xv * xv;
}

// 固定 block reduction。
float sum = block_reduce_sum(local_sum);

float row_rstd = rsqrtf(sum / static_cast<float>(H) + eps);

if (tid == 0) {
rstd[row] = row_rstd;
}

__syncthreads();

// 写出 y = x * rstd * weight。
for (int col = tid; col < H; col += blockDim.x) {
float xv = load_as_float<scalar_t>(x_row + col);
float wv = load_as_float<weight_t>(weight + col);
float out = xv * row_rstd * wv;
store_from_float<scalar_t>(y_row + col, out);
}
}


template <typename scalar_t, typename weight_t>
__global__ void rmsnorm_bwd_dx_kernel(
const scalar_t* __restrict__ dy,
const scalar_t* __restrict__ x,
const weight_t* __restrict__ weight,
const float* __restrict__ rstd,
scalar_t* __restrict__ dx,
int T,
int H
) {
int row = blockIdx.x;
int tid = threadIdx.x;

const scalar_t* dy_row = dy + row * H;
const scalar_t* x_row = x + row * H;
scalar_t* dx_row = dx + row * H;

float local_dot = 0.0f;

for (int col = tid; col < H; col += blockDim.x) {
float dyv = load_as_float<scalar_t>(dy_row + col);
float xv = load_as_float<scalar_t>(x_row + col);
float wv = load_as_float<weight_t>(weight + col);
local_dot += dyv * wv * xv;
}

float dot = block_reduce_sum(local_dot);

float r = rstd[row];
float coeff = dot * r * r * r / static_cast<float>(H);

for (int col = tid; col < H; col += blockDim.x) {
float dyv = load_as_float<scalar_t>(dy_row + col);
float xv = load_as_float<scalar_t>(x_row + col);
float wv = load_as_float<weight_t>(weight + col);

float out = r * dyv * wv - xv * coeff;
store_from_float<scalar_t>(dx_row + col, out);
}
}


template <typename scalar_t>
__global__ void rmsnorm_partial_dw_kernel(
const scalar_t* __restrict__ dy,
const scalar_t* __restrict__ x,
const float* __restrict__ rstd,
const bool* __restrict__ mask,
float* __restrict__ partial_dw,
int T,
int H
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = T * H;

if (idx >= total) return;

int row = idx / H;

if (!mask[row]) {
partial_dw[idx] = 0.0f;
return;
}

float dyv = load_as_float<scalar_t>(dy + idx);
float xv = load_as_float<scalar_t>(x + idx);
float r = rstd[row];

partial_dw[idx] = dyv * xv * r;
}


__global__ void rmsnorm_reduce_dw_kernel(
const float* __restrict__ partial_dw,
float* __restrict__ dw,
int T,
int H
) {
int h = blockIdx.x;
int tid = threadIdx.x;

float local_sum = 0.0f;

for (int t = tid; t < T; t += blockDim.x) {
local_sum += partial_dw[t * H + h];
}

float sum = block_reduce_sum(local_sum);

if (tid == 0) {
dw[h] = sum;
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
}


void rmsnorm_forward_cuda(
torch::Tensor x,
torch::Tensor weight,
torch::Tensor y,
torch::Tensor rstd,
double eps
) {
int T = x.size(0);
int H = x.size(1);
int threads = choose_threads(H);
size_t smem = threads * sizeof(float);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "rmsnorm_forward_cuda", [&] {
using x_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, weight.scalar_type(), "rmsnorm_forward_weight_cuda", [&] {
using w_t = scalar_t;
rmsnorm_fwd_kernel<x_t, w_t><<<T, threads, smem, stream>>>(
x.data_ptr<x_t>(),
weight.data_ptr<w_t>(),
y.data_ptr<x_t>(),
rstd.data_ptr<float>(),
T,
H,
static_cast<float>(eps)
);
Comment on lines +222 to +236

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify CUDA launchers consistently guard devices and check launch errors.
rg -n -C2 'CUDAGuard|C10_CUDA_KERNEL_LAUNCH_CHECK|getCurrentCUDAStream|<<<' csrc

Repository: RL-Align/RL-Kernel

Length of output: 6414


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect rmsnorm CUDA file structure and nearby context.
ast-grep outline csrc/cuda/rmsnorm.cu --view expanded || true

printf '\n--- includes and top of file ---\n'
sed -n '1,120p' csrc/cuda/rmsnorm.cu

printf '\n--- forward/backward launch regions ---\n'
sed -n '200,340p' csrc/cuda/rmsnorm.cu

printf '\n--- guard/check usage in the file ---\n'
rg -n 'CUDAGuard|CUDA_CHECK|C10_CUDA_KERNEL_LAUNCH_CHECK|getCurrentCUDAStream' csrc/cuda/rmsnorm.cu

Repository: RL-Align/RL-Kernel

Length of output: 6949


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect similar CUDA wrappers for the preferred launch pattern.
sed -n '1,120p' csrc/cuda/attention/prefix_shared_attention.cu
printf '\n---\n'
sed -n '1,140p' csrc/fused_logp_kernel.cu

Repository: RL-Align/RL-Kernel

Length of output: 9349


Add device guards and launch checks to all RMSNorm CUDA entry points.
csrc/cuda/rmsnorm.cu:222-236, 254-268, 288-299, 314-321 use at::cuda::getCurrentCUDAStream() directly; guard the tensor device first and add C10_CUDA_KERNEL_LAUNCH_CHECK() after each kernel launch.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@csrc/cuda/rmsnorm.cu` around lines 222 - 236, The RMSNorm CUDA entry points
currently assume the active device matches the input tensors and do not verify
kernel launch success. In the forward path and the other RMSNorm CUDA functions
referenced by the same kernel-launch pattern, add an explicit device guard based
on the input tensor device before calling at::cuda::getCurrentCUDAStream(), and
place C10_CUDA_KERNEL_LAUNCH_CHECK() immediately after each <<<...>>> launch in
the rmsnorm_fwd_kernel and related RMSNorm kernels so failures are surfaced
consistently.

});
});
}


void rmsnorm_backward_dx_cuda(
torch::Tensor dy,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor rstd,
torch::Tensor dx
) {
int T = x.size(0);
int H = x.size(1);
int threads = choose_threads(H);
size_t smem = threads * sizeof(float);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "rmsnorm_backward_dx_cuda", [&] {
using x_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, weight.scalar_type(), "rmsnorm_backward_dx_weight_cuda", [&] {
using w_t = scalar_t;
rmsnorm_bwd_dx_kernel<x_t, w_t><<<T, threads, smem, stream>>>(
dy.data_ptr<x_t>(),
x.data_ptr<x_t>(),
weight.data_ptr<w_t>(),
rstd.data_ptr<float>(),
dx.data_ptr<x_t>(),
T,
H
);
});
});
}


void rmsnorm_backward_partial_dw_cuda(
torch::Tensor dy,
torch::Tensor x,
torch::Tensor rstd,
torch::Tensor mask,
torch::Tensor partial_dw
) {
int T = x.size(0);
int H = x.size(1);
int total = T * H;

int threads = 256;
int blocks = (total + threads - 1) / threads;

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "rmsnorm_partial_dw_cuda", [&] {
rmsnorm_partial_dw_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
dy.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
rstd.data_ptr<float>(),
mask.data_ptr<bool>(),
partial_dw.data_ptr<float>(),
T,
H
);
});
}


void rmsnorm_backward_reduce_dw_cuda(
torch::Tensor partial_dw,
torch::Tensor dw
) {
int T = partial_dw.size(0);
int H = partial_dw.size(1);

int threads = choose_threads(T);
size_t smem = threads * sizeof(float);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

rmsnorm_reduce_dw_kernel<<<H, threads, smem, stream>>>(
partial_dw.data_ptr<float>(),
dw.data_ptr<float>(),
T,
H
);
}
Loading
Loading