Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,6 @@ cython_debug/
marimo/_static/
marimo/_lsp/
__marimo__/

# Local dev notes (not for upstream)
_dev_notes/
174 changes: 174 additions & 0 deletions benchmarks/benchmark_batch_invariant_logp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

"""Benchmark batch-invariant logp: Native vs Triton vs CUDA (SM90 TMA).

All three backends compute ``logits[t, target[t]] - logsumexp(logits[t, :])``
from a materialized ``[N, V]`` logits tensor with a locked, per-row reduction
order (batch-invariant). The comparison here is latency and peak VRAM across a
vocab sweep:

- Native materializes ``log_softmax`` over the full ``[N, V]`` tensor.
- Triton streams the vocab through an online softmax (grid = one program/row).
- CUDA is the Hopper TMA online-softmax kernel (one CTA/row); only present when
the extension is built with ``KERNEL_ALIGN_FORCE_SM90=1`` on an SM90 device.

Usage:
python benchmarks/benchmark_batch_invariant_logp.py
python benchmarks/benchmark_batch_invariant_logp.py --configs "4096,128256;8192,151936"
"""

import argparse

import torch
from tabulate import tabulate

from rl_engine.kernels.ops.pytorch.loss.batch_invariant_logp import NativeBatchInvariantLogpOp
from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import TritonBatchInvariantLogpOp
from rl_engine.platforms.device import device_ctx
from rl_engine.utils.logger import logger


def _maybe_sm90_op():
"""The Hopper TMA op, or None when unavailable (non-Hopper / not built)."""
from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE

if not (
torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] == 9
and _EXT_AVAILABLE
and hasattr(_C, "batch_invariant_logp_sm90")
):
return None
from rl_engine.kernels.ops.cuda.loss.batch_invariant_logp import BatchInvariantLogpSM90Op

return BatchInvariantLogpSM90Op()


# (num_tokens, vocab); vocab kept a multiple of 8 so the bf16 TMA path runs.
DEFAULT_CONFIGS = [
(4096, 32768),
(4096, 128256),
(4096, 151936),
(8192, 128256),
]


def _make_inputs(num_tokens, vocab, device, dtype):
logits = torch.randn(num_tokens, vocab, device=device, dtype=dtype)
target = torch.randint(0, vocab, (num_tokens,), device=device)
return logits, target


def _time_ms(fn, warmup, iters):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / iters


def _peak_vram_gb(fn, warmup=3, iters=5):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
torch.cuda.empty_cache()
baseline = torch.cuda.memory_allocated()
torch.cuda.reset_peak_memory_stats()
for _ in range(iters):
fn()
torch.cuda.synchronize()
return (torch.cuda.max_memory_allocated() - baseline) / (1024**3)

Comment on lines +77 to +88

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

VRAM measurement ignores CLI --warmup/--iters overrides.

_peak_vram_gb hardcodes warmup=3, iters=5 (Line 77) instead of accepting args.warmup/args.iters like _time_ms does, so --warmup/--iters only affect timing, not the memory measurement, even though it is called right alongside the timing runs (Lines 115-116, 128).

Proposed fix
-def _peak_vram_gb(fn, warmup=3, iters=5):
+def _peak_vram_gb(fn, warmup, iters):
     for _ in range(warmup):
         fn()
     torch.cuda.synchronize()
     torch.cuda.empty_cache()
     baseline = torch.cuda.memory_allocated()
     torch.cuda.reset_peak_memory_stats()
     for _ in range(iters):
         fn()
     torch.cuda.synchronize()
     return (torch.cuda.max_memory_allocated() - baseline) / (1024**3)

And update call sites to pass args.warmup, args.iters.

Also applies to: 113-116, 127-128

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/benchmark_batch_invariant_logp.py` around lines 77 - 88, The VRAM
helper currently hardcodes its warmup and iteration counts, so CLI overrides are
ignored for memory measurement. Update `_peak_vram_gb` to accept the same
`warmup` and `iters` inputs used by `_time_ms`, and change the call sites in
`benchmark_batch_invariant_logp` to pass `args.warmup` and `args.iters` so
timing and VRAM runs stay consistent.


def run_benchmark(args):
if device_ctx.device_type not in ["cuda", "xpu", "hip"]:
raise RuntimeError("batch_invariant_logp benchmark requires a compatible GPU device.")
Comment on lines +63 to +92

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🌐 Web query:

Does torch.xpu support Event, synchronize, and memory stats APIs equivalent to torch.cuda for Intel GPUs?

💡 Result:

Yes, torch.xpu provides APIs that are functionally equivalent to the corresponding torch.cuda APIs for Intel GPUs [1][2][3]. Specifically, PyTorch implements the following support: - Synchronization: The torch.xpu.synchronize function is available to wait for all kernels in all streams on an XPU device to complete, mirroring the behavior of torch.cuda.synchronize [4][2][5]. - Events: The torch.xpu.Event class serves as a wrapper around an XPU event, allowing for synchronization markers and monitoring device progress, consistent with torch.cuda.Event [6][2][7]. - Memory Statistics: PyTorch provides a comprehensive set of memory management APIs under the torch.xpu.memory namespace that mirror torch.cuda memory functions [1][2]. These include: - torch.xpu.memory.memory_stats and memory_stats_as_nested_dict for allocator statistics [1][2][8]. - Metrics such as memory_allocated, memory_reserved, max_memory_allocated, and max_memory_reserved [1][2]. - Utilities like memory_snapshot, mem_get_info, and tools to manage memory fractions [1][2]. These XPU-specific APIs are designed to provide a consistent user experience when migrating code from CUDA to Intel GPU backends [9][3].

Citations:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect the benchmark file and nearby device handling.
git ls-files benchmarks/benchmark_batch_invariant_logp.py
echo "---"
cat -n benchmarks/benchmark_batch_invariant_logp.py | sed -n '1,220p'

echo "--- device_ctx references ---"
rg -n "device_ctx|torch\.cuda|torch\.xpu|torch\.hip|xpu" benchmarks/benchmark_batch_invariant_logp.py

Repository: RL-Align/RL-Kernel

Length of output: 7659


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check PyTorch backend alias behavior from installed docs/source if available in repo.
python3 - <<'PY'
import importlib.util, sys
mods = ["torch"]
for m in mods:
    print(m, importlib.util.find_spec(m) is not None)
PY

# Try to inspect PyTorch API docs via Python introspection if torch is installed.
python3 - <<'PY'
try:
    import torch
    print("torch version:", torch.__version__)
    for name in ["cuda", "xpu"]:
        print(name, "present:", hasattr(torch, name))
        if hasattr(torch, name):
            mod = getattr(torch, name)
            for attr in ["synchronize", "Event", "memory_allocated", "empty_cache", "reset_peak_memory_stats", "max_memory_allocated"]:
                print(name, attr, hasattr(mod, attr))
except Exception as e:
    print("torch import failed:", repr(e))
PY

Repository: RL-Align/RL-Kernel

Length of output: 530


🌐 Web query:

PyTorch Intel XPU backend torch.cuda compatibility layer Event synchronize memory_allocated max_memory_allocated on XPU

💡 Result:

In PyTorch, the XPU backend provides a dedicated namespace (torch.xpu) that includes direct equivalents to many torch.cuda functions, rather than relying on a universal "compatibility layer" that automatically redirects calls [1][2]. Developers should use torch.xpu APIs explicitly when targeting Intel GPUs [3]. Key XPU equivalents for the requested CUDA functions are as follows: - Synchronization: Use torch.xpu.synchronize(device=None) to wait for all kernels in all streams on a specific XPU device to complete [4][5][2]. This mirrors the behavior of torch.cuda.synchronize(). - Memory Allocation Tracking: The torch.xpu namespace provides torch.xpu.memory.memory_allocated(device=None) and torch.xpu.memory.max_memory_allocated(device=None) [6][7][8]. These functions return the current and peak memory occupied by tensors (in bytes), respectively [6][8]. While there is no transparent automatic compatibility layer that maps torch.cuda to torch.xpu globally, modern development practices often involve using conditional logic or utility wrappers to check the active device and dispatch the appropriate calls (e.g., torch.cuda.synchronize() vs. torch.xpu.synchronize()) [3]. Some third-party frameworks (such as vLLM) have implemented specific internal monkey-patching or abstraction layers to map torch.cuda APIs to torch.xpu equivalents to support features like graph capture on XPU [9].

Citations:


🌐 Web query:

site:pytorch.org torch.cuda xpu compatibility alias torch.cuda.Event torch.cuda.synchronize xpu

💡 Result:

In PyTorch, the torch.xpu API provides synchronization and event management functionality specifically for Intel GPUs, mirroring the patterns found in torch.cuda [1][2]. For migration purposes, you should replace references to torch.cuda with torch.xpu [1]. The equivalent components are: - Event Management: Use torch.xpu.Event instead of torch.cuda.Event [3][4][5]. Like its CUDA counterpart, torch.xpu.Event is a synchronization marker used to monitor device progress, measure timing (if enabled), and synchronize XPU streams [3][6]. It supports methods such as record, synchronize, wait, and query [6][5]. - Synchronization: Use torch.xpu.synchronize instead of torch.cuda.synchronize [7][8]. This function waits for all kernels in all streams on the specified XPU device to complete [8]. Additionally, individual streams have their own synchronization methods via torch.xpu.Stream, such as stream.synchronize [9]. For further details on migrating workflows, the official PyTorch documentation on XPU semantics is the recommended resource [1][7].

Citations:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Look for other benchmarks or utilities that handle xpu/hip via torch.cuda or torch.xpu.
rg -n "torch\.xpu|device_type.*xpu|device_type.*hip|torch\.cuda\.(Event|synchronize|memory_allocated|max_memory_allocated|empty_cache|reset_peak_memory_stats)" .

Repository: RL-Align/RL-Kernel

Length of output: 7573


Dispatch timing/memory helpers by backend
run_benchmark allows "xpu", but _time_ms and _peak_vram_gb still hardcode torch.cuda. Use torch.xpu for Intel GPUs and keep torch.cuda for CUDA/ROCm so the benchmark works on every allowed device.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/benchmark_batch_invariant_logp.py` around lines 63 - 92, The
timing and memory helpers are hardcoded to torch.cuda even though run_benchmark
supports xpu, so the benchmark will break on Intel GPUs. Update _time_ms and
_peak_vram_gb to dispatch through the active backend in device_ctx.device_type,
using torch.xpu for xpu and torch.cuda for cuda/hip, including synchronize,
Event creation, empty_cache, memory_allocated, reset_peak_memory_stats, and
max_memory_allocated calls.


device = device_ctx.device
dtype = torch.bfloat16
native = NativeBatchInvariantLogpOp()
triton_op = TritonBatchInvariantLogpOp()
sm90_op = _maybe_sm90_op()

logger.info(
f"batch_invariant_logp benchmark on {device} (dtype={dtype}); "
f"SM90 TMA backend {'enabled' if sm90_op is not None else 'unavailable'}"
)

rows = []
for num_tokens, vocab in args.configs:
logits, target = _make_inputs(num_tokens, vocab, device, dtype)

def fwd(op, x=logits, t=target):
with torch.no_grad():
op(x, t, validate=False)

n_fwd = _time_ms(lambda: fwd(native), args.warmup, args.iters)
t_fwd = _time_ms(lambda: fwd(triton_op), args.warmup, args.iters)
n_vram = _peak_vram_gb(lambda: fwd(native))
t_vram = _peak_vram_gb(lambda: fwd(triton_op))

row = [
f"{num_tokens}x{vocab}",
f"{n_fwd:.3f}",
f"{t_fwd:.3f}",
f"{n_fwd/t_fwd:.2f}x",
f"{n_vram*1024:.0f}",
f"{t_vram*1024:.0f}",
]
if sm90_op is not None:
s_fwd = _time_ms(lambda: fwd(sm90_op), args.warmup, args.iters)
s_vram = _peak_vram_gb(lambda: fwd(sm90_op))
row += [
f"{s_fwd:.3f}",
f"{n_fwd/s_fwd:.2f}x",
f"{t_fwd/s_fwd:.2f}x",
f"{s_vram*1024:.0f}",
]
rows.append(row)

headers = [
"shape (N x V)",
"native fwd ms",
"triton fwd ms",
"fwd speedup",
"native fwd MB",
"triton fwd MB",
]
if sm90_op is not None:
headers += [
"cuda fwd ms",
"cuda vs native",
"cuda vs triton",
"cuda fwd MB",
]
print(tabulate(rows, headers=headers, tablefmt="github"))


def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--iters", type=int, default=20)
parser.add_argument("--warmup", type=int, default=5)
parser.add_argument(
"--configs",
type=str,
default=None,
help="Semicolon-separated 'tokens,vocab' tuples, e.g. '4096,128256;8192,151936'.",
)
args = parser.parse_args()
if args.configs:
args.configs = [tuple(int(x) for x in tup.split(",")) for tup in args.configs.split(";")]
else:
args.configs = DEFAULT_CONFIGS
return args


if __name__ == "__main__":
run_benchmark(parse_args())
213 changes: 213 additions & 0 deletions csrc/cuda/batch_invariant_logp_kernel_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2026 RL-Kernel Contributors

// Hopper (SM90) batch-invariant selected-token log-prob
// logp[n] = logits[n, target[n]] - logsumexp(logits[n, :])

#include "../utils/tma_utils.cuh"
#include <cub/cub.cuh>
#include <cuda_bf16.h>
#include <math_constants.h>
#include <torch/extension.h>

namespace {

// A single TMA box dimension is capped at 256 elements
#define SMEM_TILE 4096
#define TMA_BOX 256
static constexpr int LOADS_PER_TILE = SMEM_TILE / TMA_BOX;

template <typename T> __device__ __forceinline__ float to_float(T x);
template <> __device__ __forceinline__ float to_float<nv_bfloat16>(nv_bfloat16 x) {
return __bfloat162float(x);
}
template <> __device__ __forceinline__ float to_float<float>(float x) { return x; }

template <typename T, int NUM_WARPS>
__global__ void batch_invariant_logp_sm90_kernel(const __grid_constant__ CUtensorMap logits_tmap,
const int *__restrict__ target,
const T *__restrict__ logits_gmem,
float *__restrict__ output_logp,
float *__restrict__ output_lse, int num_tokens,
int vocab_size, int ignore_index) {
constexpr int NUM_THREADS = NUM_WARPS * 32;

// one CTA per row; each warp streams a TMA tile of the row into shared memory
const int tid = threadIdx.x;
const int row_idx = blockIdx.x;

extern __shared__ __align__(1024) char smem[];
const int smem_addr = static_cast<int>(__cvta_generic_to_shared(smem));
T *smem_logits = reinterpret_cast<T *>(smem);
const int tma_mbar_addr = smem_addr + (SMEM_TILE * sizeof(T));

if (tid == 0) {
mbarrier_init(tma_mbar_addr, 1);
asm volatile("fence.mbarrier_init.release.cluster;");
}
__syncthreads();

const int v_aligned = (vocab_size / TMA_BOX) * TMA_BOX;
const int num_tiles = (v_aligned + SMEM_TILE - 1) / SMEM_TILE;
int phase = 0;

using BlockReduce = cub::BlockReduce<float, NUM_THREADS>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float s_tile_max;

float row_max = -CUDART_INF_F;
float row_sum = 0.0f;

for (int step = 0; step < num_tiles; ++step) {
const int col_offset = step * SMEM_TILE;
const int current_tile_size = min(SMEM_TILE, v_aligned - col_offset);

if (tid == 0) {
for (int j = 0; j < LOADS_PER_TILE; ++j) {
const int col = col_offset + j * TMA_BOX;
if (col < v_aligned) {
tma_2d_g2s(smem_addr + j * TMA_BOX * sizeof(T), &logits_tmap, col, row_idx,
tma_mbar_addr);
}
}
mbarrier_arrive_expect_tx(tma_mbar_addr, current_tile_size * sizeof(T));
}
mbarrier_wait(tma_mbar_addr, phase);
phase ^= 1;

// tile max (fixed strided partition -> deterministic)
float tile_max = -CUDART_INF_F;
for (int i = tid; i < current_tile_size; i += NUM_THREADS) {
tile_max = max(tile_max, to_float<T>(smem_logits[i]));
}
float block_tile_max = BlockReduce(temp_storage).Reduce(tile_max, cub::Max());
if (tid == 0)
s_tile_max = block_tile_max;
__syncthreads();

// tile sum of exp(x - tile_max).
float tile_sum = 0.0f;
for (int i = tid; i < current_tile_size; i += NUM_THREADS) {
tile_sum += expf(to_float<T>(smem_logits[i]) - s_tile_max);
}
float block_tile_sum = BlockReduce(temp_storage).Reduce(tile_sum, cub::Sum());

// Online log-sum-exp merge of this tile into the running row state.
if (tid == 0) {
const float new_max = max(row_max, s_tile_max);
row_sum = row_sum * expf(row_max - new_max) + block_tile_sum * expf(s_tile_max - new_max);
row_max = new_max;
}
__syncthreads();
}

// Tail [v_aligned, V): fewer than TMA_BOX elements, read straight from global.
const int tail = vocab_size - v_aligned;
if (tail > 0) {
const int64_t base = (int64_t)row_idx * vocab_size + v_aligned;
float tail_max = -CUDART_INF_F;
for (int i = tid; i < tail; i += NUM_THREADS) {
tail_max = max(tail_max, to_float<T>(logits_gmem[base + i]));
}
float block_tail_max = BlockReduce(temp_storage).Reduce(tail_max, cub::Max());
if (tid == 0)
s_tile_max = block_tail_max;
__syncthreads();

float tail_sum = 0.0f;
for (int i = tid; i < tail; i += NUM_THREADS) {
tail_sum += expf(to_float<T>(logits_gmem[base + i]) - s_tile_max);
}
float block_tail_sum = BlockReduce(temp_storage).Reduce(tail_sum, cub::Sum());
if (tid == 0) {
const float new_max = max(row_max, s_tile_max);
row_sum = row_sum * expf(row_max - new_max) + block_tail_sum * expf(s_tile_max - new_max);
row_max = new_max;
}
__syncthreads();
}

if (tid == 0) {
const float lse = row_max + logf(row_sum);
const int tgt = target[row_idx];
if (tgt == ignore_index) {
output_logp[row_idx] = 0.0f;
} else {
const float tgt_logit = to_float<T>(logits_gmem[(int64_t)row_idx * vocab_size + tgt]);
output_logp[row_idx] = tgt_logit - lse;
}
output_lse[row_idx] = lse;
}
}

template <typename T>
void launch_batch_invariant_logp_sm90(torch::Tensor logits, torch::Tensor target,
torch::Tensor logp, torch::Tensor lse, int N, int V,
int ignore_index) {
constexpr int NUM_WARPS = 4;
CUtensorMap logits_tmap;
// Global [N, V]; TMA box = [1 row, TMA_BOX cols], unswizzled row-major.
init_tensor_map<T>(&logits_tmap, reinterpret_cast<const T *>(logits.data_ptr<T>()), N, V, 1,
TMA_BOX);

const int smem_size = (SMEM_TILE * sizeof(T)) + 16;
batch_invariant_logp_sm90_kernel<T, NUM_WARPS><<<N, NUM_WARPS * 32, smem_size>>>(
logits_tmap, target.data_ptr<int>(), reinterpret_cast<const T *>(logits.data_ptr<T>()),
logp.data_ptr<float>(), lse.data_ptr<float>(), N, V, ignore_index);
}

// at::BFloat16 and nv_bfloat16 are layout-compatible; data_ptr<T> needs the ATen
// type, so bf16 gets its own launch that reinterprets the pointer.
void launch_batch_invariant_logp_sm90_bf16(torch::Tensor logits, torch::Tensor target,
torch::Tensor logp, torch::Tensor lse, int N, int V,
int ignore_index) {
constexpr int NUM_WARPS = 4;
CUtensorMap logits_tmap;
init_tensor_map<nv_bfloat16>(
&logits_tmap, reinterpret_cast<const nv_bfloat16 *>(logits.data_ptr<at::BFloat16>()), N, V,
1, TMA_BOX);

const int smem_size = (SMEM_TILE * sizeof(nv_bfloat16)) + 16;
batch_invariant_logp_sm90_kernel<nv_bfloat16, NUM_WARPS><<<N, NUM_WARPS * 32, smem_size>>>(
logits_tmap, target.data_ptr<int>(),
reinterpret_cast<const nv_bfloat16 *>(logits.data_ptr<at::BFloat16>()),
logp.data_ptr<float>(), lse.data_ptr<float>(), N, V, ignore_index);
}

} // namespace

std::vector<torch::Tensor> batch_invariant_logp_sm90_forward(torch::Tensor logits,
torch::Tensor target,
int64_t ignore_index) {
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor");
TORCH_CHECK(logits.dim() == 2, "logits must be 2-D [N, V]");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
const int N = logits.size(0);
const int V = logits.size(1);
TORCH_CHECK(target.numel() == N, "target must have one id per row: expected ", N, ", got ",
target.numel());

// TMA requires the global row stride (V * elem_size) to be 16-byte aligned.
const int elem_size = static_cast<int>(logits.element_size());
TORCH_CHECK((static_cast<int64_t>(V) * elem_size) % 16 == 0,
"batch_invariant_logp_sm90 requires the vocab row stride (V * elem_size) to be a "
"multiple of 16 bytes; got V=",
V, ", elem_size=", elem_size);

auto opts_f = logits.options().dtype(torch::kFloat);
auto logp = torch::empty({N}, opts_f);
auto lse = torch::empty({N}, opts_f);
auto target_i = target.to(torch::kInt32).contiguous();
Comment on lines +179 to +200

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify existing extension conventions for CUDA device guards / launch checks.
rg -n -C2 'CUDAGuard|OptionalCUDAGuard|C10_CUDA_KERNEL_LAUNCH_CHECK|CUDA_KERNEL_LAUNCH_CHECK' csrc

Repository: RL-Align/RL-Kernel

Length of output: 565


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect the SM90 batch-invariant logp kernel entrypoint and nearby CUDA launch code.
FILE="csrc/cuda/batch_invariant_logp_kernel_sm90.cu"
wc -l "$FILE"
sed -n '1,260p' "$FILE"

# Find the Python binding/callers for this kernel.
rg -n -C3 'batch_invariant_logp_sm90|_C\.batch_invariant_logp_sm90|batch_invariant_logp_sm90_forward' .

Repository: RL-Align/RL-Kernel

Length of output: 22607


Guard the SM90 entrypoint with a CUDA device guard and same-device target check.

logits can be on a different CUDA device than the current context, so this launch can hit the wrong GPU. The exported binding should set the device from logits and reject non-CUDA or cross-device target tensors; the Python wrapper already normalizes target, but direct _C.batch_invariant_logp_sm90(...) calls do not.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@csrc/cuda/batch_invariant_logp_kernel_sm90.cu` around lines 179 - 200, The
SM90 entrypoint in batch_invariant_logp_sm90_forward currently assumes the
current CUDA context matches logits, and it also only validates target shape,
not device. Add a CUDA device guard based on logits before any allocation or
launch, and add a TORCH_CHECK that target is on the same CUDA device as logits
(and CUDA at all) alongside the existing logits checks. Keep the validation and
setup near batch_invariant_logp_sm90_forward so direct calls through the
exported binding are safe even when the Python wrapper is bypassed.


if (logits.scalar_type() == at::kBFloat16) {
launch_batch_invariant_logp_sm90_bf16(logits, target_i, logp, lse, N, V,
static_cast<int>(ignore_index));
} else if (logits.scalar_type() == at::kFloat) {
launch_batch_invariant_logp_sm90<float>(logits, target_i, logp, lse, N, V,
static_cast<int>(ignore_index));
} else {
TORCH_CHECK(false, "batch_invariant_logp_sm90 supports only bfloat16 and float32 logits");
}

return {logp, lse};
}
Loading