Skip to content

[WS1][kernels] Batch-invariant logprob (CUDA)#204

Open
KJLdefeated wants to merge 5 commits into
mainfrom
feat/batch-invariant-logp-cuda
Open

[WS1][kernels] Batch-invariant logprob (CUDA)#204
KJLdefeated wants to merge 5 commits into
mainfrom
feat/batch-invariant-logp-cuda

Conversation

@KJLdefeated

@KJLdefeated KJLdefeated commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator

[WS1][kernels] SM90 Batch-Invariant logprob CUDA Kernel

Summary

Adds a Hopper (SM90) TMA CUDA kernel for the batch-invariant selected-token log-probability operator introduced in #199 plus a Native/Triton/CUDA benchmark. The operator computes, from already materialized logits:

logp[t] = logits[t, target_ids[t]] - logsumexp(logits[t, :])

The kernel streams the vocab through shared memory with TMA bulk-tensor copies and folds it into a running online log-sum-exp, so it never materializes the [N, V] softmax and its extra forward memory is ~0. It is registered as the
top-priority CUDA backend on Hopper and falls back cleanly everywhere else.

Path Status
CUDA SM90 forward (csrc/cuda/batch_invariant_logp_kernel_sm90.cu) TMA online-softmax, bf16 + fp32, one CTA per row. New.
Python op (BatchInvariantLogpSM90Op) Autograd wrapper; TMA forward + tile-wise backward. New.
Backward Reuses #199's Triton tile-wise backward via the forward-saved lse (no [N, V] materialization).
Registry Hardware-gated: inserted at the front of the CUDA batch_invariant_logp list only when the SM90 symbol is compiled on cc_major == 9.
Fallback Non-bf16/fp32, unaligned vocab row stride, or non-Hopper → Triton, else native.
Benchmark benchmarks/benchmark_batch_invariant_logp.py compares Native/Triton/CUDA. New.

This builds on the open branch for #199 (Native + Triton). Reviewed after #199 merges.

Implementation

  • New kernel csrc/cuda/batch_invariant_logp_kernel_sm90.cu.
  • Binding in csrc/ops.cpp and build source in setup.py's sm90_srcs; stub in rl_engine/_C.pyi.
  • Python op rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py:
    • _BatchInvariantLogpSM90Function forward calls the compiled kernel;
    • backward reuses Triton _batch_invariant_logp_bwd_kernel.
  • Registry rl_engine/kernels/registry.py.
  • Docs/benchmark: updated docs/operators/batch-invariant-logp.md; added the benchmark script.

Validation environment

Item Value
GPU NVIDIA H200 (Hopper, SM90, compute capability 9.0)
Driver 550.127.08 (CUDA driver 12.4)
CUDA toolkit / nvcc 12.8, V12.8.93 (conda nvidia channel)
PyTorch 2.11.0+cu128
Python 3.11.15
Host compiler gcc 12.5.0
Extension KERNEL_ALIGN_FORCE_SM90=1, _C.batch_invariant_logp_sm90 available
Build vars TORCH_CUDA_ARCH_LIST=9.0a, KERNEL_ALIGN_DEV_RPATH=1

Correctness / Tests

Unit coverage

python -m pytest tests/test_batch_invariant_logp.py -q -rs

Result: 67 passed (14 new SM90-kernel tests: correctness fp32/bf16, large and unaligned vocab, single-token, 3-D, bitwise batch-invariance, backward, ignore-index, fp16 fallback; plus the registry-dispatch test now accepting the CUDA op).

python -m pytest tests/test_batch_invariant_logp.py -k CudaSM90 -q

Result: 14 passed. Registry selects BatchInvariantLogpSM90Op on the H200:

Detected TMA-capable architecture (SM90); ...
Successfully linked to precompiled _C.batch_invariant_logp_sm90 kernel.

Vocab sweep (bf16, vs log_softmax + gather)

Correct on both TMA-aligned and unaligned (tail-path) vocab sizes:

V 256 300 400 1024 50257 128256 151936
max abs err 0 0 4.8e-7 4.8e-7 9.5e-7 9.5e-7 9.5e-7

Benchmarks

python benchmarks/benchmark_batch_invariant_logp.py — bf16, H200, 20 iters + 5 warmup. Memory is peak extra above baseline (MB).

Forward

shape (N x V) native ms triton ms cuda ms cuda vs native cuda vs triton native MB triton MB cuda MB
4096x32768 1.341 0.148 0.090 14.8x 1.64x 1536 0 0
4096x128256 4.964 0.566 0.323 15.4x 1.75x 6012 0 0
4096x151936 5.908 0.669 0.384 15.4x 1.74x 7122 0 0
8192x128256 9.904 1.056 0.597 16.6x 1.77x 12024 0 0

Forward + backward

shape (N x V) native ms triton ms cuda ms cuda vs native cuda vs triton native MB triton MB cuda MB
4096x32768 3.549 0.463 0.407 8.7x 1.14x 1792 512 512
4096x128256 13.173 1.750 1.510 8.7x 1.16x 7014 2004 2004
4096x151936 15.632 2.072 1.788 8.7x 1.16x 8310 2376 2376
8192x128256 26.211 3.416 2.955 8.9x 1.16x 14028 4008 4008

Files

  • csrc/cuda/batch_invariant_logp_kernel_sm90.cu (new)
  • rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py (new)
  • benchmarks/benchmark_batch_invariant_logp.py (new)
  • csrc/ops.cpp, setup.py, rl_engine/_C.pyi, rl_engine/kernels/registry.py
  • tests/test_batch_invariant_logp.py, docs/operators/batch-invariant-logp.md

Summary by CodeRabbit

  • New Features

    • Added a new selected-log-probability operator with support for multiple backends, including a high-performance CUDA option on compatible hardware.
    • Added a benchmark script to compare speed and memory usage across implementations.
  • Bug Fixes

    • Improved handling of ignored targets and input validation across implementations.
    • Added fallback behavior when the fastest backend is unavailable.
  • Documentation

    • Added user-facing docs, examples, benchmark results, and navigation links for the new operator.
  • Tests

    • Added extensive coverage for correctness, gradients, batch invariance, edge cases, and backend dispatch.

hihaluemen and others added 5 commits June 28, 2026 20:06
Implements batch_invariant_logp for selected-token log probabilities from materialized logits with row-local, batch-invariant semantics.

- PyTorch NativeBatchInvariantLogpOp: FP32 row-wise reference with ignore_index handling and target validation.

- Triton TritonBatchInvariantLogpOp: online-softmax forward with fixed vocab tiling and tile-wise backward using saved per-row lse.

- Registry dispatch, PyTorch/Triton tests, and operator docs.
@coderabbitai

coderabbitai Bot commented Jul 2, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

Adds a new batch_invariant_logp operator computing per-row selected log-probabilities with batch-invariance guarantees, implemented as native PyTorch, Triton, and CUDA SM90 TMA backends. Wires the operator into the kernel registry, build system, includes a benchmark script, documentation, and a large test suite.

Changes

Batch-invariant logp operator

Layer / File(s) Summary
Native PyTorch reference implementation
rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
NativeBatchInvariantLogpOp computes selected log-probabilities via explicit FP32 logsumexp with shape validation and ignore_index masking.
Triton forward/backward kernels
rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py
Adds forward/backward Triton kernels with online log-sum-exp reduction and an autograd wrapper TritonBatchInvariantLogpOp with validation.
CUDA SM90 TMA kernel and bindings
csrc/cuda/batch_invariant_logp_kernel_sm90.cu, csrc/ops.cpp, rl_engine/_C.pyi, setup.py
Adds the TMA-based CUDA kernel, forward entrypoint, pybind and type-stub bindings, and adds the source to the SM90 build.
CUDA SM90 Python wrapper and fallback
rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py
Adds SM90 support detection, fallback to Triton/native, an autograd Function calling the compiled kernel, and BatchInvariantLogpSM90Op dispatch wrapper.
Kernel registry wiring
rl_engine/kernels/registry.py
Adds OpBackend entries and priority maps for cuda/rocm/cpu, plus hardware probing to prefer SM90 when available.
Test suite
tests/test_batch_invariant_logp.py
Adds correctness, batch-invariance, gradient, validation, fallback, and registry dispatch tests across all backends.
Benchmark script
benchmarks/benchmark_batch_invariant_logp.py
Adds a script comparing latency and peak VRAM across native/Triton/SM90 backends with CLI configuration.
Operator documentation
docs/operators/batch-invariant-logp.md, docs/operators/README.md, docs/.nav.yml
Adds documentation describing the operator contract, backends, benchmarks, and usage, linked in navigation.

Estimated code review effort: 4 (Complex) | ~60 minutes

Sequence Diagram(s)

sequenceDiagram
  participant Caller
  participant KernelRegistry
  participant BatchInvariantLogpSM90Op
  participant _BatchInvariantLogpSM90Function
  participant TritonBatchInvariantLogpOp
  participant NativeBatchInvariantLogpOp

  Caller->>KernelRegistry: get_op("batch_invariant_logp")
  KernelRegistry-->>Caller: selected backend op
  Caller->>BatchInvariantLogpSM90Op: __call__(logits, target_ids, ignore_index)
  BatchInvariantLogpSM90Op->>BatchInvariantLogpSM90Op: check shape, dtype, alignment
  alt SM90 supported
    BatchInvariantLogpSM90Op->>_BatchInvariantLogpSM90Function: apply()
    _BatchInvariantLogpSM90Function->>_BatchInvariantLogpSM90Function: call compiled SM90 kernel forward
    _BatchInvariantLogpSM90Function-->>BatchInvariantLogpSM90Op: logp, lse
  else unsupported input
    BatchInvariantLogpSM90Op->>TritonBatchInvariantLogpOp: apply() (fallback)
    alt Triton unavailable
      TritonBatchInvariantLogpOp->>NativeBatchInvariantLogpOp: apply() (fallback)
      NativeBatchInvariantLogpOp-->>BatchInvariantLogpSM90Op: logp
    else Triton available
      TritonBatchInvariantLogpOp-->>BatchInvariantLogpSM90Op: logp
    end
  end
  BatchInvariantLogpSM90Op-->>Caller: selected log-probability
Loading

Possibly related PRs

  • RL-Align/RL-Kernel#91: Both PRs adjust SM90/TMA kernel prioritization in rl_engine/kernels/registry.py based on compiled extension symbol presence, and share the same setup.py SM90 build-gating pattern.

Suggested labels: needs-gpu-ci

Suggested reviewers: inaniloquentee, Flink-ddd

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title is concise and accurately reflects the main change: a batch-invariant logprob CUDA backend.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/batch-invariant-logp-cuda
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch feat/batch-invariant-logp-cuda

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🧹 Nitpick comments (3)
tests/test_batch_invariant_logp.py (2)

45-56: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Narrow the broad exception catch.

Catching bare Exception to detect kernel availability masks unrelated failures (e.g., a broken _C module import due to an ABI mismatch) as "SM90 not available," silently skipping tests instead of surfacing the real problem.

🔧 Suggested narrowing
     try:
         from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE
-    except Exception:
+    except ImportError:
         return False
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_batch_invariant_logp.py` around lines 45 - 56, The
_sm90_kernel_available helper is swallowing unrelated import failures by
catching a broad Exception when importing rl_engine.kernels.ops.base. Narrow
this to the specific import error(s) expected for missing kernel bindings, and
keep the availability check in _sm90_kernel_available focused on genuine “not
installed/not built” cases so ABI or module breakages still surface during
tests.

Source: Linters/SAST tools


831-1020: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Missing validation-error test coverage for the SM90 backend.

Every other backend has a dedicated validation test class (TestValidation for native at lines 294-324, TestTritonValidation/TestTritonCPUValidation at lines 788-824), but there is no equivalent for BatchInvariantLogpSM90Op (e.g., rejecting 1D logits, shape mismatches, or out-of-range targets). Given the PR objective explicitly calls out "expanded tests covering correctness, backward, ignore-index, alignment, and fallback behavior," validation parity across backends would close a gap.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_batch_invariant_logp.py` around lines 831 - 1020, Add a SM90
validation test class for BatchInvariantLogpSM90Op to match the existing backend
coverage in TestValidation and TestTritonValidation/TestTritonCPUValidation. In
tests/test_batch_invariant_logp.py, create cases that assert invalid inputs are
rejected for the BatchInvariantLogpSM90Op path, such as 1D logits, mismatched
logits/target shapes, and targets outside the vocabulary range. Keep the tests
grouped near the other TestCudaSM90* classes and use _get_op() so the new
coverage clearly targets the SM90 backend.
rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py (1)

22-35: 🩺 Stability & Availability | 🔵 Trivial | ⚡ Quick win

Narrow the fallback exception handling.

Only the import path should fall back for missing Triton. Catching every Exception around the backward kernel can mask real kernel failures and silently take the full [N, V] PyTorch fallback.

Suggested shape
-        try:
+        try:
             import triton
 
             from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import (
                 _BLOCK_V,
                 _batch_invariant_logp_bwd_kernel,
             )
+        except (ImportError, ModuleNotFoundError):  # pragma: no cover - Triton missing
+            ...
+        else:
 
             grid = (num_tokens, triton.cdiv(vocab_size, _BLOCK_V))
             _batch_invariant_logp_bwd_kernel[grid](
                 ...
             )
-        except Exception:  # pragma: no cover - Triton missing
+        # PyTorch fallback for the missing-Triton branch only.

Also applies to: 67-89

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py` around lines 22 -
35, The _fallback_op helper currently catches every Exception, which can hide
real Triton/kernel failures and incorrectly route to the PyTorch fallback.
Narrow the exception handling in batch_invariant_logp so only Triton
import/module-missing cases fall back to NativeBatchInvariantLogpOp, and let
actual kernel errors propagate; update the same logic wherever the fallback is
duplicated in the backward path referenced by TritonBatchInvariantLogpOp and
NativeBatchInvariantLogpOp.

Source: Linters/SAST tools

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@benchmarks/benchmark_batch_invariant_logp.py`:
- Around line 77-88: The VRAM helper currently hardcodes its warmup and
iteration counts, so CLI overrides are ignored for memory measurement. Update
`_peak_vram_gb` to accept the same `warmup` and `iters` inputs used by
`_time_ms`, and change the call sites in `benchmark_batch_invariant_logp` to
pass `args.warmup` and `args.iters` so timing and VRAM runs stay consistent.
- Around line 63-92: The timing and memory helpers are hardcoded to torch.cuda
even though run_benchmark supports xpu, so the benchmark will break on Intel
GPUs. Update _time_ms and _peak_vram_gb to dispatch through the active backend
in device_ctx.device_type, using torch.xpu for xpu and torch.cuda for cuda/hip,
including synchronize, Event creation, empty_cache, memory_allocated,
reset_peak_memory_stats, and max_memory_allocated calls.

In `@csrc/cuda/batch_invariant_logp_kernel_sm90.cu`:
- Around line 179-200: The SM90 entrypoint in batch_invariant_logp_sm90_forward
currently assumes the current CUDA context matches logits, and it also only
validates target shape, not device. Add a CUDA device guard based on logits
before any allocation or launch, and add a TORCH_CHECK that target is on the
same CUDA device as logits (and CUDA at all) alongside the existing logits
checks. Keep the validation and setup near batch_invariant_logp_sm90_forward so
direct calls through the exported binding are safe even when the Python wrapper
is bypassed.

In `@docs/operators/batch-invariant-logp.md`:
- Around line 85-100: The “Forward + backward” measurements in the
batch-invariant logp docs are not reproducible from the current benchmark path.
Update the benchmark source used by run_benchmark so it actually exercises and
times a backward pass (including VRAM tracking), or else remove/qualify this
table as coming from a different benchmark variant. Use the existing benchmark
identifiers in benchmark_batch_invariant_logp.py and its run_benchmark flow to
keep the documentation aligned with the script.

In `@rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py`:
- Around line 114-130: The SM90 wrapper currently disables target validation by
default in BatchInvariantLogp.__call__ and apply, which can allow unsafe CUDA
gathers for bad non-ignored targets. Update the default in these methods so
validate matches the native operator’s behavior (true by default), and keep an
explicit opt-out path for callers that already validated targets. Preserve the
existing ignore_index handling and ensure the wrapper forwards the validate flag
consistently through BatchInvariantLogp.
- Around line 12-19: The `_sm90_supported` guard only checks tensor device type
and alignment, so tensors from non-Hopper CUDA devices can still route into the
SM90 TMA path. Update `_sm90_supported` to also verify the input logits’ CUDA
capability (cc_major == 9) on the tensor’s current device, alongside the
existing dtype and stride-multiple-of-16 checks, so cached registry entries
correctly fall back for non-SM90 GPUs.

In `@rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py`:
- Around line 115-124: The shape validator in _validate_shapes currently allows
non-integral target_ids and empty vocabularies to slip through. Update
_validate_shapes in batch_invariant_logp.py to first reject target_ids unless
they are an integer/long type before any cast happens in apply(), and add an
explicit check that logits.size(-1) is greater than zero before proceeding. Keep
the existing logits/target_ids shape checks in place after these new
validations.

In `@rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py`:
- Around line 202-212: The Triton batch-invariant logp path defaults validation
off, which makes it inconsistent with the native backend and can hide invalid
target IDs. Update the default in the BatchInvariantLogp call/apply flow so
validation is enabled by default, matching the native implementation; adjust the
default value on the validate parameter in the relevant method(s) and ensure
BatchInvariantLogp.apply preserves that default behavior.
- Around line 219-245: `_BatchInvariantLogpFunction.apply` currently validates
shapes and index ranges but not the dtype of `target_ids`, so non-integer inputs
can be silently converted later. Add an explicit integer-type check in the
validation path of the batch invariant logp function before
`target_ids.reshape(-1)` is used, and raise a clear `ValueError` for float,
bool, complex, or other non-integer dtypes. Keep the existing shape and range
checks in place and ensure the new check is near the `validate` block in
`batch_invariant_logp.py`.

---

Nitpick comments:
In `@rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py`:
- Around line 22-35: The _fallback_op helper currently catches every Exception,
which can hide real Triton/kernel failures and incorrectly route to the PyTorch
fallback. Narrow the exception handling in batch_invariant_logp so only Triton
import/module-missing cases fall back to NativeBatchInvariantLogpOp, and let
actual kernel errors propagate; update the same logic wherever the fallback is
duplicated in the backward path referenced by TritonBatchInvariantLogpOp and
NativeBatchInvariantLogpOp.

In `@tests/test_batch_invariant_logp.py`:
- Around line 45-56: The _sm90_kernel_available helper is swallowing unrelated
import failures by catching a broad Exception when importing
rl_engine.kernels.ops.base. Narrow this to the specific import error(s) expected
for missing kernel bindings, and keep the availability check in
_sm90_kernel_available focused on genuine “not installed/not built” cases so ABI
or module breakages still surface during tests.
- Around line 831-1020: Add a SM90 validation test class for
BatchInvariantLogpSM90Op to match the existing backend coverage in
TestValidation and TestTritonValidation/TestTritonCPUValidation. In
tests/test_batch_invariant_logp.py, create cases that assert invalid inputs are
rejected for the BatchInvariantLogpSM90Op path, such as 1D logits, mismatched
logits/target shapes, and targets outside the vocabulary range. Keep the tests
grouped near the other TestCudaSM90* classes and use _get_op() so the new
coverage clearly targets the SM90 backend.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b1bd5092-d130-4cca-b342-e7dd0294cf08

📥 Commits

Reviewing files that changed from the base of the PR and between aee6d3b and 603a7e1.

📒 Files selected for processing (14)
  • .gitignore
  • benchmarks/benchmark_batch_invariant_logp.py
  • csrc/cuda/batch_invariant_logp_kernel_sm90.cu
  • csrc/ops.cpp
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/batch-invariant-logp.md
  • rl_engine/_C.pyi
  • rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py
  • rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
  • rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py
  • rl_engine/kernels/registry.py
  • setup.py
  • tests/test_batch_invariant_logp.py

Comment on lines +63 to +92
def _time_ms(fn, warmup, iters):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / iters


def _peak_vram_gb(fn, warmup=3, iters=5):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
torch.cuda.empty_cache()
baseline = torch.cuda.memory_allocated()
torch.cuda.reset_peak_memory_stats()
for _ in range(iters):
fn()
torch.cuda.synchronize()
return (torch.cuda.max_memory_allocated() - baseline) / (1024**3)


def run_benchmark(args):
if device_ctx.device_type not in ["cuda", "xpu", "hip"]:
raise RuntimeError("batch_invariant_logp benchmark requires a compatible GPU device.")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🌐 Web query:

Does torch.xpu support Event, synchronize, and memory stats APIs equivalent to torch.cuda for Intel GPUs?

💡 Result:

Yes, torch.xpu provides APIs that are functionally equivalent to the corresponding torch.cuda APIs for Intel GPUs [1][2][3]. Specifically, PyTorch implements the following support: - Synchronization: The torch.xpu.synchronize function is available to wait for all kernels in all streams on an XPU device to complete, mirroring the behavior of torch.cuda.synchronize [4][2][5]. - Events: The torch.xpu.Event class serves as a wrapper around an XPU event, allowing for synchronization markers and monitoring device progress, consistent with torch.cuda.Event [6][2][7]. - Memory Statistics: PyTorch provides a comprehensive set of memory management APIs under the torch.xpu.memory namespace that mirror torch.cuda memory functions [1][2]. These include: - torch.xpu.memory.memory_stats and memory_stats_as_nested_dict for allocator statistics [1][2][8]. - Metrics such as memory_allocated, memory_reserved, max_memory_allocated, and max_memory_reserved [1][2]. - Utilities like memory_snapshot, mem_get_info, and tools to manage memory fractions [1][2]. These XPU-specific APIs are designed to provide a consistent user experience when migrating code from CUDA to Intel GPU backends [9][3].

Citations:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect the benchmark file and nearby device handling.
git ls-files benchmarks/benchmark_batch_invariant_logp.py
echo "---"
cat -n benchmarks/benchmark_batch_invariant_logp.py | sed -n '1,220p'

echo "--- device_ctx references ---"
rg -n "device_ctx|torch\.cuda|torch\.xpu|torch\.hip|xpu" benchmarks/benchmark_batch_invariant_logp.py

Repository: RL-Align/RL-Kernel

Length of output: 7659


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check PyTorch backend alias behavior from installed docs/source if available in repo.
python3 - <<'PY'
import importlib.util, sys
mods = ["torch"]
for m in mods:
    print(m, importlib.util.find_spec(m) is not None)
PY

# Try to inspect PyTorch API docs via Python introspection if torch is installed.
python3 - <<'PY'
try:
    import torch
    print("torch version:", torch.__version__)
    for name in ["cuda", "xpu"]:
        print(name, "present:", hasattr(torch, name))
        if hasattr(torch, name):
            mod = getattr(torch, name)
            for attr in ["synchronize", "Event", "memory_allocated", "empty_cache", "reset_peak_memory_stats", "max_memory_allocated"]:
                print(name, attr, hasattr(mod, attr))
except Exception as e:
    print("torch import failed:", repr(e))
PY

Repository: RL-Align/RL-Kernel

Length of output: 530


🌐 Web query:

PyTorch Intel XPU backend torch.cuda compatibility layer Event synchronize memory_allocated max_memory_allocated on XPU

💡 Result:

In PyTorch, the XPU backend provides a dedicated namespace (torch.xpu) that includes direct equivalents to many torch.cuda functions, rather than relying on a universal "compatibility layer" that automatically redirects calls [1][2]. Developers should use torch.xpu APIs explicitly when targeting Intel GPUs [3]. Key XPU equivalents for the requested CUDA functions are as follows: - Synchronization: Use torch.xpu.synchronize(device=None) to wait for all kernels in all streams on a specific XPU device to complete [4][5][2]. This mirrors the behavior of torch.cuda.synchronize(). - Memory Allocation Tracking: The torch.xpu namespace provides torch.xpu.memory.memory_allocated(device=None) and torch.xpu.memory.max_memory_allocated(device=None) [6][7][8]. These functions return the current and peak memory occupied by tensors (in bytes), respectively [6][8]. While there is no transparent automatic compatibility layer that maps torch.cuda to torch.xpu globally, modern development practices often involve using conditional logic or utility wrappers to check the active device and dispatch the appropriate calls (e.g., torch.cuda.synchronize() vs. torch.xpu.synchronize()) [3]. Some third-party frameworks (such as vLLM) have implemented specific internal monkey-patching or abstraction layers to map torch.cuda APIs to torch.xpu equivalents to support features like graph capture on XPU [9].

Citations:


🌐 Web query:

site:pytorch.org torch.cuda xpu compatibility alias torch.cuda.Event torch.cuda.synchronize xpu

💡 Result:

In PyTorch, the torch.xpu API provides synchronization and event management functionality specifically for Intel GPUs, mirroring the patterns found in torch.cuda [1][2]. For migration purposes, you should replace references to torch.cuda with torch.xpu [1]. The equivalent components are: - Event Management: Use torch.xpu.Event instead of torch.cuda.Event [3][4][5]. Like its CUDA counterpart, torch.xpu.Event is a synchronization marker used to monitor device progress, measure timing (if enabled), and synchronize XPU streams [3][6]. It supports methods such as record, synchronize, wait, and query [6][5]. - Synchronization: Use torch.xpu.synchronize instead of torch.cuda.synchronize [7][8]. This function waits for all kernels in all streams on the specified XPU device to complete [8]. Additionally, individual streams have their own synchronization methods via torch.xpu.Stream, such as stream.synchronize [9]. For further details on migrating workflows, the official PyTorch documentation on XPU semantics is the recommended resource [1][7].

Citations:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Look for other benchmarks or utilities that handle xpu/hip via torch.cuda or torch.xpu.
rg -n "torch\.xpu|device_type.*xpu|device_type.*hip|torch\.cuda\.(Event|synchronize|memory_allocated|max_memory_allocated|empty_cache|reset_peak_memory_stats)" .

Repository: RL-Align/RL-Kernel

Length of output: 7573


Dispatch timing/memory helpers by backend
run_benchmark allows "xpu", but _time_ms and _peak_vram_gb still hardcode torch.cuda. Use torch.xpu for Intel GPUs and keep torch.cuda for CUDA/ROCm so the benchmark works on every allowed device.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/benchmark_batch_invariant_logp.py` around lines 63 - 92, The
timing and memory helpers are hardcoded to torch.cuda even though run_benchmark
supports xpu, so the benchmark will break on Intel GPUs. Update _time_ms and
_peak_vram_gb to dispatch through the active backend in device_ctx.device_type,
using torch.xpu for xpu and torch.cuda for cuda/hip, including synchronize,
Event creation, empty_cache, memory_allocated, reset_peak_memory_stats, and
max_memory_allocated calls.

Comment on lines +77 to +88
def _peak_vram_gb(fn, warmup=3, iters=5):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
torch.cuda.empty_cache()
baseline = torch.cuda.memory_allocated()
torch.cuda.reset_peak_memory_stats()
for _ in range(iters):
fn()
torch.cuda.synchronize()
return (torch.cuda.max_memory_allocated() - baseline) / (1024**3)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

VRAM measurement ignores CLI --warmup/--iters overrides.

_peak_vram_gb hardcodes warmup=3, iters=5 (Line 77) instead of accepting args.warmup/args.iters like _time_ms does, so --warmup/--iters only affect timing, not the memory measurement, even though it is called right alongside the timing runs (Lines 115-116, 128).

Proposed fix
-def _peak_vram_gb(fn, warmup=3, iters=5):
+def _peak_vram_gb(fn, warmup, iters):
     for _ in range(warmup):
         fn()
     torch.cuda.synchronize()
     torch.cuda.empty_cache()
     baseline = torch.cuda.memory_allocated()
     torch.cuda.reset_peak_memory_stats()
     for _ in range(iters):
         fn()
     torch.cuda.synchronize()
     return (torch.cuda.max_memory_allocated() - baseline) / (1024**3)

And update call sites to pass args.warmup, args.iters.

Also applies to: 113-116, 127-128

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/benchmark_batch_invariant_logp.py` around lines 77 - 88, The VRAM
helper currently hardcodes its warmup and iteration counts, so CLI overrides are
ignored for memory measurement. Update `_peak_vram_gb` to accept the same
`warmup` and `iters` inputs used by `_time_ms`, and change the call sites in
`benchmark_batch_invariant_logp` to pass `args.warmup` and `args.iters` so
timing and VRAM runs stay consistent.

Comment on lines +179 to +200
std::vector<torch::Tensor> batch_invariant_logp_sm90_forward(torch::Tensor logits,
torch::Tensor target,
int64_t ignore_index) {
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor");
TORCH_CHECK(logits.dim() == 2, "logits must be 2-D [N, V]");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
const int N = logits.size(0);
const int V = logits.size(1);
TORCH_CHECK(target.numel() == N, "target must have one id per row: expected ", N, ", got ",
target.numel());

// TMA requires the global row stride (V * elem_size) to be 16-byte aligned.
const int elem_size = static_cast<int>(logits.element_size());
TORCH_CHECK((static_cast<int64_t>(V) * elem_size) % 16 == 0,
"batch_invariant_logp_sm90 requires the vocab row stride (V * elem_size) to be a "
"multiple of 16 bytes; got V=",
V, ", elem_size=", elem_size);

auto opts_f = logits.options().dtype(torch::kFloat);
auto logp = torch::empty({N}, opts_f);
auto lse = torch::empty({N}, opts_f);
auto target_i = target.to(torch::kInt32).contiguous();

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify existing extension conventions for CUDA device guards / launch checks.
rg -n -C2 'CUDAGuard|OptionalCUDAGuard|C10_CUDA_KERNEL_LAUNCH_CHECK|CUDA_KERNEL_LAUNCH_CHECK' csrc

Repository: RL-Align/RL-Kernel

Length of output: 565


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect the SM90 batch-invariant logp kernel entrypoint and nearby CUDA launch code.
FILE="csrc/cuda/batch_invariant_logp_kernel_sm90.cu"
wc -l "$FILE"
sed -n '1,260p' "$FILE"

# Find the Python binding/callers for this kernel.
rg -n -C3 'batch_invariant_logp_sm90|_C\.batch_invariant_logp_sm90|batch_invariant_logp_sm90_forward' .

Repository: RL-Align/RL-Kernel

Length of output: 22607


Guard the SM90 entrypoint with a CUDA device guard and same-device target check.

logits can be on a different CUDA device than the current context, so this launch can hit the wrong GPU. The exported binding should set the device from logits and reject non-CUDA or cross-device target tensors; the Python wrapper already normalizes target, but direct _C.batch_invariant_logp_sm90(...) calls do not.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@csrc/cuda/batch_invariant_logp_kernel_sm90.cu` around lines 179 - 200, The
SM90 entrypoint in batch_invariant_logp_sm90_forward currently assumes the
current CUDA context matches logits, and it also only validates target shape,
not device. Add a CUDA device guard based on logits before any allocation or
launch, and add a TORCH_CHECK that target is on the same CUDA device as logits
(and CUDA at all) alongside the existing logits checks. Keep the validation and
setup near batch_invariant_logp_sm90_forward so direct calls through the
exported binding are safe even when the Python wrapper is bypassed.

Comment on lines +85 to +100
**Forward + backward**

| shape (N x V) | native ms | triton ms | cuda ms | cuda vs native | cuda vs triton | native MB | triton MB | cuda MB |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| 4096 x 32768 | 3.549 | 0.463 | 0.407 | 8.7x | 1.14x | 1792 | 512 | 512 |
| 4096 x 128256 | 13.173 | 1.750 | 1.510 | 8.7x | 1.16x | 7014 | 2004 | 2004 |
| 4096 x 151936 | 15.632 | 2.072 | 1.788 | 8.7x | 1.16x | 8310 | 2376 | 2376 |
| 8192 x 128256 | 26.211 | 3.416 | 2.955 | 8.9x | 1.16x | 14028 | 4008 | 4008 |

- Forward: ~1.7x vs Triton, ~15x vs native, with ~0 extra VRAM — the vocab is
reduced to per-row scalars, so no `[N, V]` intermediate is materialized.
- Forward + backward: ~1.15x vs Triton, ~8.8x vs native, with memory equal to
Triton. The backward's `[N, V]` cost is `grad_logits` itself (one gradient per
input logit, unavoidable for any backend); the streamed backends avoid native's
extra `[N, V]` `softmax` / `log_softmax` intermediates by recomputing from the
saved per-row `lse`.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

"Forward + backward" table isn't reproducible from the benchmark script.

The doc claims measured "Forward + backward" latency/VRAM numbers (Lines 85-92), but benchmarks/benchmark_batch_invariant_logp.py's run_benchmark only exercises the forward pass under torch.no_grad() — there is no .backward() call or backward timing/VRAM path in the script. Either the script needs a backward benchmark added, or this table should be removed/clarified as sourced from an out-of-tree variant.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/operators/batch-invariant-logp.md` around lines 85 - 100, The “Forward +
backward” measurements in the batch-invariant logp docs are not reproducible
from the current benchmark path. Update the benchmark source used by
run_benchmark so it actually exercises and times a backward pass (including VRAM
tracking), or else remove/qualify this table as coming from a different
benchmark variant. Use the existing benchmark identifiers in
benchmark_batch_invariant_logp.py and its run_benchmark flow to keep the
documentation aligned with the script.

Comment on lines +12 to +19
def _sm90_supported(logits: torch.Tensor) -> bool:
"""Whether the TMA forward can run these logits directly.
bf16/fp32 only, and the TMA descriptor needs the vocab row stride
(``V * element_size``) to be a multiple of 16 bytes.
"""
if not logits.is_cuda or logits.dtype not in (torch.bfloat16, torch.float32):
return False
return (logits.size(-1) * logits.element_size()) % 16 == 0

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Include the logits device capability in _sm90_supported.

The wrapper can be used directly, and cached registry instances can receive tensors from a different CUDA device than the one probed during registry initialization. Without a per-input cc_major == 9 check here, aligned fp32/bf16 tensors on non-Hopper GPUs can still enter the SM90 TMA kernel instead of falling back.

Suggested fix
 def _sm90_supported(logits: torch.Tensor) -> bool:
@@
     if not logits.is_cuda or logits.dtype not in (torch.bfloat16, torch.float32):
         return False
+    cc_major, _ = torch.cuda.get_device_capability(logits.device)
+    if cc_major != 9:
+        return False
     return (logits.size(-1) * logits.element_size()) % 16 == 0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _sm90_supported(logits: torch.Tensor) -> bool:
"""Whether the TMA forward can run these logits directly.
bf16/fp32 only, and the TMA descriptor needs the vocab row stride
(``V * element_size``) to be a multiple of 16 bytes.
"""
if not logits.is_cuda or logits.dtype not in (torch.bfloat16, torch.float32):
return False
return (logits.size(-1) * logits.element_size()) % 16 == 0
def _sm90_supported(logits: torch.Tensor) -> bool:
"""Whether the TMA forward can run these logits directly.
bf16/fp32 only, and the TMA descriptor needs the vocab row stride
(``V * element_size``) to be a multiple of 16 bytes.
"""
if not logits.is_cuda or logits.dtype not in (torch.bfloat16, torch.float32):
return False
cc_major, _ = torch.cuda.get_device_capability(logits.device)
if cc_major != 9:
return False
return (logits.size(-1) * logits.element_size()) % 16 == 0
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py` around lines 12 -
19, The `_sm90_supported` guard only checks tensor device type and alignment, so
tensors from non-Hopper CUDA devices can still route into the SM90 TMA path.
Update `_sm90_supported` to also verify the input logits’ CUDA capability
(cc_major == 9) on the tensor’s current device, alongside the existing dtype and
stride-multiple-of-16 checks, so cached registry entries correctly fall back for
non-SM90 GPUs.

Comment on lines +114 to +130
def __call__(
self,
logits: torch.Tensor,
target_ids: torch.Tensor,
ignore_index: int = -100,
*,
validate: bool = False,
) -> torch.Tensor:
return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate)

def apply(
self,
logits: torch.Tensor,
target_ids: torch.Tensor,
ignore_index: int = -100,
*,
validate: bool = False,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔒 Security & Privacy | 🟠 Major | ⚡ Quick win

Keep target validation on by default for the SM90 wrapper.

The native operator defaults to validate=True, but this wrapper defaults to False before calling an unchecked CUDA gather. A bad non-ignored target can read outside [0, V) by default; make opt-out explicit for callers that have already validated targets.

Suggested fix
     def __call__(
@@
-        validate: bool = False,
+        validate: bool = True,
@@
     def apply(
@@
-        validate: bool = False,
+        validate: bool = True,

Also applies to: 145-157

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py` around lines 114 -
130, The SM90 wrapper currently disables target validation by default in
BatchInvariantLogp.__call__ and apply, which can allow unsafe CUDA gathers for
bad non-ignored targets. Update the default in these methods so validate matches
the native operator’s behavior (true by default), and keep an explicit opt-out
path for callers that already validated targets. Preserve the existing
ignore_index handling and ensure the wrapper forwards the validate flag
consistently through BatchInvariantLogp.

Comment on lines +115 to +124
def _validate_shapes(logits: torch.Tensor, target_ids: torch.Tensor) -> None:
if logits.dim() < 2:
raise ValueError(
f"logits must be at least 2-D ([*lead, V]), got shape {tuple(logits.shape)}"
)
if logits.shape[:-1] != target_ids.shape:
raise ValueError(
f"logits leading shape {tuple(logits.shape[:-1])} must match "
f"target_ids shape {tuple(target_ids.shape)}"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Locate the file and inspect the relevant section plus nearby usage.
git ls-files 'rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py'
echo '---'
wc -l rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
echo '--- lines 1-220 ---'
cat -n rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py | sed -n '1,220p'

Repository: RL-Align/RL-Kernel

Length of output: 5615


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Probe PyTorch behavior for the edge cases mentioned in the review comment.
python3 - <<'PY'
import torch

cases = [
    ("bool target", torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([True, False])),
    ("float target", torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([0.0, 1.0])),
    ("int target", torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([0, 1])),
]

for name, logits, target in cases:
    print(f"\nCASE: {name}")
    try:
        gathered = logits.log_softmax(-1).gather(-1, target.long().unsqueeze(-1)).squeeze(-1)
        print("gather with target.long():", gathered)
    except Exception as e:
        print("gather error:", type(e).__name__, e)

# Empty vocab edge
logits = torch.empty(2, 0)
target = torch.empty(2, dtype=torch.long)
print("\nCASE: empty vocab")
try:
    x = logits.log_softmax(-1)
    print("log_softmax:", x)
except Exception as e:
    print("log_softmax error:", type(e).__name__, e)

try:
    x = logits.gather(-1, target.unsqueeze(-1))
    print("gather:", x)
except Exception as e:
    print("gather error:", type(e).__name__, e)
PY

Repository: RL-Align/RL-Kernel

Length of output: 519


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find any direct call sites or existing validation that might already reject these inputs.
rg -n "target_ids.*long|\.long\(\).*target_ids|_validate_shapes\(" rl_engine/kernels/ops/pytorch/loss -S

Repository: RL-Align/RL-Kernel

Length of output: 694


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo '== file =='
wc -l rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
echo '--- relevant section ---'
sed -n '1,220p' rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py | nl -ba | sed -n '1,220p'

echo '== search for related casting/validation =='
rg -n "_validate_shapes|target_ids.*long|\\.long\\(\\)|gather\\(" rl_engine/kernels/ops/pytorch/loss -S || true

echo '== torch edge-case probe =='
python3 - <<'PY'
import torch

def probe(logits, target_ids, label):
    print(f"\nCASE: {label}")
    print("logits.shape=", tuple(logits.shape), "target_ids.shape=", tuple(target_ids.shape), "target_ids.dtype=", target_ids.dtype)
    try:
        t = target_ids.long()
        print("target_ids.long() dtype=", t.dtype, "values=", t)
    except Exception as e:
        print("target_ids.long() error:", type(e).__name__, e)
        return
    try:
        ls = logits.log_softmax(-1)
        print("log_softmax shape=", tuple(ls.shape))
    except Exception as e:
        print("log_softmax error:", type(e).__name__, e)
        ls = None
    if ls is not None:
        try:
            gathered = ls.gather(-1, t.unsqueeze(-1)).squeeze(-1)
            print("gather result shape=", tuple(gathered.shape), "values=", gathered)
        except Exception as e:
            print("gather error:", type(e).__name__, e)

probe(torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([True, False]), "bool target")
probe(torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([0.0, 1.0]), "float target")
probe(torch.empty(2, 0), torch.empty(2, dtype=torch.long), "empty vocab")
PY

Repository: RL-Align/RL-Kernel

Length of output: 295


Reject non-integral targets and empty vocab. target_ids is cast to long in apply(), so float/bool inputs are silently truncated before any range checks; validate the dtype first. Also reject logits.size(-1) == 0 up front, since the later reduction/indexing path fails on empty vocab.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py` around lines 115
- 124, The shape validator in _validate_shapes currently allows non-integral
target_ids and empty vocabularies to slip through. Update _validate_shapes in
batch_invariant_logp.py to first reject target_ids unless they are an
integer/long type before any cast happens in apply(), and add an explicit check
that logits.size(-1) is greater than zero before proceeding. Keep the existing
logits/target_ids shape checks in place after these new validations.

Comment on lines +202 to +212
validate: bool = False,
) -> torch.Tensor:
return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate)

def apply(
self,
logits: torch.Tensor,
target_ids: torch.Tensor,
ignore_index: int = -100,
*,
validate: bool = False,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Keep validation enabled by default for backend consistency.

Native defaults to validate=True, but Triton defaults to False; invalid negative/out-of-range targets then produce silent bogus logprobs instead of the documented validation failure.

Proposed default alignment
         ignore_index: int = -100,
         *,
-        validate: bool = False,
+        validate: bool = True,
@@
         ignore_index: int = -100,
         *,
-        validate: bool = False,
+        validate: bool = True,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
validate: bool = False,
) -> torch.Tensor:
return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate)
def apply(
self,
logits: torch.Tensor,
target_ids: torch.Tensor,
ignore_index: int = -100,
*,
validate: bool = False,
validate: bool = True,
) -> torch.Tensor:
return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate)
def apply(
self,
logits: torch.Tensor,
target_ids: torch.Tensor,
ignore_index: int = -100,
*,
validate: bool = True,
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py` around lines 202 -
212, The Triton batch-invariant logp path defaults validation off, which makes
it inconsistent with the native backend and can hide invalid target IDs. Update
the default in the BatchInvariantLogp call/apply flow so validation is enabled
by default, matching the native implementation; adjust the default value on the
validate parameter in the relevant method(s) and ensure BatchInvariantLogp.apply
preserves that default behavior.

Comment on lines +219 to +245
if logits.dim() < 2:
raise ValueError(
f"logits must be at least 2-D ([*lead, V]), got shape "
f"{tuple(logits.shape)}"
)
if logits.shape[:-1] != target_ids.shape:
raise ValueError(
f"logits leading shape {tuple(logits.shape[:-1])} must match "
f"target_ids shape {tuple(target_ids.shape)}"
)

if validate:
vocab_size = logits.size(-1)
target_flat = target_ids.reshape(-1)
valid_targets = target_flat[target_flat != ignore_index]
if valid_targets.numel() > 0 and (
(valid_targets < 0).any() or (valid_targets >= vocab_size).any()
):
bad = valid_targets[
(valid_targets < 0) | (valid_targets >= vocab_size)
]
raise ValueError(
f"target_ids contains values outside [0, {vocab_size}): "
f"{bad.tolist()}"
)

return _BatchInvariantLogpFunction.apply(logits, target_ids, ignore_index)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Map the file and nearby symbols first
ast-grep outline rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py --view expanded

# Inspect the relevant section around the reported lines and the cast mentioned in the comment
cat -n rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py | sed -n '100,170p'
cat -n rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py | sed -n '180,260p'

# Search for any existing dtype / empty-vocab validation in this module and adjacent loss implementations
rg -n "target_ids.*dtype|bool|is_floating_point|is_complex|vocab|empty|size\\(-1\\)|ignore_index|ValueError\\(" rl_engine/kernels/ops/triton/loss -g '*.py'

Repository: RL-Align/RL-Kernel

Length of output: 15041


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect the sibling implementation that already performs target validation.
cat -n rl_engine/kernels/ops/triton/loss/linear_logp.py | sed -n '170,215p'

# Check whether this module already has any call sites that rely on non-integer targets or empty vocab.
rg -n "TritonBatchInvariantLogpOp|batch_invariant_logp|linear_logp|target_ids" rl_engine -g '*.py'

Repository: RL-Align/RL-Kernel

Length of output: 15615


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Read the native and CUDA batch-invariant logp implementations around validation.
cat -n rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py | sed -n '1,140p'
echo '---'
cat -n rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py | sed -n '1,190p'

# Probe the Triton wrapper semantics in code only: locate the exact cast and any empty-vocab assumptions.
python3 - <<'PY'
from pathlib import Path
p = Path("rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py")
text = p.read_text()
for needle in [
    "target_ids.reshape(-1).to(",
    "vocab_size = logits.size(-1)",
    "triton.cdiv(vocab_size, _BLOCK_V)",
    "if validate:",
]:
    print(f"\n=== {needle} ===")
    idx = text.find(needle)
    print(text[max(0, idx-200):idx+500] if idx != -1 else "not found")
PY

Repository: RL-Align/RL-Kernel

Length of output: 15373


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Read the Triton forward kernel and any tests covering batch_invariant_logp edge cases.
cat -n rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py | sed -n '1,120p'
echo '---'
fd -a 'batch_invariant_logp' . | sed -n '1,50p'
echo '---'
rg -n "batch_invariant_logp|TritonBatchInvariantLogpOp|NativeBatchInvariantLogpOp|empty vocab|dtype|float|bool" . -g '*test*.py' -g '*tests*.py'

Repository: RL-Align/RL-Kernel

Length of output: 32428


🏁 Script executed:

#!/bin/bash
set -euo pipefail

python3 - <<'PY'
import torch

samples = [
    torch.tensor([1.2, -3.7, 4.0], dtype=torch.float32),
    torch.tensor([True, False, True], dtype=torch.bool),
    torch.tensor([1+2j, 3+0j], dtype=torch.complex64),
]
for t in samples:
    print("src", t.dtype, t.tolist())
    try:
        out = t.to(dtype=torch.int64)
        print("dst", out.dtype, out.tolist())
    except Exception as e:
        print("err", type(e).__name__, str(e))

print("--- empty vocab shape probe ---")
logits = torch.empty(2, 0, dtype=torch.float32)
target = torch.zeros(2, dtype=torch.long)
print("logits shape", tuple(logits.shape), "target shape", tuple(target.shape))
try:
    row_max = logits.reshape(-1, logits.size(-1)).float().max(dim=-1).values
    print("native row_max", row_max)
except Exception as e:
    print("native err", type(e).__name__, str(e))
PY

# Check whether tests already cover empty-vocab or dtype coercion.
rg -n "empty vocab|vocab.*0|dtype.*target|float.*target|bool.*target|complex.*target|target_ids.*float|target_ids.*bool" tests/test_batch_invariant_logp.py rl_engine/kernels/ops -g '*.py'

Repository: RL-Align/RL-Kernel

Length of output: 2205


Require an integer dtype check for target_ids.
target_ids.reshape(-1).to(dtype=torch.int64) silently truncates floats and accepts bool/complex inputs, so invalid targets can become wrong indices instead of failing fast.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py` around lines 219 -
245, `_BatchInvariantLogpFunction.apply` currently validates shapes and index
ranges but not the dtype of `target_ids`, so non-integer inputs can be silently
converted later. Add an explicit integer-type check in the validation path of
the batch invariant logp function before `target_ids.reshape(-1)` is used, and
raise a clear `ValueError` for float, bool, complex, or other non-integer
dtypes. Keep the existing shape and range checks in place and ensure the new
check is near the `validate` block in `batch_invariant_logp.py`.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants