[WS1][kernels] Batch-invariant logprob (CUDA)#204
Conversation
Implements batch_invariant_logp for selected-token log probabilities from materialized logits with row-local, batch-invariant semantics. - PyTorch NativeBatchInvariantLogpOp: FP32 row-wise reference with ignore_index handling and target validation. - Triton TritonBatchInvariantLogpOp: online-softmax forward with fixed vocab tiling and tile-wise backward using saved per-row lse. - Registry dispatch, PyTorch/Triton tests, and operator docs.
📝 WalkthroughWalkthroughAdds a new ChangesBatch-invariant logp operator
Estimated code review effort: 4 (Complex) | ~60 minutes Sequence Diagram(s)sequenceDiagram
participant Caller
participant KernelRegistry
participant BatchInvariantLogpSM90Op
participant _BatchInvariantLogpSM90Function
participant TritonBatchInvariantLogpOp
participant NativeBatchInvariantLogpOp
Caller->>KernelRegistry: get_op("batch_invariant_logp")
KernelRegistry-->>Caller: selected backend op
Caller->>BatchInvariantLogpSM90Op: __call__(logits, target_ids, ignore_index)
BatchInvariantLogpSM90Op->>BatchInvariantLogpSM90Op: check shape, dtype, alignment
alt SM90 supported
BatchInvariantLogpSM90Op->>_BatchInvariantLogpSM90Function: apply()
_BatchInvariantLogpSM90Function->>_BatchInvariantLogpSM90Function: call compiled SM90 kernel forward
_BatchInvariantLogpSM90Function-->>BatchInvariantLogpSM90Op: logp, lse
else unsupported input
BatchInvariantLogpSM90Op->>TritonBatchInvariantLogpOp: apply() (fallback)
alt Triton unavailable
TritonBatchInvariantLogpOp->>NativeBatchInvariantLogpOp: apply() (fallback)
NativeBatchInvariantLogpOp-->>BatchInvariantLogpSM90Op: logp
else Triton available
TritonBatchInvariantLogpOp-->>BatchInvariantLogpSM90Op: logp
end
end
BatchInvariantLogpSM90Op-->>Caller: selected log-probability
Possibly related PRs
Suggested labels: Suggested reviewers: 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 9
🧹 Nitpick comments (3)
tests/test_batch_invariant_logp.py (2)
45-56: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueNarrow the broad exception catch.
Catching bare
Exceptionto detect kernel availability masks unrelated failures (e.g., a broken_Cmodule import due to an ABI mismatch) as "SM90 not available," silently skipping tests instead of surfacing the real problem.🔧 Suggested narrowing
try: from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE - except Exception: + except ImportError: return False🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_batch_invariant_logp.py` around lines 45 - 56, The _sm90_kernel_available helper is swallowing unrelated import failures by catching a broad Exception when importing rl_engine.kernels.ops.base. Narrow this to the specific import error(s) expected for missing kernel bindings, and keep the availability check in _sm90_kernel_available focused on genuine “not installed/not built” cases so ABI or module breakages still surface during tests.Source: Linters/SAST tools
831-1020: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winMissing validation-error test coverage for the SM90 backend.
Every other backend has a dedicated validation test class (
TestValidationfor native at lines 294-324,TestTritonValidation/TestTritonCPUValidationat lines 788-824), but there is no equivalent forBatchInvariantLogpSM90Op(e.g., rejecting 1D logits, shape mismatches, or out-of-range targets). Given the PR objective explicitly calls out "expanded tests covering correctness, backward, ignore-index, alignment, and fallback behavior," validation parity across backends would close a gap.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_batch_invariant_logp.py` around lines 831 - 1020, Add a SM90 validation test class for BatchInvariantLogpSM90Op to match the existing backend coverage in TestValidation and TestTritonValidation/TestTritonCPUValidation. In tests/test_batch_invariant_logp.py, create cases that assert invalid inputs are rejected for the BatchInvariantLogpSM90Op path, such as 1D logits, mismatched logits/target shapes, and targets outside the vocabulary range. Keep the tests grouped near the other TestCudaSM90* classes and use _get_op() so the new coverage clearly targets the SM90 backend.rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py (1)
22-35: 🩺 Stability & Availability | 🔵 Trivial | ⚡ Quick winNarrow the fallback exception handling.
Only the import path should fall back for missing Triton. Catching every
Exceptionaround the backward kernel can mask real kernel failures and silently take the full[N, V]PyTorch fallback.Suggested shape
- try: + try: import triton from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( _BLOCK_V, _batch_invariant_logp_bwd_kernel, ) + except (ImportError, ModuleNotFoundError): # pragma: no cover - Triton missing + ... + else: grid = (num_tokens, triton.cdiv(vocab_size, _BLOCK_V)) _batch_invariant_logp_bwd_kernel[grid]( ... ) - except Exception: # pragma: no cover - Triton missing + # PyTorch fallback for the missing-Triton branch only.Also applies to: 67-89
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py` around lines 22 - 35, The _fallback_op helper currently catches every Exception, which can hide real Triton/kernel failures and incorrectly route to the PyTorch fallback. Narrow the exception handling in batch_invariant_logp so only Triton import/module-missing cases fall back to NativeBatchInvariantLogpOp, and let actual kernel errors propagate; update the same logic wherever the fallback is duplicated in the backward path referenced by TritonBatchInvariantLogpOp and NativeBatchInvariantLogpOp.Source: Linters/SAST tools
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@benchmarks/benchmark_batch_invariant_logp.py`:
- Around line 77-88: The VRAM helper currently hardcodes its warmup and
iteration counts, so CLI overrides are ignored for memory measurement. Update
`_peak_vram_gb` to accept the same `warmup` and `iters` inputs used by
`_time_ms`, and change the call sites in `benchmark_batch_invariant_logp` to
pass `args.warmup` and `args.iters` so timing and VRAM runs stay consistent.
- Around line 63-92: The timing and memory helpers are hardcoded to torch.cuda
even though run_benchmark supports xpu, so the benchmark will break on Intel
GPUs. Update _time_ms and _peak_vram_gb to dispatch through the active backend
in device_ctx.device_type, using torch.xpu for xpu and torch.cuda for cuda/hip,
including synchronize, Event creation, empty_cache, memory_allocated,
reset_peak_memory_stats, and max_memory_allocated calls.
In `@csrc/cuda/batch_invariant_logp_kernel_sm90.cu`:
- Around line 179-200: The SM90 entrypoint in batch_invariant_logp_sm90_forward
currently assumes the current CUDA context matches logits, and it also only
validates target shape, not device. Add a CUDA device guard based on logits
before any allocation or launch, and add a TORCH_CHECK that target is on the
same CUDA device as logits (and CUDA at all) alongside the existing logits
checks. Keep the validation and setup near batch_invariant_logp_sm90_forward so
direct calls through the exported binding are safe even when the Python wrapper
is bypassed.
In `@docs/operators/batch-invariant-logp.md`:
- Around line 85-100: The “Forward + backward” measurements in the
batch-invariant logp docs are not reproducible from the current benchmark path.
Update the benchmark source used by run_benchmark so it actually exercises and
times a backward pass (including VRAM tracking), or else remove/qualify this
table as coming from a different benchmark variant. Use the existing benchmark
identifiers in benchmark_batch_invariant_logp.py and its run_benchmark flow to
keep the documentation aligned with the script.
In `@rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py`:
- Around line 114-130: The SM90 wrapper currently disables target validation by
default in BatchInvariantLogp.__call__ and apply, which can allow unsafe CUDA
gathers for bad non-ignored targets. Update the default in these methods so
validate matches the native operator’s behavior (true by default), and keep an
explicit opt-out path for callers that already validated targets. Preserve the
existing ignore_index handling and ensure the wrapper forwards the validate flag
consistently through BatchInvariantLogp.
- Around line 12-19: The `_sm90_supported` guard only checks tensor device type
and alignment, so tensors from non-Hopper CUDA devices can still route into the
SM90 TMA path. Update `_sm90_supported` to also verify the input logits’ CUDA
capability (cc_major == 9) on the tensor’s current device, alongside the
existing dtype and stride-multiple-of-16 checks, so cached registry entries
correctly fall back for non-SM90 GPUs.
In `@rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py`:
- Around line 115-124: The shape validator in _validate_shapes currently allows
non-integral target_ids and empty vocabularies to slip through. Update
_validate_shapes in batch_invariant_logp.py to first reject target_ids unless
they are an integer/long type before any cast happens in apply(), and add an
explicit check that logits.size(-1) is greater than zero before proceeding. Keep
the existing logits/target_ids shape checks in place after these new
validations.
In `@rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py`:
- Around line 202-212: The Triton batch-invariant logp path defaults validation
off, which makes it inconsistent with the native backend and can hide invalid
target IDs. Update the default in the BatchInvariantLogp call/apply flow so
validation is enabled by default, matching the native implementation; adjust the
default value on the validate parameter in the relevant method(s) and ensure
BatchInvariantLogp.apply preserves that default behavior.
- Around line 219-245: `_BatchInvariantLogpFunction.apply` currently validates
shapes and index ranges but not the dtype of `target_ids`, so non-integer inputs
can be silently converted later. Add an explicit integer-type check in the
validation path of the batch invariant logp function before
`target_ids.reshape(-1)` is used, and raise a clear `ValueError` for float,
bool, complex, or other non-integer dtypes. Keep the existing shape and range
checks in place and ensure the new check is near the `validate` block in
`batch_invariant_logp.py`.
---
Nitpick comments:
In `@rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py`:
- Around line 22-35: The _fallback_op helper currently catches every Exception,
which can hide real Triton/kernel failures and incorrectly route to the PyTorch
fallback. Narrow the exception handling in batch_invariant_logp so only Triton
import/module-missing cases fall back to NativeBatchInvariantLogpOp, and let
actual kernel errors propagate; update the same logic wherever the fallback is
duplicated in the backward path referenced by TritonBatchInvariantLogpOp and
NativeBatchInvariantLogpOp.
In `@tests/test_batch_invariant_logp.py`:
- Around line 45-56: The _sm90_kernel_available helper is swallowing unrelated
import failures by catching a broad Exception when importing
rl_engine.kernels.ops.base. Narrow this to the specific import error(s) expected
for missing kernel bindings, and keep the availability check in
_sm90_kernel_available focused on genuine “not installed/not built” cases so ABI
or module breakages still surface during tests.
- Around line 831-1020: Add a SM90 validation test class for
BatchInvariantLogpSM90Op to match the existing backend coverage in
TestValidation and TestTritonValidation/TestTritonCPUValidation. In
tests/test_batch_invariant_logp.py, create cases that assert invalid inputs are
rejected for the BatchInvariantLogpSM90Op path, such as 1D logits, mismatched
logits/target shapes, and targets outside the vocabulary range. Keep the tests
grouped near the other TestCudaSM90* classes and use _get_op() so the new
coverage clearly targets the SM90 backend.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b1bd5092-d130-4cca-b342-e7dd0294cf08
📒 Files selected for processing (14)
.gitignorebenchmarks/benchmark_batch_invariant_logp.pycsrc/cuda/batch_invariant_logp_kernel_sm90.cucsrc/ops.cppdocs/.nav.ymldocs/operators/README.mddocs/operators/batch-invariant-logp.mdrl_engine/_C.pyirl_engine/kernels/ops/cuda/loss/batch_invariant_logp.pyrl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.pyrl_engine/kernels/ops/triton/loss/batch_invariant_logp.pyrl_engine/kernels/registry.pysetup.pytests/test_batch_invariant_logp.py
| def _time_ms(fn, warmup, iters): | ||
| for _ in range(warmup): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| start = torch.cuda.Event(enable_timing=True) | ||
| end = torch.cuda.Event(enable_timing=True) | ||
| start.record() | ||
| for _ in range(iters): | ||
| fn() | ||
| end.record() | ||
| torch.cuda.synchronize() | ||
| return start.elapsed_time(end) / iters | ||
|
|
||
|
|
||
| def _peak_vram_gb(fn, warmup=3, iters=5): | ||
| for _ in range(warmup): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| torch.cuda.empty_cache() | ||
| baseline = torch.cuda.memory_allocated() | ||
| torch.cuda.reset_peak_memory_stats() | ||
| for _ in range(iters): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| return (torch.cuda.max_memory_allocated() - baseline) / (1024**3) | ||
|
|
||
|
|
||
| def run_benchmark(args): | ||
| if device_ctx.device_type not in ["cuda", "xpu", "hip"]: | ||
| raise RuntimeError("batch_invariant_logp benchmark requires a compatible GPU device.") |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
🧩 Analysis chain
🌐 Web query:
Does torch.xpu support Event, synchronize, and memory stats APIs equivalent to torch.cuda for Intel GPUs?
💡 Result:
Yes, torch.xpu provides APIs that are functionally equivalent to the corresponding torch.cuda APIs for Intel GPUs [1][2][3]. Specifically, PyTorch implements the following support: - Synchronization: The torch.xpu.synchronize function is available to wait for all kernels in all streams on an XPU device to complete, mirroring the behavior of torch.cuda.synchronize [4][2][5]. - Events: The torch.xpu.Event class serves as a wrapper around an XPU event, allowing for synchronization markers and monitoring device progress, consistent with torch.cuda.Event [6][2][7]. - Memory Statistics: PyTorch provides a comprehensive set of memory management APIs under the torch.xpu.memory namespace that mirror torch.cuda memory functions [1][2]. These include: - torch.xpu.memory.memory_stats and memory_stats_as_nested_dict for allocator statistics [1][2][8]. - Metrics such as memory_allocated, memory_reserved, max_memory_allocated, and max_memory_reserved [1][2]. - Utilities like memory_snapshot, mem_get_info, and tools to manage memory fractions [1][2]. These XPU-specific APIs are designed to provide a consistent user experience when migrating code from CUDA to Intel GPU backends [9][3].
Citations:
- 1: https://docs.pytorch.org/docs/main/xpu.html
- 2: https://docs.pytorch.org/docs/stable/xpu.md
- 3: [RFC][API-Unstable] Intel GPU distributed Backend integration in
torch-xpu-opsand registeration in PyTorch pytorch/pytorch#141741 - 4: https://docs.pytorch.org/docs/main/generated/torch.xpu.synchronize.html
- 5: https://docs.pytorch.org/docs/stable/generated/torch.xpu.synchronize.md
- 6: https://docs.pytorch.org/docs/stable/generated/torch.xpu.Event.html
- 7: https://docs.pytorch.org/docs/2.12/xpu.aliases.html
- 8: https://docs.pytorch.org/docs/2.11/generated/torch.xpu.memory.memory_stats.html
- 9: https://docs.pytorch.org/docs/stable/notes/get_start_xpu.md
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect the benchmark file and nearby device handling.
git ls-files benchmarks/benchmark_batch_invariant_logp.py
echo "---"
cat -n benchmarks/benchmark_batch_invariant_logp.py | sed -n '1,220p'
echo "--- device_ctx references ---"
rg -n "device_ctx|torch\.cuda|torch\.xpu|torch\.hip|xpu" benchmarks/benchmark_batch_invariant_logp.pyRepository: RL-Align/RL-Kernel
Length of output: 7659
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check PyTorch backend alias behavior from installed docs/source if available in repo.
python3 - <<'PY'
import importlib.util, sys
mods = ["torch"]
for m in mods:
print(m, importlib.util.find_spec(m) is not None)
PY
# Try to inspect PyTorch API docs via Python introspection if torch is installed.
python3 - <<'PY'
try:
import torch
print("torch version:", torch.__version__)
for name in ["cuda", "xpu"]:
print(name, "present:", hasattr(torch, name))
if hasattr(torch, name):
mod = getattr(torch, name)
for attr in ["synchronize", "Event", "memory_allocated", "empty_cache", "reset_peak_memory_stats", "max_memory_allocated"]:
print(name, attr, hasattr(mod, attr))
except Exception as e:
print("torch import failed:", repr(e))
PYRepository: RL-Align/RL-Kernel
Length of output: 530
🌐 Web query:
PyTorch Intel XPU backend torch.cuda compatibility layer Event synchronize memory_allocated max_memory_allocated on XPU
💡 Result:
In PyTorch, the XPU backend provides a dedicated namespace (torch.xpu) that includes direct equivalents to many torch.cuda functions, rather than relying on a universal "compatibility layer" that automatically redirects calls [1][2]. Developers should use torch.xpu APIs explicitly when targeting Intel GPUs [3]. Key XPU equivalents for the requested CUDA functions are as follows: - Synchronization: Use torch.xpu.synchronize(device=None) to wait for all kernels in all streams on a specific XPU device to complete [4][5][2]. This mirrors the behavior of torch.cuda.synchronize(). - Memory Allocation Tracking: The torch.xpu namespace provides torch.xpu.memory.memory_allocated(device=None) and torch.xpu.memory.max_memory_allocated(device=None) [6][7][8]. These functions return the current and peak memory occupied by tensors (in bytes), respectively [6][8]. While there is no transparent automatic compatibility layer that maps torch.cuda to torch.xpu globally, modern development practices often involve using conditional logic or utility wrappers to check the active device and dispatch the appropriate calls (e.g., torch.cuda.synchronize() vs. torch.xpu.synchronize()) [3]. Some third-party frameworks (such as vLLM) have implemented specific internal monkey-patching or abstraction layers to map torch.cuda APIs to torch.xpu equivalents to support features like graph capture on XPU [9].
Citations:
- 1: https://docs.pytorch.org/docs/main/xpu.html
- 2: https://github.com/pytorch/pytorch/blob/e9ebbd3b/torch/xpu/__init__.py
- 3: https://discuss.pytorch.org/t/timings-for-intel-arc-graphics-xpu-vs-nvidia-rtx-3000-gpu-on-a-laptop/218200
- 4: https://docs.pytorch.org/docs/main/generated/torch.xpu.synchronize.html
- 5: https://docs.pytorch.org/docs/stable/generated/torch.xpu.synchronize.md
- 6: https://docs.pytorch.org/docs/stable/generated/torch.xpu.memory.max_memory_allocated.html
- 7: https://docs.pytorch.org/docs/stable/generated/torch.xpu.memory.memory_allocated.html
- 8: https://docs.pytorch.org/docs/2.11/generated/torch.xpu.memory.memory_allocated.html
- 9: [XPU]Support CUDAGraph on XPU Platform vllm-project/vllm#34482
🌐 Web query:
site:pytorch.org torch.cuda xpu compatibility alias torch.cuda.Event torch.cuda.synchronize xpu
💡 Result:
In PyTorch, the torch.xpu API provides synchronization and event management functionality specifically for Intel GPUs, mirroring the patterns found in torch.cuda [1][2]. For migration purposes, you should replace references to torch.cuda with torch.xpu [1]. The equivalent components are: - Event Management: Use torch.xpu.Event instead of torch.cuda.Event [3][4][5]. Like its CUDA counterpart, torch.xpu.Event is a synchronization marker used to monitor device progress, measure timing (if enabled), and synchronize XPU streams [3][6]. It supports methods such as record, synchronize, wait, and query [6][5]. - Synchronization: Use torch.xpu.synchronize instead of torch.cuda.synchronize [7][8]. This function waits for all kernels in all streams on the specified XPU device to complete [8]. Additionally, individual streams have their own synchronization methods via torch.xpu.Stream, such as stream.synchronize [9]. For further details on migrating workflows, the official PyTorch documentation on XPU semantics is the recommended resource [1][7].
Citations:
- 1: https://docs.pytorch.org/docs/stable/notes/get_start_xpu.md
- 2: https://docs.pytorch.org/docs/stable/xpu.html
- 3: https://docs.pytorch.org/docs/main/generated/torch.xpu.streams.Event.html
- 4: https://docs.pytorch.org/docs/2.11/generated/torch.xpu.Event.html
- 5: https://docs.pytorch.org/docs/stable/generated/torch.xpu.Event.md
- 6: https://docs.pytorch.org/docs/stable/generated/torch.xpu.streams.Event.md
- 7: https://docs.pytorch.org/docs/stable/xpu.md
- 8: https://docs.pytorch.org/docs/stable/generated/torch.xpu.synchronize.md
- 9: https://docs.pytorch.org/docs/stable/generated/torch.xpu.Stream_class.md
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Look for other benchmarks or utilities that handle xpu/hip via torch.cuda or torch.xpu.
rg -n "torch\.xpu|device_type.*xpu|device_type.*hip|torch\.cuda\.(Event|synchronize|memory_allocated|max_memory_allocated|empty_cache|reset_peak_memory_stats)" .Repository: RL-Align/RL-Kernel
Length of output: 7573
Dispatch timing/memory helpers by backend
run_benchmark allows "xpu", but _time_ms and _peak_vram_gb still hardcode torch.cuda. Use torch.xpu for Intel GPUs and keep torch.cuda for CUDA/ROCm so the benchmark works on every allowed device.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@benchmarks/benchmark_batch_invariant_logp.py` around lines 63 - 92, The
timing and memory helpers are hardcoded to torch.cuda even though run_benchmark
supports xpu, so the benchmark will break on Intel GPUs. Update _time_ms and
_peak_vram_gb to dispatch through the active backend in device_ctx.device_type,
using torch.xpu for xpu and torch.cuda for cuda/hip, including synchronize,
Event creation, empty_cache, memory_allocated, reset_peak_memory_stats, and
max_memory_allocated calls.
| def _peak_vram_gb(fn, warmup=3, iters=5): | ||
| for _ in range(warmup): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| torch.cuda.empty_cache() | ||
| baseline = torch.cuda.memory_allocated() | ||
| torch.cuda.reset_peak_memory_stats() | ||
| for _ in range(iters): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| return (torch.cuda.max_memory_allocated() - baseline) / (1024**3) | ||
|
|
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
VRAM measurement ignores CLI --warmup/--iters overrides.
_peak_vram_gb hardcodes warmup=3, iters=5 (Line 77) instead of accepting args.warmup/args.iters like _time_ms does, so --warmup/--iters only affect timing, not the memory measurement, even though it is called right alongside the timing runs (Lines 115-116, 128).
Proposed fix
-def _peak_vram_gb(fn, warmup=3, iters=5):
+def _peak_vram_gb(fn, warmup, iters):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
torch.cuda.empty_cache()
baseline = torch.cuda.memory_allocated()
torch.cuda.reset_peak_memory_stats()
for _ in range(iters):
fn()
torch.cuda.synchronize()
return (torch.cuda.max_memory_allocated() - baseline) / (1024**3)And update call sites to pass args.warmup, args.iters.
Also applies to: 113-116, 127-128
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@benchmarks/benchmark_batch_invariant_logp.py` around lines 77 - 88, The VRAM
helper currently hardcodes its warmup and iteration counts, so CLI overrides are
ignored for memory measurement. Update `_peak_vram_gb` to accept the same
`warmup` and `iters` inputs used by `_time_ms`, and change the call sites in
`benchmark_batch_invariant_logp` to pass `args.warmup` and `args.iters` so
timing and VRAM runs stay consistent.
| std::vector<torch::Tensor> batch_invariant_logp_sm90_forward(torch::Tensor logits, | ||
| torch::Tensor target, | ||
| int64_t ignore_index) { | ||
| TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor"); | ||
| TORCH_CHECK(logits.dim() == 2, "logits must be 2-D [N, V]"); | ||
| TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); | ||
| const int N = logits.size(0); | ||
| const int V = logits.size(1); | ||
| TORCH_CHECK(target.numel() == N, "target must have one id per row: expected ", N, ", got ", | ||
| target.numel()); | ||
|
|
||
| // TMA requires the global row stride (V * elem_size) to be 16-byte aligned. | ||
| const int elem_size = static_cast<int>(logits.element_size()); | ||
| TORCH_CHECK((static_cast<int64_t>(V) * elem_size) % 16 == 0, | ||
| "batch_invariant_logp_sm90 requires the vocab row stride (V * elem_size) to be a " | ||
| "multiple of 16 bytes; got V=", | ||
| V, ", elem_size=", elem_size); | ||
|
|
||
| auto opts_f = logits.options().dtype(torch::kFloat); | ||
| auto logp = torch::empty({N}, opts_f); | ||
| auto lse = torch::empty({N}, opts_f); | ||
| auto target_i = target.to(torch::kInt32).contiguous(); |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify existing extension conventions for CUDA device guards / launch checks.
rg -n -C2 'CUDAGuard|OptionalCUDAGuard|C10_CUDA_KERNEL_LAUNCH_CHECK|CUDA_KERNEL_LAUNCH_CHECK' csrcRepository: RL-Align/RL-Kernel
Length of output: 565
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect the SM90 batch-invariant logp kernel entrypoint and nearby CUDA launch code.
FILE="csrc/cuda/batch_invariant_logp_kernel_sm90.cu"
wc -l "$FILE"
sed -n '1,260p' "$FILE"
# Find the Python binding/callers for this kernel.
rg -n -C3 'batch_invariant_logp_sm90|_C\.batch_invariant_logp_sm90|batch_invariant_logp_sm90_forward' .Repository: RL-Align/RL-Kernel
Length of output: 22607
Guard the SM90 entrypoint with a CUDA device guard and same-device target check.
logits can be on a different CUDA device than the current context, so this launch can hit the wrong GPU. The exported binding should set the device from logits and reject non-CUDA or cross-device target tensors; the Python wrapper already normalizes target, but direct _C.batch_invariant_logp_sm90(...) calls do not.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@csrc/cuda/batch_invariant_logp_kernel_sm90.cu` around lines 179 - 200, The
SM90 entrypoint in batch_invariant_logp_sm90_forward currently assumes the
current CUDA context matches logits, and it also only validates target shape,
not device. Add a CUDA device guard based on logits before any allocation or
launch, and add a TORCH_CHECK that target is on the same CUDA device as logits
(and CUDA at all) alongside the existing logits checks. Keep the validation and
setup near batch_invariant_logp_sm90_forward so direct calls through the
exported binding are safe even when the Python wrapper is bypassed.
| **Forward + backward** | ||
|
|
||
| | shape (N x V) | native ms | triton ms | cuda ms | cuda vs native | cuda vs triton | native MB | triton MB | cuda MB | | ||
| | --- | --- | --- | --- | --- | --- | --- | --- | --- | | ||
| | 4096 x 32768 | 3.549 | 0.463 | 0.407 | 8.7x | 1.14x | 1792 | 512 | 512 | | ||
| | 4096 x 128256 | 13.173 | 1.750 | 1.510 | 8.7x | 1.16x | 7014 | 2004 | 2004 | | ||
| | 4096 x 151936 | 15.632 | 2.072 | 1.788 | 8.7x | 1.16x | 8310 | 2376 | 2376 | | ||
| | 8192 x 128256 | 26.211 | 3.416 | 2.955 | 8.9x | 1.16x | 14028 | 4008 | 4008 | | ||
|
|
||
| - Forward: ~1.7x vs Triton, ~15x vs native, with ~0 extra VRAM — the vocab is | ||
| reduced to per-row scalars, so no `[N, V]` intermediate is materialized. | ||
| - Forward + backward: ~1.15x vs Triton, ~8.8x vs native, with memory equal to | ||
| Triton. The backward's `[N, V]` cost is `grad_logits` itself (one gradient per | ||
| input logit, unavoidable for any backend); the streamed backends avoid native's | ||
| extra `[N, V]` `softmax` / `log_softmax` intermediates by recomputing from the | ||
| saved per-row `lse`. |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
"Forward + backward" table isn't reproducible from the benchmark script.
The doc claims measured "Forward + backward" latency/VRAM numbers (Lines 85-92), but benchmarks/benchmark_batch_invariant_logp.py's run_benchmark only exercises the forward pass under torch.no_grad() — there is no .backward() call or backward timing/VRAM path in the script. Either the script needs a backward benchmark added, or this table should be removed/clarified as sourced from an out-of-tree variant.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@docs/operators/batch-invariant-logp.md` around lines 85 - 100, The “Forward +
backward” measurements in the batch-invariant logp docs are not reproducible
from the current benchmark path. Update the benchmark source used by
run_benchmark so it actually exercises and times a backward pass (including VRAM
tracking), or else remove/qualify this table as coming from a different
benchmark variant. Use the existing benchmark identifiers in
benchmark_batch_invariant_logp.py and its run_benchmark flow to keep the
documentation aligned with the script.
| def _sm90_supported(logits: torch.Tensor) -> bool: | ||
| """Whether the TMA forward can run these logits directly. | ||
| bf16/fp32 only, and the TMA descriptor needs the vocab row stride | ||
| (``V * element_size``) to be a multiple of 16 bytes. | ||
| """ | ||
| if not logits.is_cuda or logits.dtype not in (torch.bfloat16, torch.float32): | ||
| return False | ||
| return (logits.size(-1) * logits.element_size()) % 16 == 0 |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Include the logits device capability in _sm90_supported.
The wrapper can be used directly, and cached registry instances can receive tensors from a different CUDA device than the one probed during registry initialization. Without a per-input cc_major == 9 check here, aligned fp32/bf16 tensors on non-Hopper GPUs can still enter the SM90 TMA kernel instead of falling back.
Suggested fix
def _sm90_supported(logits: torch.Tensor) -> bool:
@@
if not logits.is_cuda or logits.dtype not in (torch.bfloat16, torch.float32):
return False
+ cc_major, _ = torch.cuda.get_device_capability(logits.device)
+ if cc_major != 9:
+ return False
return (logits.size(-1) * logits.element_size()) % 16 == 0📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _sm90_supported(logits: torch.Tensor) -> bool: | |
| """Whether the TMA forward can run these logits directly. | |
| bf16/fp32 only, and the TMA descriptor needs the vocab row stride | |
| (``V * element_size``) to be a multiple of 16 bytes. | |
| """ | |
| if not logits.is_cuda or logits.dtype not in (torch.bfloat16, torch.float32): | |
| return False | |
| return (logits.size(-1) * logits.element_size()) % 16 == 0 | |
| def _sm90_supported(logits: torch.Tensor) -> bool: | |
| """Whether the TMA forward can run these logits directly. | |
| bf16/fp32 only, and the TMA descriptor needs the vocab row stride | |
| (``V * element_size``) to be a multiple of 16 bytes. | |
| """ | |
| if not logits.is_cuda or logits.dtype not in (torch.bfloat16, torch.float32): | |
| return False | |
| cc_major, _ = torch.cuda.get_device_capability(logits.device) | |
| if cc_major != 9: | |
| return False | |
| return (logits.size(-1) * logits.element_size()) % 16 == 0 |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py` around lines 12 -
19, The `_sm90_supported` guard only checks tensor device type and alignment, so
tensors from non-Hopper CUDA devices can still route into the SM90 TMA path.
Update `_sm90_supported` to also verify the input logits’ CUDA capability
(cc_major == 9) on the tensor’s current device, alongside the existing dtype and
stride-multiple-of-16 checks, so cached registry entries correctly fall back for
non-SM90 GPUs.
| def __call__( | ||
| self, | ||
| logits: torch.Tensor, | ||
| target_ids: torch.Tensor, | ||
| ignore_index: int = -100, | ||
| *, | ||
| validate: bool = False, | ||
| ) -> torch.Tensor: | ||
| return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate) | ||
|
|
||
| def apply( | ||
| self, | ||
| logits: torch.Tensor, | ||
| target_ids: torch.Tensor, | ||
| ignore_index: int = -100, | ||
| *, | ||
| validate: bool = False, |
There was a problem hiding this comment.
🔒 Security & Privacy | 🟠 Major | ⚡ Quick win
Keep target validation on by default for the SM90 wrapper.
The native operator defaults to validate=True, but this wrapper defaults to False before calling an unchecked CUDA gather. A bad non-ignored target can read outside [0, V) by default; make opt-out explicit for callers that have already validated targets.
Suggested fix
def __call__(
@@
- validate: bool = False,
+ validate: bool = True,
@@
def apply(
@@
- validate: bool = False,
+ validate: bool = True,Also applies to: 145-157
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py` around lines 114 -
130, The SM90 wrapper currently disables target validation by default in
BatchInvariantLogp.__call__ and apply, which can allow unsafe CUDA gathers for
bad non-ignored targets. Update the default in these methods so validate matches
the native operator’s behavior (true by default), and keep an explicit opt-out
path for callers that already validated targets. Preserve the existing
ignore_index handling and ensure the wrapper forwards the validate flag
consistently through BatchInvariantLogp.
| def _validate_shapes(logits: torch.Tensor, target_ids: torch.Tensor) -> None: | ||
| if logits.dim() < 2: | ||
| raise ValueError( | ||
| f"logits must be at least 2-D ([*lead, V]), got shape {tuple(logits.shape)}" | ||
| ) | ||
| if logits.shape[:-1] != target_ids.shape: | ||
| raise ValueError( | ||
| f"logits leading shape {tuple(logits.shape[:-1])} must match " | ||
| f"target_ids shape {tuple(target_ids.shape)}" | ||
| ) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Locate the file and inspect the relevant section plus nearby usage.
git ls-files 'rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py'
echo '---'
wc -l rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
echo '--- lines 1-220 ---'
cat -n rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py | sed -n '1,220p'Repository: RL-Align/RL-Kernel
Length of output: 5615
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Probe PyTorch behavior for the edge cases mentioned in the review comment.
python3 - <<'PY'
import torch
cases = [
("bool target", torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([True, False])),
("float target", torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([0.0, 1.0])),
("int target", torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([0, 1])),
]
for name, logits, target in cases:
print(f"\nCASE: {name}")
try:
gathered = logits.log_softmax(-1).gather(-1, target.long().unsqueeze(-1)).squeeze(-1)
print("gather with target.long():", gathered)
except Exception as e:
print("gather error:", type(e).__name__, e)
# Empty vocab edge
logits = torch.empty(2, 0)
target = torch.empty(2, dtype=torch.long)
print("\nCASE: empty vocab")
try:
x = logits.log_softmax(-1)
print("log_softmax:", x)
except Exception as e:
print("log_softmax error:", type(e).__name__, e)
try:
x = logits.gather(-1, target.unsqueeze(-1))
print("gather:", x)
except Exception as e:
print("gather error:", type(e).__name__, e)
PYRepository: RL-Align/RL-Kernel
Length of output: 519
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find any direct call sites or existing validation that might already reject these inputs.
rg -n "target_ids.*long|\.long\(\).*target_ids|_validate_shapes\(" rl_engine/kernels/ops/pytorch/loss -SRepository: RL-Align/RL-Kernel
Length of output: 694
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo '== file =='
wc -l rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
echo '--- relevant section ---'
sed -n '1,220p' rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py | nl -ba | sed -n '1,220p'
echo '== search for related casting/validation =='
rg -n "_validate_shapes|target_ids.*long|\\.long\\(\\)|gather\\(" rl_engine/kernels/ops/pytorch/loss -S || true
echo '== torch edge-case probe =='
python3 - <<'PY'
import torch
def probe(logits, target_ids, label):
print(f"\nCASE: {label}")
print("logits.shape=", tuple(logits.shape), "target_ids.shape=", tuple(target_ids.shape), "target_ids.dtype=", target_ids.dtype)
try:
t = target_ids.long()
print("target_ids.long() dtype=", t.dtype, "values=", t)
except Exception as e:
print("target_ids.long() error:", type(e).__name__, e)
return
try:
ls = logits.log_softmax(-1)
print("log_softmax shape=", tuple(ls.shape))
except Exception as e:
print("log_softmax error:", type(e).__name__, e)
ls = None
if ls is not None:
try:
gathered = ls.gather(-1, t.unsqueeze(-1)).squeeze(-1)
print("gather result shape=", tuple(gathered.shape), "values=", gathered)
except Exception as e:
print("gather error:", type(e).__name__, e)
probe(torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([True, False]), "bool target")
probe(torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([0.0, 1.0]), "float target")
probe(torch.empty(2, 0), torch.empty(2, dtype=torch.long), "empty vocab")
PYRepository: RL-Align/RL-Kernel
Length of output: 295
Reject non-integral targets and empty vocab. target_ids is cast to long in apply(), so float/bool inputs are silently truncated before any range checks; validate the dtype first. Also reject logits.size(-1) == 0 up front, since the later reduction/indexing path fails on empty vocab.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py` around lines 115
- 124, The shape validator in _validate_shapes currently allows non-integral
target_ids and empty vocabularies to slip through. Update _validate_shapes in
batch_invariant_logp.py to first reject target_ids unless they are an
integer/long type before any cast happens in apply(), and add an explicit check
that logits.size(-1) is greater than zero before proceeding. Keep the existing
logits/target_ids shape checks in place after these new validations.
| validate: bool = False, | ||
| ) -> torch.Tensor: | ||
| return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate) | ||
|
|
||
| def apply( | ||
| self, | ||
| logits: torch.Tensor, | ||
| target_ids: torch.Tensor, | ||
| ignore_index: int = -100, | ||
| *, | ||
| validate: bool = False, |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Keep validation enabled by default for backend consistency.
Native defaults to validate=True, but Triton defaults to False; invalid negative/out-of-range targets then produce silent bogus logprobs instead of the documented validation failure.
Proposed default alignment
ignore_index: int = -100,
*,
- validate: bool = False,
+ validate: bool = True,
@@
ignore_index: int = -100,
*,
- validate: bool = False,
+ validate: bool = True,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| validate: bool = False, | |
| ) -> torch.Tensor: | |
| return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate) | |
| def apply( | |
| self, | |
| logits: torch.Tensor, | |
| target_ids: torch.Tensor, | |
| ignore_index: int = -100, | |
| *, | |
| validate: bool = False, | |
| validate: bool = True, | |
| ) -> torch.Tensor: | |
| return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate) | |
| def apply( | |
| self, | |
| logits: torch.Tensor, | |
| target_ids: torch.Tensor, | |
| ignore_index: int = -100, | |
| *, | |
| validate: bool = True, |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py` around lines 202 -
212, The Triton batch-invariant logp path defaults validation off, which makes
it inconsistent with the native backend and can hide invalid target IDs. Update
the default in the BatchInvariantLogp call/apply flow so validation is enabled
by default, matching the native implementation; adjust the default value on the
validate parameter in the relevant method(s) and ensure BatchInvariantLogp.apply
preserves that default behavior.
| if logits.dim() < 2: | ||
| raise ValueError( | ||
| f"logits must be at least 2-D ([*lead, V]), got shape " | ||
| f"{tuple(logits.shape)}" | ||
| ) | ||
| if logits.shape[:-1] != target_ids.shape: | ||
| raise ValueError( | ||
| f"logits leading shape {tuple(logits.shape[:-1])} must match " | ||
| f"target_ids shape {tuple(target_ids.shape)}" | ||
| ) | ||
|
|
||
| if validate: | ||
| vocab_size = logits.size(-1) | ||
| target_flat = target_ids.reshape(-1) | ||
| valid_targets = target_flat[target_flat != ignore_index] | ||
| if valid_targets.numel() > 0 and ( | ||
| (valid_targets < 0).any() or (valid_targets >= vocab_size).any() | ||
| ): | ||
| bad = valid_targets[ | ||
| (valid_targets < 0) | (valid_targets >= vocab_size) | ||
| ] | ||
| raise ValueError( | ||
| f"target_ids contains values outside [0, {vocab_size}): " | ||
| f"{bad.tolist()}" | ||
| ) | ||
|
|
||
| return _BatchInvariantLogpFunction.apply(logits, target_ids, ignore_index) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Map the file and nearby symbols first
ast-grep outline rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py --view expanded
# Inspect the relevant section around the reported lines and the cast mentioned in the comment
cat -n rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py | sed -n '100,170p'
cat -n rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py | sed -n '180,260p'
# Search for any existing dtype / empty-vocab validation in this module and adjacent loss implementations
rg -n "target_ids.*dtype|bool|is_floating_point|is_complex|vocab|empty|size\\(-1\\)|ignore_index|ValueError\\(" rl_engine/kernels/ops/triton/loss -g '*.py'Repository: RL-Align/RL-Kernel
Length of output: 15041
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect the sibling implementation that already performs target validation.
cat -n rl_engine/kernels/ops/triton/loss/linear_logp.py | sed -n '170,215p'
# Check whether this module already has any call sites that rely on non-integer targets or empty vocab.
rg -n "TritonBatchInvariantLogpOp|batch_invariant_logp|linear_logp|target_ids" rl_engine -g '*.py'Repository: RL-Align/RL-Kernel
Length of output: 15615
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Read the native and CUDA batch-invariant logp implementations around validation.
cat -n rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py | sed -n '1,140p'
echo '---'
cat -n rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py | sed -n '1,190p'
# Probe the Triton wrapper semantics in code only: locate the exact cast and any empty-vocab assumptions.
python3 - <<'PY'
from pathlib import Path
p = Path("rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py")
text = p.read_text()
for needle in [
"target_ids.reshape(-1).to(",
"vocab_size = logits.size(-1)",
"triton.cdiv(vocab_size, _BLOCK_V)",
"if validate:",
]:
print(f"\n=== {needle} ===")
idx = text.find(needle)
print(text[max(0, idx-200):idx+500] if idx != -1 else "not found")
PYRepository: RL-Align/RL-Kernel
Length of output: 15373
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Read the Triton forward kernel and any tests covering batch_invariant_logp edge cases.
cat -n rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py | sed -n '1,120p'
echo '---'
fd -a 'batch_invariant_logp' . | sed -n '1,50p'
echo '---'
rg -n "batch_invariant_logp|TritonBatchInvariantLogpOp|NativeBatchInvariantLogpOp|empty vocab|dtype|float|bool" . -g '*test*.py' -g '*tests*.py'Repository: RL-Align/RL-Kernel
Length of output: 32428
🏁 Script executed:
#!/bin/bash
set -euo pipefail
python3 - <<'PY'
import torch
samples = [
torch.tensor([1.2, -3.7, 4.0], dtype=torch.float32),
torch.tensor([True, False, True], dtype=torch.bool),
torch.tensor([1+2j, 3+0j], dtype=torch.complex64),
]
for t in samples:
print("src", t.dtype, t.tolist())
try:
out = t.to(dtype=torch.int64)
print("dst", out.dtype, out.tolist())
except Exception as e:
print("err", type(e).__name__, str(e))
print("--- empty vocab shape probe ---")
logits = torch.empty(2, 0, dtype=torch.float32)
target = torch.zeros(2, dtype=torch.long)
print("logits shape", tuple(logits.shape), "target shape", tuple(target.shape))
try:
row_max = logits.reshape(-1, logits.size(-1)).float().max(dim=-1).values
print("native row_max", row_max)
except Exception as e:
print("native err", type(e).__name__, str(e))
PY
# Check whether tests already cover empty-vocab or dtype coercion.
rg -n "empty vocab|vocab.*0|dtype.*target|float.*target|bool.*target|complex.*target|target_ids.*float|target_ids.*bool" tests/test_batch_invariant_logp.py rl_engine/kernels/ops -g '*.py'Repository: RL-Align/RL-Kernel
Length of output: 2205
Require an integer dtype check for target_ids.
target_ids.reshape(-1).to(dtype=torch.int64) silently truncates floats and accepts bool/complex inputs, so invalid targets can become wrong indices instead of failing fast.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py` around lines 219 -
245, `_BatchInvariantLogpFunction.apply` currently validates shapes and index
ranges but not the dtype of `target_ids`, so non-integer inputs can be silently
converted later. Add an explicit integer-type check in the validation path of
the batch invariant logp function before `target_ids.reshape(-1)` is used, and
raise a clear `ValueError` for float, bool, complex, or other non-integer
dtypes. Keep the existing shape and range checks in place and ensure the new
check is near the `validate` block in `batch_invariant_logp.py`.
[WS1][kernels] SM90 Batch-Invariant logprob CUDA Kernel
Summary
Adds a Hopper (SM90) TMA CUDA kernel for the batch-invariant selected-token log-probability operator introduced in #199 plus a Native/Triton/CUDA benchmark. The operator computes, from already materialized logits:
The kernel streams the vocab through shared memory with TMA bulk-tensor copies and folds it into a running online log-sum-exp, so it never materializes the
[N, V]softmax and its extra forward memory is ~0. It is registered as thetop-priority CUDA backend on Hopper and falls back cleanly everywhere else.
csrc/cuda/batch_invariant_logp_kernel_sm90.cu)BatchInvariantLogpSM90Op)lse(no[N, V]materialization).batch_invariant_logplist only when the SM90 symbol is compiled oncc_major == 9.benchmarks/benchmark_batch_invariant_logp.pycompares Native/Triton/CUDA. New.This builds on the open branch for #199 (Native + Triton). Reviewed after #199 merges.
Implementation
csrc/cuda/batch_invariant_logp_kernel_sm90.cu.csrc/ops.cppand build source insetup.py'ssm90_srcs; stub inrl_engine/_C.pyi.rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py:_BatchInvariantLogpSM90Functionforward calls the compiled kernel;_batch_invariant_logp_bwd_kernel.rl_engine/kernels/registry.py.docs/operators/batch-invariant-logp.md; added the benchmark script.Validation environment
nvccnvidiachannel)KERNEL_ALIGN_FORCE_SM90=1,_C.batch_invariant_logp_sm90availableTORCH_CUDA_ARCH_LIST=9.0a,KERNEL_ALIGN_DEV_RPATH=1Correctness / Tests
Unit coverage
Result: 67 passed (14 new SM90-kernel tests: correctness fp32/bf16, large and unaligned vocab, single-token, 3-D, bitwise batch-invariance, backward, ignore-index, fp16 fallback; plus the registry-dispatch test now accepting the CUDA op).
Result: 14 passed. Registry selects
BatchInvariantLogpSM90Opon the H200:Vocab sweep (bf16, vs
log_softmax + gather)Correct on both TMA-aligned and unaligned (tail-path) vocab sizes:
Benchmarks
python benchmarks/benchmark_batch_invariant_logp.py— bf16, H200, 20 iters + 5 warmup. Memory is peak extra above baseline (MB).Forward
Forward + backward
Files
csrc/cuda/batch_invariant_logp_kernel_sm90.cu(new)rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py(new)benchmarks/benchmark_batch_invariant_logp.py(new)csrc/ops.cpp,setup.py,rl_engine/_C.pyi,rl_engine/kernels/registry.pytests/test_batch_invariant_logp.py,docs/operators/batch-invariant-logp.mdSummary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests