-
Notifications
You must be signed in to change notification settings - Fork 42
[WS1][kernels] Batch-invariant logprob (CUDA) #204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d829be2
d748a8f
60697a5
c31cc31
603a7e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -206,3 +206,6 @@ cython_debug/ | |
| marimo/_static/ | ||
| marimo/_lsp/ | ||
| __marimo__/ | ||
|
|
||
| # Local dev notes (not for upstream) | ||
| _dev_notes/ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,174 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2026 RL-Kernel Contributors | ||
|
|
||
| """Benchmark batch-invariant logp: Native vs Triton vs CUDA (SM90 TMA). | ||
|
|
||
| All three backends compute ``logits[t, target[t]] - logsumexp(logits[t, :])`` | ||
| from a materialized ``[N, V]`` logits tensor with a locked, per-row reduction | ||
| order (batch-invariant). The comparison here is latency and peak VRAM across a | ||
| vocab sweep: | ||
|
|
||
| - Native materializes ``log_softmax`` over the full ``[N, V]`` tensor. | ||
| - Triton streams the vocab through an online softmax (grid = one program/row). | ||
| - CUDA is the Hopper TMA online-softmax kernel (one CTA/row); only present when | ||
| the extension is built with ``KERNEL_ALIGN_FORCE_SM90=1`` on an SM90 device. | ||
|
|
||
| Usage: | ||
| python benchmarks/benchmark_batch_invariant_logp.py | ||
| python benchmarks/benchmark_batch_invariant_logp.py --configs "4096,128256;8192,151936" | ||
| """ | ||
|
|
||
| import argparse | ||
|
|
||
| import torch | ||
| from tabulate import tabulate | ||
|
|
||
| from rl_engine.kernels.ops.pytorch.loss.batch_invariant_logp import NativeBatchInvariantLogpOp | ||
| from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import TritonBatchInvariantLogpOp | ||
| from rl_engine.platforms.device import device_ctx | ||
| from rl_engine.utils.logger import logger | ||
|
|
||
|
|
||
| def _maybe_sm90_op(): | ||
| """The Hopper TMA op, or None when unavailable (non-Hopper / not built).""" | ||
| from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE | ||
|
|
||
| if not ( | ||
| torch.cuda.is_available() | ||
| and torch.cuda.get_device_capability()[0] == 9 | ||
| and _EXT_AVAILABLE | ||
| and hasattr(_C, "batch_invariant_logp_sm90") | ||
| ): | ||
| return None | ||
| from rl_engine.kernels.ops.cuda.loss.batch_invariant_logp import BatchInvariantLogpSM90Op | ||
|
|
||
| return BatchInvariantLogpSM90Op() | ||
|
|
||
|
|
||
| # (num_tokens, vocab); vocab kept a multiple of 8 so the bf16 TMA path runs. | ||
| DEFAULT_CONFIGS = [ | ||
| (4096, 32768), | ||
| (4096, 128256), | ||
| (4096, 151936), | ||
| (8192, 128256), | ||
| ] | ||
|
|
||
|
|
||
| def _make_inputs(num_tokens, vocab, device, dtype): | ||
| logits = torch.randn(num_tokens, vocab, device=device, dtype=dtype) | ||
| target = torch.randint(0, vocab, (num_tokens,), device=device) | ||
| return logits, target | ||
|
|
||
|
|
||
| def _time_ms(fn, warmup, iters): | ||
| for _ in range(warmup): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| start = torch.cuda.Event(enable_timing=True) | ||
| end = torch.cuda.Event(enable_timing=True) | ||
| start.record() | ||
| for _ in range(iters): | ||
| fn() | ||
| end.record() | ||
| torch.cuda.synchronize() | ||
| return start.elapsed_time(end) / iters | ||
|
|
||
|
|
||
| def _peak_vram_gb(fn, warmup=3, iters=5): | ||
| for _ in range(warmup): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| torch.cuda.empty_cache() | ||
| baseline = torch.cuda.memory_allocated() | ||
| torch.cuda.reset_peak_memory_stats() | ||
| for _ in range(iters): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
| return (torch.cuda.max_memory_allocated() - baseline) / (1024**3) | ||
|
|
||
|
|
||
| def run_benchmark(args): | ||
| if device_ctx.device_type not in ["cuda", "xpu", "hip"]: | ||
| raise RuntimeError("batch_invariant_logp benchmark requires a compatible GPU device.") | ||
|
Comment on lines
+63
to
+92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win 🧩 Analysis chain🌐 Web query:
💡 Result: Yes, torch.xpu provides APIs that are functionally equivalent to the corresponding torch.cuda APIs for Intel GPUs [1][2][3]. Specifically, PyTorch implements the following support: - Synchronization: The torch.xpu.synchronize function is available to wait for all kernels in all streams on an XPU device to complete, mirroring the behavior of torch.cuda.synchronize [4][2][5]. - Events: The torch.xpu.Event class serves as a wrapper around an XPU event, allowing for synchronization markers and monitoring device progress, consistent with torch.cuda.Event [6][2][7]. - Memory Statistics: PyTorch provides a comprehensive set of memory management APIs under the torch.xpu.memory namespace that mirror torch.cuda memory functions [1][2]. These include: - torch.xpu.memory.memory_stats and memory_stats_as_nested_dict for allocator statistics [1][2][8]. - Metrics such as memory_allocated, memory_reserved, max_memory_allocated, and max_memory_reserved [1][2]. - Utilities like memory_snapshot, mem_get_info, and tools to manage memory fractions [1][2]. These XPU-specific APIs are designed to provide a consistent user experience when migrating code from CUDA to Intel GPU backends [9][3]. Citations:
🏁 Script executed: #!/bin/bash
set -euo pipefail
# Inspect the benchmark file and nearby device handling.
git ls-files benchmarks/benchmark_batch_invariant_logp.py
echo "---"
cat -n benchmarks/benchmark_batch_invariant_logp.py | sed -n '1,220p'
echo "--- device_ctx references ---"
rg -n "device_ctx|torch\.cuda|torch\.xpu|torch\.hip|xpu" benchmarks/benchmark_batch_invariant_logp.pyRepository: RL-Align/RL-Kernel Length of output: 7659 🏁 Script executed: #!/bin/bash
set -euo pipefail
# Check PyTorch backend alias behavior from installed docs/source if available in repo.
python3 - <<'PY'
import importlib.util, sys
mods = ["torch"]
for m in mods:
print(m, importlib.util.find_spec(m) is not None)
PY
# Try to inspect PyTorch API docs via Python introspection if torch is installed.
python3 - <<'PY'
try:
import torch
print("torch version:", torch.__version__)
for name in ["cuda", "xpu"]:
print(name, "present:", hasattr(torch, name))
if hasattr(torch, name):
mod = getattr(torch, name)
for attr in ["synchronize", "Event", "memory_allocated", "empty_cache", "reset_peak_memory_stats", "max_memory_allocated"]:
print(name, attr, hasattr(mod, attr))
except Exception as e:
print("torch import failed:", repr(e))
PYRepository: RL-Align/RL-Kernel Length of output: 530 🌐 Web query:
💡 Result: In PyTorch, the XPU backend provides a dedicated namespace ( Citations:
🌐 Web query:
💡 Result: In PyTorch, the torch.xpu API provides synchronization and event management functionality specifically for Intel GPUs, mirroring the patterns found in torch.cuda [1][2]. For migration purposes, you should replace references to torch.cuda with torch.xpu [1]. The equivalent components are: - Event Management: Use torch.xpu.Event instead of torch.cuda.Event [3][4][5]. Like its CUDA counterpart, torch.xpu.Event is a synchronization marker used to monitor device progress, measure timing (if enabled), and synchronize XPU streams [3][6]. It supports methods such as record, synchronize, wait, and query [6][5]. - Synchronization: Use torch.xpu.synchronize instead of torch.cuda.synchronize [7][8]. This function waits for all kernels in all streams on the specified XPU device to complete [8]. Additionally, individual streams have their own synchronization methods via torch.xpu.Stream, such as stream.synchronize [9]. For further details on migrating workflows, the official PyTorch documentation on XPU semantics is the recommended resource [1][7]. Citations:
🏁 Script executed: #!/bin/bash
set -euo pipefail
# Look for other benchmarks or utilities that handle xpu/hip via torch.cuda or torch.xpu.
rg -n "torch\.xpu|device_type.*xpu|device_type.*hip|torch\.cuda\.(Event|synchronize|memory_allocated|max_memory_allocated|empty_cache|reset_peak_memory_stats)" .Repository: RL-Align/RL-Kernel Length of output: 7573 Dispatch timing/memory helpers by backend 🤖 Prompt for AI Agents |
||
|
|
||
| device = device_ctx.device | ||
| dtype = torch.bfloat16 | ||
| native = NativeBatchInvariantLogpOp() | ||
| triton_op = TritonBatchInvariantLogpOp() | ||
| sm90_op = _maybe_sm90_op() | ||
|
|
||
| logger.info( | ||
| f"batch_invariant_logp benchmark on {device} (dtype={dtype}); " | ||
| f"SM90 TMA backend {'enabled' if sm90_op is not None else 'unavailable'}" | ||
| ) | ||
|
|
||
| rows = [] | ||
| for num_tokens, vocab in args.configs: | ||
| logits, target = _make_inputs(num_tokens, vocab, device, dtype) | ||
|
|
||
| def fwd(op, x=logits, t=target): | ||
| with torch.no_grad(): | ||
| op(x, t, validate=False) | ||
|
|
||
| n_fwd = _time_ms(lambda: fwd(native), args.warmup, args.iters) | ||
| t_fwd = _time_ms(lambda: fwd(triton_op), args.warmup, args.iters) | ||
| n_vram = _peak_vram_gb(lambda: fwd(native)) | ||
| t_vram = _peak_vram_gb(lambda: fwd(triton_op)) | ||
|
|
||
| row = [ | ||
| f"{num_tokens}x{vocab}", | ||
| f"{n_fwd:.3f}", | ||
| f"{t_fwd:.3f}", | ||
| f"{n_fwd/t_fwd:.2f}x", | ||
| f"{n_vram*1024:.0f}", | ||
| f"{t_vram*1024:.0f}", | ||
| ] | ||
| if sm90_op is not None: | ||
| s_fwd = _time_ms(lambda: fwd(sm90_op), args.warmup, args.iters) | ||
| s_vram = _peak_vram_gb(lambda: fwd(sm90_op)) | ||
| row += [ | ||
| f"{s_fwd:.3f}", | ||
| f"{n_fwd/s_fwd:.2f}x", | ||
| f"{t_fwd/s_fwd:.2f}x", | ||
| f"{s_vram*1024:.0f}", | ||
| ] | ||
| rows.append(row) | ||
|
|
||
| headers = [ | ||
| "shape (N x V)", | ||
| "native fwd ms", | ||
| "triton fwd ms", | ||
| "fwd speedup", | ||
| "native fwd MB", | ||
| "triton fwd MB", | ||
| ] | ||
| if sm90_op is not None: | ||
| headers += [ | ||
| "cuda fwd ms", | ||
| "cuda vs native", | ||
| "cuda vs triton", | ||
| "cuda fwd MB", | ||
| ] | ||
| print(tabulate(rows, headers=headers, tablefmt="github")) | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser(description=__doc__) | ||
| parser.add_argument("--iters", type=int, default=20) | ||
| parser.add_argument("--warmup", type=int, default=5) | ||
| parser.add_argument( | ||
| "--configs", | ||
| type=str, | ||
| default=None, | ||
| help="Semicolon-separated 'tokens,vocab' tuples, e.g. '4096,128256;8192,151936'.", | ||
| ) | ||
| args = parser.parse_args() | ||
| if args.configs: | ||
| args.configs = [tuple(int(x) for x in tup.split(",")) for tup in args.configs.split(";")] | ||
| else: | ||
| args.configs = DEFAULT_CONFIGS | ||
| return args | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_benchmark(parse_args()) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,213 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // Copyright (c) 2026 RL-Kernel Contributors | ||
|
|
||
| // Hopper (SM90) batch-invariant selected-token log-prob | ||
| // logp[n] = logits[n, target[n]] - logsumexp(logits[n, :]) | ||
|
|
||
| #include "../utils/tma_utils.cuh" | ||
| #include <cub/cub.cuh> | ||
| #include <cuda_bf16.h> | ||
| #include <math_constants.h> | ||
| #include <torch/extension.h> | ||
|
|
||
| namespace { | ||
|
|
||
| // A single TMA box dimension is capped at 256 elements | ||
| #define SMEM_TILE 4096 | ||
| #define TMA_BOX 256 | ||
| static constexpr int LOADS_PER_TILE = SMEM_TILE / TMA_BOX; | ||
|
|
||
| template <typename T> __device__ __forceinline__ float to_float(T x); | ||
| template <> __device__ __forceinline__ float to_float<nv_bfloat16>(nv_bfloat16 x) { | ||
| return __bfloat162float(x); | ||
| } | ||
| template <> __device__ __forceinline__ float to_float<float>(float x) { return x; } | ||
|
|
||
| template <typename T, int NUM_WARPS> | ||
| __global__ void batch_invariant_logp_sm90_kernel(const __grid_constant__ CUtensorMap logits_tmap, | ||
| const int *__restrict__ target, | ||
| const T *__restrict__ logits_gmem, | ||
| float *__restrict__ output_logp, | ||
| float *__restrict__ output_lse, int num_tokens, | ||
| int vocab_size, int ignore_index) { | ||
| constexpr int NUM_THREADS = NUM_WARPS * 32; | ||
|
|
||
| // one CTA per row; each warp streams a TMA tile of the row into shared memory | ||
| const int tid = threadIdx.x; | ||
| const int row_idx = blockIdx.x; | ||
|
|
||
| extern __shared__ __align__(1024) char smem[]; | ||
| const int smem_addr = static_cast<int>(__cvta_generic_to_shared(smem)); | ||
| T *smem_logits = reinterpret_cast<T *>(smem); | ||
| const int tma_mbar_addr = smem_addr + (SMEM_TILE * sizeof(T)); | ||
|
|
||
| if (tid == 0) { | ||
| mbarrier_init(tma_mbar_addr, 1); | ||
| asm volatile("fence.mbarrier_init.release.cluster;"); | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| const int v_aligned = (vocab_size / TMA_BOX) * TMA_BOX; | ||
| const int num_tiles = (v_aligned + SMEM_TILE - 1) / SMEM_TILE; | ||
| int phase = 0; | ||
|
|
||
| using BlockReduce = cub::BlockReduce<float, NUM_THREADS>; | ||
| __shared__ typename BlockReduce::TempStorage temp_storage; | ||
| __shared__ float s_tile_max; | ||
|
|
||
| float row_max = -CUDART_INF_F; | ||
| float row_sum = 0.0f; | ||
|
|
||
| for (int step = 0; step < num_tiles; ++step) { | ||
| const int col_offset = step * SMEM_TILE; | ||
| const int current_tile_size = min(SMEM_TILE, v_aligned - col_offset); | ||
|
|
||
| if (tid == 0) { | ||
| for (int j = 0; j < LOADS_PER_TILE; ++j) { | ||
| const int col = col_offset + j * TMA_BOX; | ||
| if (col < v_aligned) { | ||
| tma_2d_g2s(smem_addr + j * TMA_BOX * sizeof(T), &logits_tmap, col, row_idx, | ||
| tma_mbar_addr); | ||
| } | ||
| } | ||
| mbarrier_arrive_expect_tx(tma_mbar_addr, current_tile_size * sizeof(T)); | ||
| } | ||
| mbarrier_wait(tma_mbar_addr, phase); | ||
| phase ^= 1; | ||
|
|
||
| // tile max (fixed strided partition -> deterministic) | ||
| float tile_max = -CUDART_INF_F; | ||
| for (int i = tid; i < current_tile_size; i += NUM_THREADS) { | ||
| tile_max = max(tile_max, to_float<T>(smem_logits[i])); | ||
| } | ||
| float block_tile_max = BlockReduce(temp_storage).Reduce(tile_max, cub::Max()); | ||
| if (tid == 0) | ||
| s_tile_max = block_tile_max; | ||
| __syncthreads(); | ||
|
|
||
| // tile sum of exp(x - tile_max). | ||
| float tile_sum = 0.0f; | ||
| for (int i = tid; i < current_tile_size; i += NUM_THREADS) { | ||
| tile_sum += expf(to_float<T>(smem_logits[i]) - s_tile_max); | ||
| } | ||
| float block_tile_sum = BlockReduce(temp_storage).Reduce(tile_sum, cub::Sum()); | ||
|
|
||
| // Online log-sum-exp merge of this tile into the running row state. | ||
| if (tid == 0) { | ||
| const float new_max = max(row_max, s_tile_max); | ||
| row_sum = row_sum * expf(row_max - new_max) + block_tile_sum * expf(s_tile_max - new_max); | ||
| row_max = new_max; | ||
| } | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| // Tail [v_aligned, V): fewer than TMA_BOX elements, read straight from global. | ||
| const int tail = vocab_size - v_aligned; | ||
| if (tail > 0) { | ||
| const int64_t base = (int64_t)row_idx * vocab_size + v_aligned; | ||
| float tail_max = -CUDART_INF_F; | ||
| for (int i = tid; i < tail; i += NUM_THREADS) { | ||
| tail_max = max(tail_max, to_float<T>(logits_gmem[base + i])); | ||
| } | ||
| float block_tail_max = BlockReduce(temp_storage).Reduce(tail_max, cub::Max()); | ||
| if (tid == 0) | ||
| s_tile_max = block_tail_max; | ||
| __syncthreads(); | ||
|
|
||
| float tail_sum = 0.0f; | ||
| for (int i = tid; i < tail; i += NUM_THREADS) { | ||
| tail_sum += expf(to_float<T>(logits_gmem[base + i]) - s_tile_max); | ||
| } | ||
| float block_tail_sum = BlockReduce(temp_storage).Reduce(tail_sum, cub::Sum()); | ||
| if (tid == 0) { | ||
| const float new_max = max(row_max, s_tile_max); | ||
| row_sum = row_sum * expf(row_max - new_max) + block_tail_sum * expf(s_tile_max - new_max); | ||
| row_max = new_max; | ||
| } | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| if (tid == 0) { | ||
| const float lse = row_max + logf(row_sum); | ||
| const int tgt = target[row_idx]; | ||
| if (tgt == ignore_index) { | ||
| output_logp[row_idx] = 0.0f; | ||
| } else { | ||
| const float tgt_logit = to_float<T>(logits_gmem[(int64_t)row_idx * vocab_size + tgt]); | ||
| output_logp[row_idx] = tgt_logit - lse; | ||
| } | ||
| output_lse[row_idx] = lse; | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| void launch_batch_invariant_logp_sm90(torch::Tensor logits, torch::Tensor target, | ||
| torch::Tensor logp, torch::Tensor lse, int N, int V, | ||
| int ignore_index) { | ||
| constexpr int NUM_WARPS = 4; | ||
| CUtensorMap logits_tmap; | ||
| // Global [N, V]; TMA box = [1 row, TMA_BOX cols], unswizzled row-major. | ||
| init_tensor_map<T>(&logits_tmap, reinterpret_cast<const T *>(logits.data_ptr<T>()), N, V, 1, | ||
| TMA_BOX); | ||
|
|
||
| const int smem_size = (SMEM_TILE * sizeof(T)) + 16; | ||
| batch_invariant_logp_sm90_kernel<T, NUM_WARPS><<<N, NUM_WARPS * 32, smem_size>>>( | ||
| logits_tmap, target.data_ptr<int>(), reinterpret_cast<const T *>(logits.data_ptr<T>()), | ||
| logp.data_ptr<float>(), lse.data_ptr<float>(), N, V, ignore_index); | ||
| } | ||
|
|
||
| // at::BFloat16 and nv_bfloat16 are layout-compatible; data_ptr<T> needs the ATen | ||
| // type, so bf16 gets its own launch that reinterprets the pointer. | ||
| void launch_batch_invariant_logp_sm90_bf16(torch::Tensor logits, torch::Tensor target, | ||
| torch::Tensor logp, torch::Tensor lse, int N, int V, | ||
| int ignore_index) { | ||
| constexpr int NUM_WARPS = 4; | ||
| CUtensorMap logits_tmap; | ||
| init_tensor_map<nv_bfloat16>( | ||
| &logits_tmap, reinterpret_cast<const nv_bfloat16 *>(logits.data_ptr<at::BFloat16>()), N, V, | ||
| 1, TMA_BOX); | ||
|
|
||
| const int smem_size = (SMEM_TILE * sizeof(nv_bfloat16)) + 16; | ||
| batch_invariant_logp_sm90_kernel<nv_bfloat16, NUM_WARPS><<<N, NUM_WARPS * 32, smem_size>>>( | ||
| logits_tmap, target.data_ptr<int>(), | ||
| reinterpret_cast<const nv_bfloat16 *>(logits.data_ptr<at::BFloat16>()), | ||
| logp.data_ptr<float>(), lse.data_ptr<float>(), N, V, ignore_index); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| std::vector<torch::Tensor> batch_invariant_logp_sm90_forward(torch::Tensor logits, | ||
| torch::Tensor target, | ||
| int64_t ignore_index) { | ||
| TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor"); | ||
| TORCH_CHECK(logits.dim() == 2, "logits must be 2-D [N, V]"); | ||
| TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); | ||
| const int N = logits.size(0); | ||
| const int V = logits.size(1); | ||
| TORCH_CHECK(target.numel() == N, "target must have one id per row: expected ", N, ", got ", | ||
| target.numel()); | ||
|
|
||
| // TMA requires the global row stride (V * elem_size) to be 16-byte aligned. | ||
| const int elem_size = static_cast<int>(logits.element_size()); | ||
| TORCH_CHECK((static_cast<int64_t>(V) * elem_size) % 16 == 0, | ||
| "batch_invariant_logp_sm90 requires the vocab row stride (V * elem_size) to be a " | ||
| "multiple of 16 bytes; got V=", | ||
| V, ", elem_size=", elem_size); | ||
|
|
||
| auto opts_f = logits.options().dtype(torch::kFloat); | ||
| auto logp = torch::empty({N}, opts_f); | ||
| auto lse = torch::empty({N}, opts_f); | ||
| auto target_i = target.to(torch::kInt32).contiguous(); | ||
|
Comment on lines
+179
to
+200
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🩺 Stability & Availability | 🟠 Major | ⚡ Quick win 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Verify existing extension conventions for CUDA device guards / launch checks.
rg -n -C2 'CUDAGuard|OptionalCUDAGuard|C10_CUDA_KERNEL_LAUNCH_CHECK|CUDA_KERNEL_LAUNCH_CHECK' csrcRepository: RL-Align/RL-Kernel Length of output: 565 🏁 Script executed: #!/bin/bash
set -euo pipefail
# Inspect the SM90 batch-invariant logp kernel entrypoint and nearby CUDA launch code.
FILE="csrc/cuda/batch_invariant_logp_kernel_sm90.cu"
wc -l "$FILE"
sed -n '1,260p' "$FILE"
# Find the Python binding/callers for this kernel.
rg -n -C3 'batch_invariant_logp_sm90|_C\.batch_invariant_logp_sm90|batch_invariant_logp_sm90_forward' .Repository: RL-Align/RL-Kernel Length of output: 22607 Guard the SM90 entrypoint with a CUDA device guard and same-device target check.
🤖 Prompt for AI Agents |
||
|
|
||
| if (logits.scalar_type() == at::kBFloat16) { | ||
| launch_batch_invariant_logp_sm90_bf16(logits, target_i, logp, lse, N, V, | ||
| static_cast<int>(ignore_index)); | ||
| } else if (logits.scalar_type() == at::kFloat) { | ||
| launch_batch_invariant_logp_sm90<float>(logits, target_i, logp, lse, N, V, | ||
| static_cast<int>(ignore_index)); | ||
| } else { | ||
| TORCH_CHECK(false, "batch_invariant_logp_sm90 supports only bfloat16 and float32 logits"); | ||
| } | ||
|
|
||
| return {logp, lse}; | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
VRAM measurement ignores CLI
--warmup/--itersoverrides._peak_vram_gbhardcodeswarmup=3, iters=5(Line 77) instead of acceptingargs.warmup/args.iterslike_time_msdoes, so--warmup/--itersonly affect timing, not the memory measurement, even though it is called right alongside the timing runs (Lines 115-116, 128).Proposed fix
And update call sites to pass
args.warmup, args.iters.Also applies to: 113-116, 127-128
🤖 Prompt for AI Agents