diff --git a/.gitignore b/.gitignore index ae89c0d..edc4b00 100644 --- a/.gitignore +++ b/.gitignore @@ -206,3 +206,6 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +# Local dev notes (not for upstream) +_dev_notes/ diff --git a/benchmarks/benchmark_batch_invariant_logp.py b/benchmarks/benchmark_batch_invariant_logp.py new file mode 100644 index 0000000..0b7c7f6 --- /dev/null +++ b/benchmarks/benchmark_batch_invariant_logp.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Benchmark batch-invariant logp: Native vs Triton vs CUDA (SM90 TMA). + +All three backends compute ``logits[t, target[t]] - logsumexp(logits[t, :])`` +from a materialized ``[N, V]`` logits tensor with a locked, per-row reduction +order (batch-invariant). The comparison here is latency and peak VRAM across a +vocab sweep: + +- Native materializes ``log_softmax`` over the full ``[N, V]`` tensor. +- Triton streams the vocab through an online softmax (grid = one program/row). +- CUDA is the Hopper TMA online-softmax kernel (one CTA/row); only present when + the extension is built with ``KERNEL_ALIGN_FORCE_SM90=1`` on an SM90 device. + +Usage: + python benchmarks/benchmark_batch_invariant_logp.py + python benchmarks/benchmark_batch_invariant_logp.py --configs "4096,128256;8192,151936" +""" + +import argparse + +import torch +from tabulate import tabulate + +from rl_engine.kernels.ops.pytorch.loss.batch_invariant_logp import NativeBatchInvariantLogpOp +from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import TritonBatchInvariantLogpOp +from rl_engine.platforms.device import device_ctx +from rl_engine.utils.logger import logger + + +def _maybe_sm90_op(): + """The Hopper TMA op, or None when unavailable (non-Hopper / not built).""" + from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE + + if not ( + torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] == 9 + and _EXT_AVAILABLE + and hasattr(_C, "batch_invariant_logp_sm90") + ): + return None + from rl_engine.kernels.ops.cuda.loss.batch_invariant_logp import BatchInvariantLogpSM90Op + + return BatchInvariantLogpSM90Op() + + +# (num_tokens, vocab); vocab kept a multiple of 8 so the bf16 TMA path runs. +DEFAULT_CONFIGS = [ + (4096, 32768), + (4096, 128256), + (4096, 151936), + (8192, 128256), +] + + +def _make_inputs(num_tokens, vocab, device, dtype): + logits = torch.randn(num_tokens, vocab, device=device, dtype=dtype) + target = torch.randint(0, vocab, (num_tokens,), device=device) + return logits, target + + +def _time_ms(fn, warmup, iters): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / iters + + +def _peak_vram_gb(fn, warmup=3, iters=5): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + torch.cuda.empty_cache() + baseline = torch.cuda.memory_allocated() + torch.cuda.reset_peak_memory_stats() + for _ in range(iters): + fn() + torch.cuda.synchronize() + return (torch.cuda.max_memory_allocated() - baseline) / (1024**3) + + +def run_benchmark(args): + if device_ctx.device_type not in ["cuda", "xpu", "hip"]: + raise RuntimeError("batch_invariant_logp benchmark requires a compatible GPU device.") + + device = device_ctx.device + dtype = torch.bfloat16 + native = NativeBatchInvariantLogpOp() + triton_op = TritonBatchInvariantLogpOp() + sm90_op = _maybe_sm90_op() + + logger.info( + f"batch_invariant_logp benchmark on {device} (dtype={dtype}); " + f"SM90 TMA backend {'enabled' if sm90_op is not None else 'unavailable'}" + ) + + rows = [] + for num_tokens, vocab in args.configs: + logits, target = _make_inputs(num_tokens, vocab, device, dtype) + + def fwd(op, x=logits, t=target): + with torch.no_grad(): + op(x, t, validate=False) + + n_fwd = _time_ms(lambda: fwd(native), args.warmup, args.iters) + t_fwd = _time_ms(lambda: fwd(triton_op), args.warmup, args.iters) + n_vram = _peak_vram_gb(lambda: fwd(native)) + t_vram = _peak_vram_gb(lambda: fwd(triton_op)) + + row = [ + f"{num_tokens}x{vocab}", + f"{n_fwd:.3f}", + f"{t_fwd:.3f}", + f"{n_fwd/t_fwd:.2f}x", + f"{n_vram*1024:.0f}", + f"{t_vram*1024:.0f}", + ] + if sm90_op is not None: + s_fwd = _time_ms(lambda: fwd(sm90_op), args.warmup, args.iters) + s_vram = _peak_vram_gb(lambda: fwd(sm90_op)) + row += [ + f"{s_fwd:.3f}", + f"{n_fwd/s_fwd:.2f}x", + f"{t_fwd/s_fwd:.2f}x", + f"{s_vram*1024:.0f}", + ] + rows.append(row) + + headers = [ + "shape (N x V)", + "native fwd ms", + "triton fwd ms", + "fwd speedup", + "native fwd MB", + "triton fwd MB", + ] + if sm90_op is not None: + headers += [ + "cuda fwd ms", + "cuda vs native", + "cuda vs triton", + "cuda fwd MB", + ] + print(tabulate(rows, headers=headers, tablefmt="github")) + + +def parse_args(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument( + "--configs", + type=str, + default=None, + help="Semicolon-separated 'tokens,vocab' tuples, e.g. '4096,128256;8192,151936'.", + ) + args = parser.parse_args() + if args.configs: + args.configs = [tuple(int(x) for x in tup.split(",")) for tup in args.configs.split(";")] + else: + args.configs = DEFAULT_CONFIGS + return args + + +if __name__ == "__main__": + run_benchmark(parse_args()) diff --git a/csrc/cuda/batch_invariant_logp_kernel_sm90.cu b/csrc/cuda/batch_invariant_logp_kernel_sm90.cu new file mode 100644 index 0000000..6496dda --- /dev/null +++ b/csrc/cuda/batch_invariant_logp_kernel_sm90.cu @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2026 RL-Kernel Contributors + +// Hopper (SM90) batch-invariant selected-token log-prob +// logp[n] = logits[n, target[n]] - logsumexp(logits[n, :]) + +#include "../utils/tma_utils.cuh" +#include +#include +#include +#include + +namespace { + +// A single TMA box dimension is capped at 256 elements +#define SMEM_TILE 4096 +#define TMA_BOX 256 +static constexpr int LOADS_PER_TILE = SMEM_TILE / TMA_BOX; + +template __device__ __forceinline__ float to_float(T x); +template <> __device__ __forceinline__ float to_float(nv_bfloat16 x) { + return __bfloat162float(x); +} +template <> __device__ __forceinline__ float to_float(float x) { return x; } + +template +__global__ void batch_invariant_logp_sm90_kernel(const __grid_constant__ CUtensorMap logits_tmap, + const int *__restrict__ target, + const T *__restrict__ logits_gmem, + float *__restrict__ output_logp, + float *__restrict__ output_lse, int num_tokens, + int vocab_size, int ignore_index) { + constexpr int NUM_THREADS = NUM_WARPS * 32; + + // one CTA per row; each warp streams a TMA tile of the row into shared memory + const int tid = threadIdx.x; + const int row_idx = blockIdx.x; + + extern __shared__ __align__(1024) char smem[]; + const int smem_addr = static_cast(__cvta_generic_to_shared(smem)); + T *smem_logits = reinterpret_cast(smem); + const int tma_mbar_addr = smem_addr + (SMEM_TILE * sizeof(T)); + + if (tid == 0) { + mbarrier_init(tma_mbar_addr, 1); + asm volatile("fence.mbarrier_init.release.cluster;"); + } + __syncthreads(); + + const int v_aligned = (vocab_size / TMA_BOX) * TMA_BOX; + const int num_tiles = (v_aligned + SMEM_TILE - 1) / SMEM_TILE; + int phase = 0; + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ float s_tile_max; + + float row_max = -CUDART_INF_F; + float row_sum = 0.0f; + + for (int step = 0; step < num_tiles; ++step) { + const int col_offset = step * SMEM_TILE; + const int current_tile_size = min(SMEM_TILE, v_aligned - col_offset); + + if (tid == 0) { + for (int j = 0; j < LOADS_PER_TILE; ++j) { + const int col = col_offset + j * TMA_BOX; + if (col < v_aligned) { + tma_2d_g2s(smem_addr + j * TMA_BOX * sizeof(T), &logits_tmap, col, row_idx, + tma_mbar_addr); + } + } + mbarrier_arrive_expect_tx(tma_mbar_addr, current_tile_size * sizeof(T)); + } + mbarrier_wait(tma_mbar_addr, phase); + phase ^= 1; + + // tile max (fixed strided partition -> deterministic) + float tile_max = -CUDART_INF_F; + for (int i = tid; i < current_tile_size; i += NUM_THREADS) { + tile_max = max(tile_max, to_float(smem_logits[i])); + } + float block_tile_max = BlockReduce(temp_storage).Reduce(tile_max, cub::Max()); + if (tid == 0) + s_tile_max = block_tile_max; + __syncthreads(); + + // tile sum of exp(x - tile_max). + float tile_sum = 0.0f; + for (int i = tid; i < current_tile_size; i += NUM_THREADS) { + tile_sum += expf(to_float(smem_logits[i]) - s_tile_max); + } + float block_tile_sum = BlockReduce(temp_storage).Reduce(tile_sum, cub::Sum()); + + // Online log-sum-exp merge of this tile into the running row state. + if (tid == 0) { + const float new_max = max(row_max, s_tile_max); + row_sum = row_sum * expf(row_max - new_max) + block_tile_sum * expf(s_tile_max - new_max); + row_max = new_max; + } + __syncthreads(); + } + + // Tail [v_aligned, V): fewer than TMA_BOX elements, read straight from global. + const int tail = vocab_size - v_aligned; + if (tail > 0) { + const int64_t base = (int64_t)row_idx * vocab_size + v_aligned; + float tail_max = -CUDART_INF_F; + for (int i = tid; i < tail; i += NUM_THREADS) { + tail_max = max(tail_max, to_float(logits_gmem[base + i])); + } + float block_tail_max = BlockReduce(temp_storage).Reduce(tail_max, cub::Max()); + if (tid == 0) + s_tile_max = block_tail_max; + __syncthreads(); + + float tail_sum = 0.0f; + for (int i = tid; i < tail; i += NUM_THREADS) { + tail_sum += expf(to_float(logits_gmem[base + i]) - s_tile_max); + } + float block_tail_sum = BlockReduce(temp_storage).Reduce(tail_sum, cub::Sum()); + if (tid == 0) { + const float new_max = max(row_max, s_tile_max); + row_sum = row_sum * expf(row_max - new_max) + block_tail_sum * expf(s_tile_max - new_max); + row_max = new_max; + } + __syncthreads(); + } + + if (tid == 0) { + const float lse = row_max + logf(row_sum); + const int tgt = target[row_idx]; + if (tgt == ignore_index) { + output_logp[row_idx] = 0.0f; + } else { + const float tgt_logit = to_float(logits_gmem[(int64_t)row_idx * vocab_size + tgt]); + output_logp[row_idx] = tgt_logit - lse; + } + output_lse[row_idx] = lse; + } +} + +template +void launch_batch_invariant_logp_sm90(torch::Tensor logits, torch::Tensor target, + torch::Tensor logp, torch::Tensor lse, int N, int V, + int ignore_index) { + constexpr int NUM_WARPS = 4; + CUtensorMap logits_tmap; + // Global [N, V]; TMA box = [1 row, TMA_BOX cols], unswizzled row-major. + init_tensor_map(&logits_tmap, reinterpret_cast(logits.data_ptr()), N, V, 1, + TMA_BOX); + + const int smem_size = (SMEM_TILE * sizeof(T)) + 16; + batch_invariant_logp_sm90_kernel<<>>( + logits_tmap, target.data_ptr(), reinterpret_cast(logits.data_ptr()), + logp.data_ptr(), lse.data_ptr(), N, V, ignore_index); +} + +// at::BFloat16 and nv_bfloat16 are layout-compatible; data_ptr needs the ATen +// type, so bf16 gets its own launch that reinterprets the pointer. +void launch_batch_invariant_logp_sm90_bf16(torch::Tensor logits, torch::Tensor target, + torch::Tensor logp, torch::Tensor lse, int N, int V, + int ignore_index) { + constexpr int NUM_WARPS = 4; + CUtensorMap logits_tmap; + init_tensor_map( + &logits_tmap, reinterpret_cast(logits.data_ptr()), N, V, + 1, TMA_BOX); + + const int smem_size = (SMEM_TILE * sizeof(nv_bfloat16)) + 16; + batch_invariant_logp_sm90_kernel<<>>( + logits_tmap, target.data_ptr(), + reinterpret_cast(logits.data_ptr()), + logp.data_ptr(), lse.data_ptr(), N, V, ignore_index); +} + +} // namespace + +std::vector batch_invariant_logp_sm90_forward(torch::Tensor logits, + torch::Tensor target, + int64_t ignore_index) { + TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor"); + TORCH_CHECK(logits.dim() == 2, "logits must be 2-D [N, V]"); + TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); + const int N = logits.size(0); + const int V = logits.size(1); + TORCH_CHECK(target.numel() == N, "target must have one id per row: expected ", N, ", got ", + target.numel()); + + // TMA requires the global row stride (V * elem_size) to be 16-byte aligned. + const int elem_size = static_cast(logits.element_size()); + TORCH_CHECK((static_cast(V) * elem_size) % 16 == 0, + "batch_invariant_logp_sm90 requires the vocab row stride (V * elem_size) to be a " + "multiple of 16 bytes; got V=", + V, ", elem_size=", elem_size); + + auto opts_f = logits.options().dtype(torch::kFloat); + auto logp = torch::empty({N}, opts_f); + auto lse = torch::empty({N}, opts_f); + auto target_i = target.to(torch::kInt32).contiguous(); + + if (logits.scalar_type() == at::kBFloat16) { + launch_batch_invariant_logp_sm90_bf16(logits, target_i, logp, lse, N, V, + static_cast(ignore_index)); + } else if (logits.scalar_type() == at::kFloat) { + launch_batch_invariant_logp_sm90(logits, target_i, logp, lse, N, V, + static_cast(ignore_index)); + } else { + TORCH_CHECK(false, "batch_invariant_logp_sm90 supports only bfloat16 and float32 logits"); + } + + return {logp, lse}; +} diff --git a/csrc/ops.cpp b/csrc/ops.cpp index f48e4f9..361287d 100644 --- a/csrc/ops.cpp +++ b/csrc/ops.cpp @@ -13,6 +13,9 @@ std::vector fused_linear_logp_sm90_forward(torch::Tensor hidden, torch::Tensor weight, torch::Tensor target, torch::optional bias); +std::vector batch_invariant_logp_sm90_forward(torch::Tensor logits, + torch::Tensor target, + int64_t ignore_index); #endif #if defined(__CUDACC__) || defined(KERNEL_ALIGN_WITH_CUDA) @@ -81,6 +84,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_logp_sm90", &fused_logp_sm90_forward, "TMA-accelerated Online Softmax Fused LogP"); m.def("fused_linear_logp_sm90", &fused_linear_logp_sm90_forward, "TMA+WGMMA fused linear log-prob (hidden @ W^T -> selected-token logp), SM90"); + m.def("batch_invariant_logp_sm90", &batch_invariant_logp_sm90_forward, + "TMA online-softmax batch-invariant selected-token log-prob from logits, SM90"); #endif #if defined(__CUDACC__) || defined(KERNEL_ALIGN_WITH_CUDA) diff --git a/docs/.nav.yml b/docs/.nav.yml index 60525c2..3a8ddc0 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -13,6 +13,7 @@ nav: - operators/README.md - operators/fused-logp.md - operators/linear-logp.md + - operators/batch-invariant-logp.md - operators/grpo-loss.md - operators/ratio-kl.md - operators/sampling.md diff --git a/docs/operators/README.md b/docs/operators/README.md index c4eae60..d0174b4 100644 --- a/docs/operators/README.md +++ b/docs/operators/README.md @@ -20,6 +20,7 @@ Every operator page should include: - [Fused LogP](fused-logp.md) - [Fused Linear LogP](linear-logp.md) +- [Batch-Invariant LogP](batch-invariant-logp.md) - [GRPO Loss](grpo-loss.md) - [Policy Ratio + KL Penalty](ratio-kl.md) - [Sampling](sampling.md) diff --git a/docs/operators/batch-invariant-logp.md b/docs/operators/batch-invariant-logp.md new file mode 100644 index 0000000..6523aee --- /dev/null +++ b/docs/operators/batch-invariant-logp.md @@ -0,0 +1,225 @@ +# Batch-Invariant LogP + +Batch-Invariant LogP computes selected token log-probabilities from already +materialized logits: + +```text +out[row] = logits[row, target_ids[row]] - logsumexp(logits[row, :]) +``` + +It targets RL post-training paths where policy log-probs are compared across +different packing, padding, and batch layouts. The key contract is +batch-invariance: for a fixed row of logits and target id, the result must not +change when that row is evaluated alone, at a different batch position, or with +different neighboring rows. + +Unlike `linear_logp`, this operator does not fuse the LM-head projection. It +takes `[*, V]` logits as input and returns one selected log-probability per row. + +## Entry Point + +```python +from rl_engine.kernels.registry import kernel_registry + +batch_invariant_logp = kernel_registry.get_op("batch_invariant_logp") + +logp = batch_invariant_logp( + logits, # [B, T, V] or [N, V], differentiable + target_ids, # [B, T] or [N], int + ignore_index=-100, + validate=False, # Triton fast path; use True to debug-check target range +) # -> [B, T] or [N], float32 + +logp.sum().backward() # gradients flow into logits only +``` + +## Backends + +| Backend | Wrapper | Status | +| --- | --- | --- | +| CUDA (SM90 TMA) | `BatchInvariantLogpSM90Op` | Hopper TMA online-softmax forward. | +| CUDA / ROCm (Triton) | `TritonBatchInvariantLogpOp` | Triton online-softmax forward and tile-wise backward. Requires a GPU tensor. | +| PyTorch native | `NativeBatchInvariantLogpOp` | FP32 reference path; CPU fallback and Triton-less fallback. | + +Current dispatch: + +```text +CUDA (Hopper, SM90 kernel compiled): CUDA (SM90 TMA) -> Triton -> PyTorch +CUDA / ROCm (otherwise): Triton -> PyTorch +CPU: PyTorch +``` + +The SM90 backend is hardware-gated: it is only inserted at the front of the +CUDA priority list when the extension exposes `_C.batch_invariant_logp_sm90` +(built with `KERNEL_ALIGN_FORCE_SM90=1`) on an SM90 device. On any other build +or device, dispatch is unchanged (Triton -> PyTorch). + +## Benchmarks + +`benchmarks/benchmark_batch_invariant_logp.py` compares Native, Triton, and the +CUDA SM90 backend (forward latency and peak VRAM across a vocab sweep, bf16): + +```bash +python benchmarks/benchmark_batch_invariant_logp.py +python benchmarks/benchmark_batch_invariant_logp.py --configs "4096,128256;8192,151936" +``` + +The CUDA column is only shown when the SM90 kernel is compiled in; otherwise the +benchmark reports Native vs Triton only. + +### Measured results + +Environment: NVIDIA H200 (Hopper, SM90, cc 9.0), CUDA 12.8 / `nvcc` 12.8.93, +PyTorch 2.11.0+cu128, `KERNEL_ALIGN_FORCE_SM90=1`. dtype bf16, 20 iters + 5 +warmup. "MB" is peak extra device memory above baseline. + +**Forward** + +| shape (N x V) | native ms | triton ms | cuda ms | cuda vs native | cuda vs triton | native MB | triton MB | cuda MB | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | +| 4096 x 32768 | 1.341 | 0.148 | 0.090 | 14.8x | 1.64x | 1536 | 0 | 0 | +| 4096 x 128256 | 4.964 | 0.566 | 0.323 | 15.4x | 1.75x | 6012 | 0 | 0 | +| 4096 x 151936 | 5.908 | 0.669 | 0.384 | 15.4x | 1.74x | 7122 | 0 | 0 | +| 8192 x 128256 | 9.904 | 1.056 | 0.597 | 16.6x | 1.77x | 12024 | 0 | 0 | + +**Forward + backward** + +| shape (N x V) | native ms | triton ms | cuda ms | cuda vs native | cuda vs triton | native MB | triton MB | cuda MB | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | +| 4096 x 32768 | 3.549 | 0.463 | 0.407 | 8.7x | 1.14x | 1792 | 512 | 512 | +| 4096 x 128256 | 13.173 | 1.750 | 1.510 | 8.7x | 1.16x | 7014 | 2004 | 2004 | +| 4096 x 151936 | 15.632 | 2.072 | 1.788 | 8.7x | 1.16x | 8310 | 2376 | 2376 | +| 8192 x 128256 | 26.211 | 3.416 | 2.955 | 8.9x | 1.16x | 14028 | 4008 | 4008 | + +- Forward: ~1.7x vs Triton, ~15x vs native, with ~0 extra VRAM — the vocab is + reduced to per-row scalars, so no `[N, V]` intermediate is materialized. +- Forward + backward: ~1.15x vs Triton, ~8.8x vs native, with memory equal to + Triton. The backward's `[N, V]` cost is `grad_logits` itself (one gradient per + input logit, unavoidable for any backend); the streamed backends avoid native's + extra `[N, V]` `softmax` / `log_softmax` intermediates by recomputing from the + saved per-row `lse`. + +## Tensor Contract + +| Argument | Shape | Dtype | Requirements | +| --- | --- | --- | --- | +| `logits` | `[N, V]` / `[B, T, V]` / `[*lead, V]` | fp32 / fp16 / bf16 | Differentiable input; last dimension is vocab. | +| `target_ids` | `[N]` / `[B, T]` / `[*lead]` | int | Same leading shape as `logits`; non-ignored values in `[0, V)`. | +| `ignore_index` | scalar int | Python int | Default `-100`. Ignored rows output zero and receive zero gradient. | +| Output | `[N]` / `[B, T]` / `[*lead]` | float32 | Selected log-probability per row. | + +`target_ids` is integer and non-differentiable. Gradients flow only into +`logits`. + +## Reference Semantics + +For non-ignored rows: + +```python +logits_2d = logits.reshape(-1, logits.size(-1)).float() +target_1d = target_ids.reshape(-1).long() + +log_probs = torch.log_softmax(logits_2d, dim=-1) +selected = torch.gather( + log_probs, + dim=-1, + index=target_1d.unsqueeze(-1), +).squeeze(-1) + +out = selected.reshape(target_ids.shape) +``` + +For ignored rows: + +```text +target_ids[row] == ignore_index +out[row] = 0.0 +grad_logits[row, :] = 0.0 +``` + +Non-ignored target ids outside `[0, V)` are invalid. In particular, +`target=-1` is invalid unless `ignore_index=-1`. + +The PyTorch native backend validates target ranges by default. The Triton +backend defaults to `validate=False` to avoid CUDA stream synchronization in +training hot paths. Use `validate=True` during debugging or in tests when +calling the Triton backend with untrusted targets. + +## Batch-Invariance + +The operator is designed so each row is computed independently: + +- The PyTorch path reshapes to `[N, V]` and applies row-wise reductions. +- The Triton forward uses `grid=(num_tokens,)`, so one program owns exactly one + row. +- Triton vocab traversal uses a fixed `_BLOCK_V=1024` and does not autotune by + batch size. +- Triton forward scans vocab tiles left-to-right using online logsumexp. +- Triton backward uses `grid=(num_tokens, vocab_tiles)` and writes one row tile + per program. It reuses the forward-saved per-row `lse`, so no backward + reduction crosses row boundaries. +- No atomic writes are used. + +These constraints ensure the result for a row depends only on that row's logits +and target id, not on batch size, row position, or neighboring rows. + +## Accuracy + +Both backends accumulate reductions in float32 and return float32 outputs. Tests +compare against `torch.log_softmax(...).gather(...)` with dtype-appropriate +tolerances: + +```text +fp32 forward: atol around 1e-5 +fp16/bf16 forward: atol around 1e-4 +fp16/bf16 backward: checked against fp32 reference with relaxed tolerance +``` + +CPU-vs-CUDA comparisons use tolerance-based checks; batch-invariance checks +within the same backend use exact equality where appropriate. + +## Minimal Example + +```python +import torch + +from rl_engine.kernels.registry import kernel_registry + +op = kernel_registry.get_op("batch_invariant_logp") + +logits = torch.randn(2, 4, 300, device="cuda", dtype=torch.bfloat16) +target_ids = torch.randint(0, 300, (2, 4), device="cuda") +target_ids[0, 0] = -100 + +out = op(logits, target_ids, ignore_index=-100) +assert out.shape == target_ids.shape +assert out.dtype == torch.float32 +assert out[0, 0].item() == 0.0 + +out.sum().backward() +``` + +## Tests + +```bash +python -m pytest tests/test_batch_invariant_logp.py -q -rs +``` + +All backends (Native, Triton) are tested in a single file. Coverage includes: +correctness, leading-shape preservation, batch-invariance (bitwise), validation, +ignore-index behavior, backward correctness, CUDA smoke cases, registry +dispatch, and Triton-specific fp32/fp16/bf16 correctness, large vocab, backward +gradient batch-invariance, and ignored-row zero gradients. + +Triton tests skip when Triton or CUDA is unavailable. On Windows, run via +WSL/Linux with CUDA. + +## Implementation Files + +- `rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py` +- `rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py` +- `rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py` +- `csrc/cuda/batch_invariant_logp_kernel_sm90.cu` +- `rl_engine/kernels/registry.py` +- `tests/test_batch_invariant_logp.py` +- `benchmarks/benchmark_batch_invariant_logp.py` diff --git a/rl_engine/_C.pyi b/rl_engine/_C.pyi index fdd6982..6dc0c92 100644 --- a/rl_engine/_C.pyi +++ b/rl_engine/_C.pyi @@ -4,6 +4,11 @@ import torch def fused_logp(logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: ... def fused_logp_sm90(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: ... +def batch_invariant_logp_sm90( + logits: torch.Tensor, + target: torch.Tensor, + ignore_index: int, +) -> list[torch.Tensor]: ... def fused_logp_forward_out( logits: torch.Tensor, token_ids: torch.Tensor, diff --git a/rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py b/rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py new file mode 100644 index 0000000..f78ad93 --- /dev/null +++ b/rl_engine/kernels/ops/cuda/loss/batch_invariant_logp.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import torch + +from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE +from rl_engine.utils.logger import logger + + +def _sm90_supported(logits: torch.Tensor) -> bool: + """Whether the TMA forward can run these logits directly. + bf16/fp32 only, and the TMA descriptor needs the vocab row stride + (``V * element_size``) to be a multiple of 16 bytes. + """ + if not logits.is_cuda or logits.dtype not in (torch.bfloat16, torch.float32): + return False + return (logits.size(-1) * logits.element_size()) % 16 == 0 + + +def _fallback_op(): + """Portable op for inputs the SM90 forward cannot take. Triton, else native.""" + try: + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + + return TritonBatchInvariantLogpOp() + except Exception: # pragma: no cover - Triton missing + from rl_engine.kernels.ops.pytorch.loss.batch_invariant_logp import ( + NativeBatchInvariantLogpOp, + ) + + return NativeBatchInvariantLogpOp() + + +class _BatchInvariantLogpSM90Function(torch.autograd.Function): + # Autograd wrapper: SM90 TMA forward + tile-wise softmax backward. + + @staticmethod + def forward(ctx, logits, target_ids, ignore_index): + lead_shape = logits.shape[:-1] + vocab_size = logits.size(-1) + + logits_2d = logits.reshape(-1, vocab_size).contiguous() + target_1d = target_ids.reshape(-1).to(device=logits.device, dtype=torch.int64).contiguous() + + logp, lse = _C.batch_invariant_logp_sm90(logits_2d, target_1d, int(ignore_index)) + + ctx.save_for_backward(logits_2d, target_1d, lse) + ctx.ignore_index = ignore_index + ctx.lead_shape = lead_shape + ctx.vocab_size = vocab_size + return logp.reshape(lead_shape) + + @staticmethod + def backward(ctx, grad_output): + logits_2d, target_1d, lse = ctx.saved_tensors + ignore_index = ctx.ignore_index + vocab_size = ctx.vocab_size + num_tokens = logits_2d.shape[0] + + grad_flat = grad_output.reshape(-1).contiguous().to(torch.float32) + grad_logits = torch.empty_like(logits_2d) + + # Reuse Triton's tile-wise backward + try: + import triton + + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + _BLOCK_V, + _batch_invariant_logp_bwd_kernel, + ) + + grid = (num_tokens, triton.cdiv(vocab_size, _BLOCK_V)) + _batch_invariant_logp_bwd_kernel[grid]( + logits_2d, + target_1d, + lse, + grad_flat, + grad_logits, + num_tokens, + vocab_size, + logits_2d.stride(0), + ignore_index=ignore_index, + BLOCK_V=_BLOCK_V, + ) + except Exception: # pragma: no cover - Triton missing + valid = target_1d != ignore_index + safe_target = torch.where(valid, target_1d, torch.zeros_like(target_1d)) + probs = torch.exp(logits_2d.float() - lse.unsqueeze(1)) + onehot = torch.zeros_like(probs) + onehot.scatter_(1, safe_target.unsqueeze(1), 1.0) + grad = grad_flat.unsqueeze(1) * (onehot - probs) + grad = torch.where(valid.unsqueeze(1), grad, torch.zeros_like(grad)) + grad_logits = grad.to(logits_2d.dtype) + + grad_logits = grad_logits.reshape(ctx.lead_shape + (vocab_size,)) + return grad_logits, None, None + + +class BatchInvariantLogpSM90Op: + # SM90 (Hopper) TMA batch-invariant selected-token log-probability. + + def __init__(self) -> None: + if not _EXT_AVAILABLE or not hasattr(_C, "batch_invariant_logp_sm90"): + raise RuntimeError( + "batch_invariant_logp_sm90 is not compiled into the extension. Rebuild with " + "KERNEL_ALIGN_FORCE_SM90=1 on an SM90 (Hopper) device: 'pip install -e .'" + ) + logger.info("Successfully linked to precompiled _C.batch_invariant_logp_sm90 kernel.") + + def __call__( + self, + logits: torch.Tensor, + target_ids: torch.Tensor, + ignore_index: int = -100, + *, + validate: bool = False, + ) -> torch.Tensor: + return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate) + + def apply( + self, + logits: torch.Tensor, + target_ids: torch.Tensor, + ignore_index: int = -100, + *, + validate: bool = False, + ) -> torch.Tensor: + if logits.dim() < 2: + raise ValueError( + f"logits must be at least 2-D ([*lead, V]), got shape {tuple(logits.shape)}" + ) + if logits.shape[:-1] != target_ids.shape: + raise ValueError( + f"logits leading shape {tuple(logits.shape[:-1])} must match " + f"target_ids shape {tuple(target_ids.shape)}" + ) + + if not _sm90_supported(logits): + return _fallback_op()(logits, target_ids, ignore_index=ignore_index, validate=validate) + + if validate: + vocab_size = logits.size(-1) + target_flat = target_ids.reshape(-1) + valid_targets = target_flat[target_flat != ignore_index] + if valid_targets.numel() > 0 and ( + (valid_targets < 0).any() or (valid_targets >= vocab_size).any() + ): + bad = valid_targets[(valid_targets < 0) | (valid_targets >= vocab_size)] + raise ValueError( + f"target_ids contains values outside [0, {vocab_size}): {bad.tolist()}" + ) + + return _BatchInvariantLogpSM90Function.apply(logits, target_ids, ignore_index) diff --git a/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py b/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py new file mode 100644 index 0000000..043aa71 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import torch + + +class NativeBatchInvariantLogpOp: + """Batch-invariant selected-token log-probability. + + ``selected_logprob[t] = logits[t, target_ids[t]] - logsumexp(logits[t, :])`` + + All reductions run in FP32. The row-wise max -> subtract -> exp -> sum -> log + pipeline is fully independent per row, so the result for any row depends + only on that row's logits and target - never on batch size or layout. + """ + + def __init__(self) -> None: + pass + + def __call__( + self, + logits: torch.Tensor, + target_ids: torch.Tensor, + ignore_index: int = -100, + *, + validate: bool = True, + ) -> torch.Tensor: + return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate) + + def apply( + self, + logits: torch.Tensor, + target_ids: torch.Tensor, + ignore_index: int = -100, + *, + validate: bool = True, + ) -> torch.Tensor: + self._validate_shapes(logits, target_ids) + + lead_shape = logits.shape[:-1] + vocab_size = logits.size(-1) + + logits_2d = logits.reshape(-1, vocab_size).float() + target_1d = target_ids.reshape(-1).to(logits.device, dtype=torch.long) + + selected_logp = self._row_wise_selected_logprob( + logits_2d, target_1d, ignore_index=ignore_index, validate=validate + ) + + return selected_logp.reshape(lead_shape) + + # ---------------------------------------------------------------------- # + # Core Computation + # ---------------------------------------------------------------------- # + @staticmethod + def _row_wise_selected_logprob( + logits_2d: torch.Tensor, + target_1d: torch.Tensor, + *, + ignore_index: int, + validate: bool = True, + ) -> torch.Tensor: + """Per-row selected logprob with locked reduction order. + + The three reduction steps (max, sum-exp, gather) operate on each row + independently. PyTorch's ``max(dim=-1)`` and ``sum(dim=-1)`` iterate + the vocab dimension in a fixed, deterministic order for a given row + length, and that order does **not** change when more rows are added + to the batch. This is the property that makes the op batch-invariant. + + Accumulation is done entirely in FP32 to avoid half-precision + catastrophic cancellation during the ``exp(logit - max)`` step. + """ + vocab_size = logits_2d.size(1) + + valid_mask = target_1d != ignore_index + + if validate: + valid_targets = target_1d[valid_mask] + if valid_targets.numel() > 0 and ( + (valid_targets < 0).any() or (valid_targets >= vocab_size).any() + ): + bad = valid_targets[(valid_targets < 0) | (valid_targets >= vocab_size)] + raise ValueError( + f"target_ids contains values outside [0, {vocab_size}): {bad.tolist()}" + ) + + safe_target = target_1d.clone() + safe_target[~valid_mask] = 0 + + # logsumexp(z) = log(sum(exp(z - max(z)))) + max(z) + row_max = logits_2d.max(dim=-1).values + shifted = logits_2d - row_max.unsqueeze(-1) + exp_shifted = shifted.exp() + sum_exp = exp_shifted.sum(dim=-1) + log_sum_exp = sum_exp.log() + row_max + + row_indices = torch.arange(logits_2d.size(0), device=logits_2d.device) + selected_logit = logits_2d[row_indices, safe_target] + + selected_logp = selected_logit - log_sum_exp + + selected_logp = selected_logp.where( + valid_mask, torch.zeros_like(selected_logp) + ) + + return selected_logp + + # ---------------------------------------------------------------------- # + # Helper + # ---------------------------------------------------------------------- # + @staticmethod + def _validate_shapes(logits: torch.Tensor, target_ids: torch.Tensor) -> None: + if logits.dim() < 2: + raise ValueError( + f"logits must be at least 2-D ([*lead, V]), got shape {tuple(logits.shape)}" + ) + if logits.shape[:-1] != target_ids.shape: + raise ValueError( + f"logits leading shape {tuple(logits.shape[:-1])} must match " + f"target_ids shape {tuple(target_ids.shape)}" + ) diff --git a/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py b/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py new file mode 100644 index 0000000..12e3472 --- /dev/null +++ b/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +_BLOCK_V: int = 1024 + + +@triton.jit +def _batch_invariant_logp_kernel( + logits_ptr, # logits [N, V] + target_ptr, # target_ids [N] + output_ptr, # selected logprob output [N] + lse_ptr, # per-row log-sum-exp, saved for backward [N] + num_tokens, # N + vocab_size, # V + stride_row, # stride between consecutive rows in logits + ignore_index: tl.constexpr, + BLOCK_V: tl.constexpr, +): + """One program = one token row. Computes selected logprob via online softmax. + + Algorithm (one-pass online log-sum-exp): + m = -inf, s = 0, z_target = 0 + for each vocab tile [v0, v0+BLOCK_V): + load tile, cast to fp32 + collect target logit from tile (if target falls in this tile) + online softmax update: m, s + lse = log(s) + m + output = z_target - lse + + Also stores the per-row lse for use by the backward kernel. + """ + row_id = tl.program_id(0) + if row_id >= num_tokens: + return + + target_id = tl.load(target_ptr + row_id) + is_ignored = target_id == ignore_index + safe_target = tl.where(is_ignored, 0, target_id) + + m = tl.full((), float("-inf"), dtype=tl.float32) + s = tl.zeros((), dtype=tl.float32) + z_target = tl.zeros((), dtype=tl.float32) + + row_base = row_id.to(tl.int64) * stride_row + + for v0 in range(0, vocab_size, BLOCK_V): + cols = v0 + tl.arange(0, BLOCK_V) + mask = cols < vocab_size + + tile = tl.load( + logits_ptr + row_base + cols, + mask=mask, + other=float("-inf"), + ).to(tl.float32) + + is_target = (cols == safe_target) & mask + z_target += tl.sum(tl.where(is_target, tile, 0.0)) + + tile_max = tl.max(tile) + new_m = tl.maximum(m, tile_max) + s = s * tl.exp(m - new_m) + tl.sum(tl.exp(tile - new_m)) + m = new_m + + lse = m + tl.log(s) + result = z_target - lse + result = tl.where(is_ignored, 0.0, result) + + tl.store(output_ptr + row_id, result) + tl.store(lse_ptr + row_id, lse) + + +@triton.jit +def _batch_invariant_logp_bwd_kernel( + logits_ptr, # logits [N, V] + target_ptr, # target_ids [N] + lse_ptr, # per-row log-sum-exp from forward [N] + grad_out_ptr, # upstream gradient [N] + grad_logits_ptr, # gradient output for logits [N, V] + num_tokens, # N + vocab_size, # V + stride_row, # stride between consecutive rows in logits / grad_logits + ignore_index: tl.constexpr, + BLOCK_V: tl.constexpr, +): + row_id = tl.program_id(0) + tile_id = tl.program_id(1) + + cols = tile_id * BLOCK_V + tl.arange(0, BLOCK_V) + mask = cols < vocab_size + + target = tl.load(target_ptr + row_id) + ignored = target == ignore_index + + row_base = row_id.to(tl.int64) * stride_row + + logits = tl.load( + logits_ptr + row_base + cols, + mask=mask, + other=0.0, + ).to(tl.float32) + lse = tl.load(lse_ptr + row_id) + grad_out = tl.load(grad_out_ptr + row_id).to(tl.float32) + + soft = tl.exp(logits - lse) + onehot = tl.where(cols == target, 1.0, 0.0) + grad = grad_out * (onehot - soft) + grad = tl.where(ignored, 0.0, grad) + + tl.store(grad_logits_ptr + row_base + cols, grad, mask=mask) + + +class _BatchInvariantLogpFunction(torch.autograd.Function): + """Autograd wrapper for the Triton batch-invariant logp kernel.""" + + @staticmethod + def forward(ctx, logits, target_ids, ignore_index): + lead_shape = logits.shape[:-1] + vocab_size = logits.size(-1) + + logits_2d = logits.reshape(-1, vocab_size).contiguous() + target_1d = target_ids.reshape(-1).to( + device=logits.device, dtype=torch.int64 + ).contiguous() + + num_tokens = logits_2d.shape[0] + output = torch.empty(num_tokens, device=logits.device, dtype=torch.float32) + lse = torch.empty(num_tokens, device=logits.device, dtype=torch.float32) + + grid = (num_tokens,) + _batch_invariant_logp_kernel[grid]( + logits_2d, + target_1d, + output, + lse, + num_tokens, + vocab_size, + logits_2d.stride(0), + ignore_index=ignore_index, + BLOCK_V=_BLOCK_V, + ) + + ctx.save_for_backward(logits_2d, target_1d, lse) + ctx.ignore_index = ignore_index + ctx.lead_shape = lead_shape + ctx.vocab_size = vocab_size + + return output.reshape(lead_shape) + + @staticmethod + def backward(ctx, grad_output): + logits_2d, target_1d, lse = ctx.saved_tensors + ignore_index = ctx.ignore_index + vocab_size = ctx.vocab_size + num_tokens = logits_2d.shape[0] + + grad_flat = grad_output.reshape(-1).contiguous().to(torch.float32) + grad_logits = torch.empty_like(logits_2d) + + grid = (num_tokens, triton.cdiv(vocab_size, _BLOCK_V)) + _batch_invariant_logp_bwd_kernel[grid]( + logits_2d, + target_1d, + lse, + grad_flat, + grad_logits, + num_tokens, + vocab_size, + logits_2d.stride(0), + ignore_index=ignore_index, + BLOCK_V=_BLOCK_V, + ) + + grad_logits = grad_logits.reshape(ctx.lead_shape + (vocab_size,)) + + return grad_logits, None, None + + +class TritonBatchInvariantLogpOp: + """Triton fused batch-invariant selected-token log-probability. + + Computes ``logits[t, target_ids[t]] - logsumexp(logits[t, :])`` using a + one-pass online softmax Triton kernel with locked reduction order. + + Requires a GPU tensor (CUDA / ROCm). + """ + + def __init__(self) -> None: + pass + + def __call__( + self, + logits: torch.Tensor, + target_ids: torch.Tensor, + ignore_index: int = -100, + *, + validate: bool = False, + ) -> torch.Tensor: + return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate) + + def apply( + self, + logits: torch.Tensor, + target_ids: torch.Tensor, + ignore_index: int = -100, + *, + validate: bool = False, + ) -> torch.Tensor: + if logits.device.type not in ("cuda", "xpu", "hip"): + raise RuntimeError( + "TritonBatchInvariantLogpOp requires a GPU tensor " + f"(CUDA / ROCm / XPU), got device '{logits.device}'." + ) + if logits.dim() < 2: + raise ValueError( + f"logits must be at least 2-D ([*lead, V]), got shape " + f"{tuple(logits.shape)}" + ) + if logits.shape[:-1] != target_ids.shape: + raise ValueError( + f"logits leading shape {tuple(logits.shape[:-1])} must match " + f"target_ids shape {tuple(target_ids.shape)}" + ) + + if validate: + vocab_size = logits.size(-1) + target_flat = target_ids.reshape(-1) + valid_targets = target_flat[target_flat != ignore_index] + if valid_targets.numel() > 0 and ( + (valid_targets < 0).any() or (valid_targets >= vocab_size).any() + ): + bad = valid_targets[ + (valid_targets < 0) | (valid_targets >= vocab_size) + ] + raise ValueError( + f"target_ids contains values outside [0, {vocab_size}): " + f"{bad.tolist()}" + ) + + return _BatchInvariantLogpFunction.apply(logits, target_ids, ignore_index) diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 6780157..e4d4b2a 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -49,6 +49,17 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): TRITON_RATIO_KL = "rl_engine.kernels.ops.triton.loss.ratio_kl.TritonRatioKLOp" PYTORCH_RATIO_KL = "rl_engine.kernels.ops.pytorch.loss.ratio_kl.NativeRatioKLOp" + # Batch-invariant selected-logprob (WS1 #148: locked reduction order) + TRITON_BATCH_INVARIANT_LOGP = ( + "rl_engine.kernels.ops.triton.loss.batch_invariant_logp.TritonBatchInvariantLogpOp" + ) + PYTORCH_BATCH_INVARIANT_LOGP = ( + "rl_engine.kernels.ops.pytorch.loss.batch_invariant_logp.NativeBatchInvariantLogpOp" + ) + CUDA_BATCH_INVARIANT_LOGP_SM90 = ( + "rl_engine.kernels.ops.cuda.loss.batch_invariant_logp.BatchInvariantLogpSM90Op" + ) + # Generic fallback TRITON_GENERIC = "rl_engine.kernels.ops.triton.generic.TritonOp" PYTORCH_ATTN = "rl_engine.kernels.ops.pytorch.attention.NativeAttentionOp" @@ -89,6 +100,10 @@ def __init__(self): "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "batch_invariant_logp": [ + OpBackend.TRITON_BATCH_INVARIANT_LOGP, + OpBackend.PYTORCH_BATCH_INVARIANT_LOGP, + ], # Default dispatch logic for new operators }, "rocm": { @@ -101,6 +116,10 @@ def __init__(self): "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "batch_invariant_logp": [ + OpBackend.TRITON_BATCH_INVARIANT_LOGP, + OpBackend.PYTORCH_BATCH_INVARIANT_LOGP, + ], }, "cpu": { "logp": [OpBackend.PYTORCH_NATIVE], @@ -108,6 +127,7 @@ def __init__(self): "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.PYTORCH_RATIO_KL], + "batch_invariant_logp": [OpBackend.PYTORCH_BATCH_INVARIANT_LOGP], }, } logger.info(f"KernelRegistry initialized for {device_ctx.device_type}") @@ -164,6 +184,13 @@ def _adjust_priority_for_hardware(self): ll_list = self._priority_map["cuda"]["linear_logp"] if OpBackend.CUDA_FUSED_LINEAR_LOGP_SM90 not in ll_list: ll_list.insert(0, OpBackend.CUDA_FUSED_LINEAR_LOGP_SM90) + + # Batch-invariant logp SM90 kernel: same sm_90a TMA gating (Hopper only). + batch_inv_compiled = _EXT_AVAILABLE and hasattr(_C, "batch_invariant_logp_sm90") + if batch_inv_compiled and cc_major == 9: + bi_list = self._priority_map["cuda"]["batch_invariant_logp"] + if OpBackend.CUDA_BATCH_INVARIANT_LOGP_SM90 not in bi_list: + bi_list.insert(0, OpBackend.CUDA_BATCH_INVARIANT_LOGP_SM90) elif cc >= 90: logger.debug( f"SM{cc}: fused TMA LogP kernel not compiled into _C; " diff --git a/setup.py b/setup.py index d5ddb89..8ba45e9 100644 --- a/setup.py +++ b/setup.py @@ -113,6 +113,7 @@ def get_extensions(): sm90_srcs = [ "csrc/cuda/fused_logp_sm90.cu", "csrc/cuda/fused_linear_logp_sm90.cu", # TMA + WGMMA fused linear log-prob + "csrc/cuda/batch_invariant_logp_kernel_sm90.cu", # TMA batch-invariant logp ] enable_sm90 = os.environ.get("KERNEL_ALIGN_FORCE_SM90") == "1" present_sm90 = [s for s in sm90_srcs if os.path.exists(s)] diff --git a/tests/test_batch_invariant_logp.py b/tests/test_batch_invariant_logp.py new file mode 100644 index 0000000..e718872 --- /dev/null +++ b/tests/test_batch_invariant_logp.py @@ -0,0 +1,1040 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Tests for batch-invariant selected-logprob (issue #148). + +The test suite validates two orthogonal properties: +1. **Correctness** - output matches ``log_softmax + gather`` reference. +2. **Batch-invariance** - the result for a given row is bitwise identical + regardless of batch size, batch position, padding, or mixed-batch layout. +""" + +import pytest +import torch + +from rl_engine.kernels.ops.pytorch.loss.batch_invariant_logp import ( + NativeBatchInvariantLogpOp, +) +from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp + + +_V = 300 + +requires_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA device required.", +) + +try: + import triton # noqa: F401 + _HAS_TRITON = True +except ImportError: + _HAS_TRITON = False + +requires_triton_cuda = pytest.mark.skipif( + not (_HAS_TRITON and torch.cuda.is_available()), + reason="Triton batch-invariant logp requires CUDA device and Triton.", +) + +requires_triton = pytest.mark.skipif( + not _HAS_TRITON, + reason="Triton package required.", +) + + +def _sm90_kernel_available() -> bool: + if not torch.cuda.is_available(): + return False + try: + from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE + except Exception: + return False + return ( + _EXT_AVAILABLE + and hasattr(_C, "batch_invariant_logp_sm90") + and torch.cuda.get_device_capability()[0] == 9 + ) + + +requires_sm90 = pytest.mark.skipif( + not _sm90_kernel_available(), + reason="batch_invariant_logp_sm90 kernel not compiled " + "(needs KERNEL_ALIGN_FORCE_SM90=1 on an SM90/Hopper device).", +) + +# 16-byte-aligned vocab so the TMA forward runs directly (not the fallback): +# bf16 needs V % 8 == 0, fp32 needs V % 4 == 0. +_VC = 1024 + + +def _reference_logp(logits: torch.Tensor, target_ids: torch.Tensor) -> torch.Tensor: + """Canonical reference: log_softmax(fp32) + gather.""" + logits_2d = logits.reshape(-1, logits.size(-1)).float() + target_1d = target_ids.reshape(-1).long() + log_probs = torch.log_softmax(logits_2d, dim=-1) + selected = torch.gather(log_probs, dim=-1, index=target_1d.unsqueeze(1)).squeeze(1) + return selected.reshape(target_ids.shape) + + +def _make_row(seed: int, vocab: int = _V, device: str = "cpu") -> torch.Tensor: + """Generate a single deterministic logit row from a seed.""" + gen = torch.Generator(device=device).manual_seed(seed) + return torch.randn(1, vocab, generator=gen, device=device) + + +# --------------------------------------------------------------------------- +# 1. Correctness tests +# --------------------------------------------------------------------------- + + +class TestCorrectness: + """Output must match the canonical ``log_softmax + gather`` reference.""" + + def test_matches_reference_basic(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(8, _V) + target = torch.randint(0, _V, (8,)) + out = op(logits, target) + ref = _reference_logp(logits, target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-6) + + def test_matches_native_logp_op(self): + bi_op = NativeBatchInvariantLogpOp() + native_op = NativeLogpOp() + logits = torch.randn(16, _V) + target = torch.randint(0, _V, (16,)) + out_bi = bi_op(logits, target) + out_native = native_op.apply_fp32(logits, target) + assert torch.allclose(out_bi, out_native, atol=1e-6) + + def test_leading_shape_preserved(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(3, 5, _V) + target = torch.randint(0, _V, (3, 5)) + out = op(logits, target) + assert out.shape == (3, 5) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-6) + + def test_bf16_input_fp32_output(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(8, _V, dtype=torch.bfloat16) + target = torch.randint(0, _V, (8,)) + out = op(logits, target) + assert out.dtype == torch.float32 + ref = _reference_logp(logits.float(), target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_fp16_input_fp32_output(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(8, _V, dtype=torch.float16) + target = torch.randint(0, _V, (8,)) + out = op(logits, target) + assert out.dtype == torch.float32 + ref = _reference_logp(logits.float(), target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_single_token(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(1, _V) + target = torch.randint(0, _V, (1,)) + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-6) + + def test_vocab_size_1(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, 1) + target = torch.zeros(4, dtype=torch.long) + out = op(logits, target) + assert torch.allclose(out, torch.zeros(4), atol=1e-6) + + def test_large_vocab(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, 128256) + target = torch.randint(0, 128256, (4,)) + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + +# --------------------------------------------------------------------------- +# 2. Batch-invariance sweep tests - the core of issue #148 +# --------------------------------------------------------------------------- + + +class TestBatchInvariance: + """Same row must produce bitwise-identical output regardless of batch context.""" + + def _get_row_result_in_batch(self, row_logits, row_target, batch_size, position): + """Embed ``row_logits`` at *position* in a random batch of *batch_size* + and return the selected logprob for that row.""" + op = NativeBatchInvariantLogpOp() + vocab = row_logits.size(-1) + batch_logits = torch.randn(batch_size, vocab) + batch_target = torch.randint(0, vocab, (batch_size,)) + batch_logits[position] = row_logits.squeeze(0) + batch_target[position] = row_target.squeeze(0) + out = op(batch_logits, batch_target) + return out[position] + + def test_batch_size_1_vs_n(self): + """Same row in batch=1 vs batch=N must be bitwise equal.""" + op = NativeBatchInvariantLogpOp() + row = _make_row(42) + target = torch.tensor([7]) + + result_alone = op(row, target).item() + + for batch_size in [2, 4, 8, 16, 32, 64, 128]: + result_in_batch = self._get_row_result_in_batch( + row, target, batch_size, position=0 + ).item() + assert result_alone == result_in_batch, ( + f"Drift at batch_size={batch_size}: " + f"alone={result_alone}, in_batch={result_in_batch}" + ) + + def test_different_positions_in_batch(self): + """Same row at different positions in the same batch must be bitwise equal.""" + op = NativeBatchInvariantLogpOp() + row = _make_row(99) + target = torch.tensor([13]) + + batch_size = 16 + results = [] + for pos in range(batch_size): + val = self._get_row_result_in_batch(row, target, batch_size, pos).item() + results.append(val) + + assert all(r == results[0] for r in results), ( + f"Position-dependent drift detected: unique values = {set(results)}" + ) + + def test_mixed_batch_content(self): + """Changing *other* rows in the batch must not affect our row's result.""" + op = NativeBatchInvariantLogpOp() + row = _make_row(77) + target = torch.tensor([25]) + + batch_size = 8 + results = [] + for trial_seed in range(20): + torch.manual_seed(trial_seed * 1000) + batch_logits = torch.randn(batch_size, _V) + batch_target = torch.randint(0, _V, (batch_size,)) + batch_logits[3] = row.squeeze(0) + batch_target[3] = target.squeeze(0) + out = op(batch_logits, batch_target) + results.append(out[3].item()) + + assert all(r == results[0] for r in results), ( + f"Mixed-batch drift: unique values = {set(results)}" + ) + + def test_padding_layout_invariance(self): + """Left-padding vs right-padding must not affect real rows.""" + op = NativeBatchInvariantLogpOp() + row = _make_row(55) + target = torch.tensor([42]) + + pad_logits = torch.zeros(1, _V) + pad_target = torch.tensor([0]) + + batch_left = torch.cat([pad_logits, pad_logits, row], dim=0) + target_left = torch.cat([pad_target, pad_target, target], dim=0) + + batch_right = torch.cat([row, pad_logits, pad_logits], dim=0) + target_right = torch.cat([target, pad_target, pad_target], dim=0) + + out_left = op(batch_left, target_left) + out_right = op(batch_right, target_right) + + assert out_left[2].item() == out_right[0].item(), ( + "Padding layout changed the result" + ) + + def test_repeated_runs_deterministic(self): + """Same input repeated N times must produce bitwise-identical output.""" + op = NativeBatchInvariantLogpOp() + logits = torch.randn(16, _V) + target = torch.randint(0, _V, (16,)) + + results = [op(logits, target) for _ in range(50)] + for i, r in enumerate(results[1:], 1): + assert torch.equal(r, results[0]), f"Run {i} differs from run 0" + + def test_batch_invariance_with_ignore_index(self): + """Ignored positions must not affect other rows and must output 0.0.""" + op = NativeBatchInvariantLogpOp() + row = _make_row(33) + target_val = 10 + + batch_a = torch.cat([row, torch.randn(3, _V)], dim=0) + target_a = torch.tensor([target_val, 5, 8, 2]) + out_a = op(batch_a, target_a) + + target_b = torch.tensor([target_val, -100, -100, -100]) + out_b = op(batch_a, target_b) + + assert out_a[0].item() == out_b[0].item(), ( + "ignore_index on other rows changed row 0" + ) + assert out_b[1].item() == 0.0 + assert out_b[2].item() == 0.0 + assert out_b[3].item() == 0.0 + + +# --------------------------------------------------------------------------- +# 3. Shape / validation tests +# --------------------------------------------------------------------------- + + +class TestValidation: + + def test_rejects_1d_logits(self): + op = NativeBatchInvariantLogpOp() + with pytest.raises(ValueError, match="at least 2-D"): + op(torch.randn(10), torch.tensor([0])) + + def test_rejects_shape_mismatch(self): + op = NativeBatchInvariantLogpOp() + with pytest.raises(ValueError, match="must match"): + op(torch.randn(4, _V), torch.randint(0, _V, (5,))) + + def test_rejects_negative_target(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, _V) + target = torch.tensor([0, -1, 2, 3]) + with pytest.raises(ValueError, match="outside"): + op(logits, target) + + def test_rejects_target_ge_vocab(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, _V) + target = torch.tensor([0, 1, _V, 3]) + with pytest.raises(ValueError, match="outside"): + op(logits, target) + + def test_negative_target_with_ignore_index_ok(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, _V) + target = torch.tensor([0, -100, 2, 3]) + out = op(logits, target) + assert out[1].item() == 0.0 + + def test_3d_logits(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(2, 3, _V) + target = torch.randint(0, _V, (2, 3)) + out = op(logits, target) + assert out.shape == (2, 3) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-6) + + +# --------------------------------------------------------------------------- +# 4. Backward / gradient tests +# --------------------------------------------------------------------------- + + +class TestBackward: + """Gradient must match the reference log_softmax + gather backward.""" + + def test_backward_matches_reference(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, _V, requires_grad=True) + target = torch.randint(0, _V, (4,)) + + out = op(logits, target).sum() + out.backward() + grad = logits.grad.detach().clone() + + ref_logits = logits.detach().clone().requires_grad_(True) + ref = _reference_logp(ref_logits, target).sum() + ref.backward() + + assert torch.allclose(grad, ref_logits.grad, atol=1e-6) + + def test_gradient_batch_invariance(self): + """Same row's gradient must be bitwise equal in batch=1 vs batch=N.""" + op = NativeBatchInvariantLogpOp() + row = _make_row(42) + target = torch.tensor([7]) + + logits_alone = row.clone().requires_grad_(True) + op(logits_alone, target).sum().backward() + grad_alone = logits_alone.grad.detach().clone() + + for batch_size in [4, 16, 64]: + batch_logits = torch.randn(batch_size, _V) + batch_logits[0] = row.squeeze(0) + batch_logits.requires_grad_(True) + batch_target = torch.randint(0, _V, (batch_size,)) + batch_target[0] = target.squeeze(0) + op(batch_logits, batch_target).sum().backward() + grad_in_batch = batch_logits.grad[0:1].detach().clone() + assert torch.equal(grad_alone, grad_in_batch), ( + f"Gradient drift at batch_size={batch_size}" + ) + + +# --------------------------------------------------------------------------- +# 5. Edge cases: all-ignore and custom ignore_index +# --------------------------------------------------------------------------- + + +class TestIgnoreEdgeCases: + + def test_all_ignore_index_outputs_zero(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, _V) + target = torch.full((4,), -100) + out = op(logits, target) + assert torch.equal(out, torch.zeros_like(out)) + + def test_custom_ignore_index(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, _V) + target = torch.tensor([0, -1, 2, 3]) + out = op(logits, target, ignore_index=-1) + assert out[1].item() == 0.0 + valid_idx = [0, 2, 3] + ref = _reference_logp(logits[valid_idx], target[valid_idx]) + assert torch.allclose(out[valid_idx], ref, atol=1e-6) + + +# --------------------------------------------------------------------------- +# 6. CUDA tests - same logic on GPU +# --------------------------------------------------------------------------- + + +@requires_cuda +class TestCUDACorrectness: + """Correctness on CUDA device.""" + + def test_matches_reference_cuda(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(8, _V, device="cuda") + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert out.device.type == "cuda" + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-6) + + def test_bf16_cuda(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(8, _V, device="cuda", dtype=torch.bfloat16) + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + assert out.dtype == torch.float32 + ref = _reference_logp(logits.float(), target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_large_vocab_cuda(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, 128256, device="cuda") + target = torch.randint(0, 128256, (4,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + +@requires_cuda +class TestCUDABatchInvariance: + """Batch-invariance on CUDA - the most important GPU validation.""" + + def test_batch_size_1_vs_n_cuda(self): + op = NativeBatchInvariantLogpOp() + row = _make_row(42, device="cuda") + target = torch.tensor([7], device="cuda") + result_alone = op(row, target).item() + + for batch_size in [2, 4, 8, 16, 32, 64, 128]: + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[0] = row.squeeze(0) + batch_target[0] = target.squeeze(0) + result_in_batch = op(batch_logits, batch_target)[0].item() + assert result_alone == result_in_batch, ( + f"CUDA drift at batch_size={batch_size}: " + f"alone={result_alone}, in_batch={result_in_batch}" + ) + + def test_different_positions_cuda(self): + op = NativeBatchInvariantLogpOp() + row = _make_row(99, device="cuda") + target = torch.tensor([13], device="cuda") + batch_size = 16 + results = [] + for pos in range(batch_size): + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[pos] = row.squeeze(0) + batch_target[pos] = target.squeeze(0) + results.append(op(batch_logits, batch_target)[pos].item()) + assert all(r == results[0] for r in results), ( + f"CUDA position drift: unique = {set(results)}" + ) + + def test_repeated_runs_cuda(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(16, _V, device="cuda") + target = torch.randint(0, _V, (16,), device="cuda") + results = [op(logits, target) for _ in range(50)] + for i, r in enumerate(results[1:], 1): + assert torch.equal(r, results[0]), f"CUDA run {i} differs from run 0" + + def test_cpu_gpu_cross_check(self): + """Same input on CPU vs CUDA should match within tolerance.""" + op = NativeBatchInvariantLogpOp() + logits_cpu = torch.randn(8, _V) + target_cpu = torch.randint(0, _V, (8,)) + out_cpu = op(logits_cpu, target_cpu) + out_cuda = op(logits_cpu.cuda(), target_cpu.cuda()) + assert torch.allclose(out_cpu, out_cuda.cpu(), atol=1e-6, rtol=1e-6), ( + "CPU vs CUDA result mismatch" + ) + + + +# --------------------------------------------------------------------------- +# 7. Triton backend tests +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# Correctness: Triton vs reference +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonCorrectness: + """Triton kernel output must match log_softmax + gather reference.""" + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_matches_reference_fp32(self): + op = self._get_op() + logits = torch.randn(8, _V, device="cuda") + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-5) + + def test_matches_reference_bf16(self): + op = self._get_op() + logits = torch.randn(8, _V, device="cuda", dtype=torch.bfloat16) + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits.float(), target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-4) + + def test_matches_reference_fp16(self): + op = self._get_op() + logits = torch.randn(8, _V, device="cuda", dtype=torch.float16) + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits.float(), target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-4) + + def test_large_vocab(self): + op = self._get_op() + logits = torch.randn(4, 128256, device="cuda") + target = torch.randint(0, 128256, (4,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_single_token(self): + op = self._get_op() + logits = torch.randn(1, _V, device="cuda") + target = torch.randint(0, _V, (1,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_3d_logits(self): + op = self._get_op() + logits = torch.randn(2, 3, _V, device="cuda") + target = torch.randint(0, _V, (2, 3), device="cuda") + out = op(logits, target) + assert out.shape == (2, 3) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_matches_pytorch_op(self): + """Triton and PyTorch ops should agree within tolerance.""" + from rl_engine.kernels.ops.pytorch.loss.batch_invariant_logp import ( + NativeBatchInvariantLogpOp, + ) + triton_op = self._get_op() + pytorch_op = NativeBatchInvariantLogpOp() + logits = torch.randn(16, _V, device="cuda") + target = torch.randint(0, _V, (16,), device="cuda") + out_triton = triton_op(logits, target) + out_pytorch = pytorch_op(logits, target) + assert torch.allclose(out_triton, out_pytorch, atol=1e-5) + + +# --------------------------------------------------------------------------- +# Batch-invariance on GPU via Triton +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonBatchInvariance: + """Triton kernel must be bitwise batch-invariant.""" + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_batch_size_1_vs_n(self): + op = self._get_op() + row = _make_row(42, device="cuda") + target = torch.tensor([7], device="cuda") + result_alone = op(row, target).item() + + for batch_size in [2, 4, 8, 16, 32, 64, 128]: + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[0] = row.squeeze(0) + batch_target[0] = target.squeeze(0) + result_in_batch = op(batch_logits, batch_target)[0].item() + assert result_alone == result_in_batch, ( + f"Triton drift at batch_size={batch_size}: " + f"alone={result_alone}, in_batch={result_in_batch}" + ) + + def test_different_positions(self): + op = self._get_op() + row = _make_row(99, device="cuda") + target = torch.tensor([13], device="cuda") + batch_size = 16 + results = [] + for pos in range(batch_size): + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[pos] = row.squeeze(0) + batch_target[pos] = target.squeeze(0) + results.append(op(batch_logits, batch_target)[pos].item()) + assert all(r == results[0] for r in results), ( + f"Triton position drift: unique = {set(results)}" + ) + + def test_repeated_runs(self): + op = self._get_op() + logits = torch.randn(16, _V, device="cuda") + target = torch.randint(0, _V, (16,), device="cuda") + results = [op(logits, target) for _ in range(50)] + for i, r in enumerate(results[1:], 1): + assert torch.equal(r, results[0]), f"Triton run {i} differs from run 0" + + def test_mixed_batch_content(self): + op = self._get_op() + row = _make_row(77, device="cuda") + target = torch.tensor([25], device="cuda") + batch_size = 8 + results = [] + for trial_seed in range(20): + torch.manual_seed(trial_seed * 1000) + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[3] = row.squeeze(0) + batch_target[3] = target.squeeze(0) + results.append(op(batch_logits, batch_target)[3].item()) + assert all(r == results[0] for r in results), ( + f"Triton mixed-batch drift: unique = {set(results)}" + ) + + +# --------------------------------------------------------------------------- +# Backward / gradient +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonBackward: + """Gradient through the Triton op must match reference.""" + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_backward_matches_reference(self): + op = self._get_op() + logits = torch.randn(4, _V, device="cuda", requires_grad=True) + target = torch.randint(0, _V, (4,), device="cuda") + + out = op(logits, target).sum() + out.backward() + grad = logits.grad.detach().clone() + + ref_logits = logits.detach().clone().requires_grad_(True) + ref = _reference_logp(ref_logits, target).sum() + ref.backward() + + assert torch.allclose(grad, ref_logits.grad, atol=1e-5) + + def test_gradient_batch_invariance(self): + op = self._get_op() + row = _make_row(42, device="cuda") + target = torch.tensor([7], device="cuda") + + logits_alone = row.clone().requires_grad_(True) + op(logits_alone, target).sum().backward() + grad_alone = logits_alone.grad.detach().clone() + + for batch_size in [4, 16, 64]: + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_logits[0] = row.squeeze(0) + batch_logits.requires_grad_(True) + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_target[0] = target.squeeze(0) + op(batch_logits, batch_target).sum().backward() + grad_in_batch = batch_logits.grad[0:1].detach().clone() + assert torch.allclose(grad_alone, grad_in_batch, atol=1e-5), ( + f"Triton gradient drift at batch_size={batch_size}" + ) + + def test_ignored_row_grad_is_zero(self): + """Ignored rows must have zero gradient across the entire vocab.""" + op = self._get_op() + logits = torch.randn(4, _V, device="cuda", requires_grad=True) + target = torch.tensor([0, -100, 2, -100], device="cuda") + op(logits, target).sum().backward() + assert torch.equal(logits.grad[1], torch.zeros(_V, device="cuda")) + assert torch.equal(logits.grad[3], torch.zeros(_V, device="cuda")) + + def test_backward_bf16_input(self): + """Backward with bf16 logits should match fp32 reference within tolerance.""" + op = self._get_op() + logits = torch.randn(8, _V, device="cuda", dtype=torch.bfloat16, requires_grad=True) + target = torch.randint(0, _V, (8,), device="cuda") + + op(logits, target).sum().backward() + grad = logits.grad.detach().clone() + + ref_logits = logits.detach().float().requires_grad_(True) + _reference_logp(ref_logits, target).sum().backward() + + assert torch.allclose(grad.float(), ref_logits.grad, atol=1e-2) + + def test_backward_fp16_input(self): + """Backward with fp16 logits should match fp32 reference within tolerance.""" + op = self._get_op() + logits = torch.randn(8, _V, device="cuda", dtype=torch.float16, requires_grad=True) + target = torch.randint(0, _V, (8,), device="cuda") + + op(logits, target).sum().backward() + grad = logits.grad.detach().clone() + + ref_logits = logits.detach().float().requires_grad_(True) + _reference_logp(ref_logits, target).sum().backward() + + assert torch.allclose(grad.float(), ref_logits.grad, atol=1e-2) + + +# --------------------------------------------------------------------------- +# ignore_index handling +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonIgnoreIndex: + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_ignore_outputs_zero(self): + op = self._get_op() + logits = torch.randn(4, _V, device="cuda") + target = torch.tensor([0, -100, 2, -100], device="cuda") + out = op(logits, target) + assert out[1].item() == 0.0 + assert out[3].item() == 0.0 + ref = _reference_logp(logits[[0, 2]], target[[0, 2]]) + assert torch.allclose(out[[0, 2]], ref, atol=1e-5) + + def test_all_ignore(self): + op = self._get_op() + logits = torch.randn(4, _V, device="cuda") + target = torch.full((4,), -100, device="cuda") + out = op(logits, target) + assert torch.equal(out, torch.zeros(4, device="cuda")) + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +@requires_triton +class TestTritonCPUValidation: + """Tests that only need Triton importable, not a GPU.""" + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_rejects_cpu_tensor(self): + op = self._get_op() + with pytest.raises(RuntimeError, match="requires a GPU"): + op(torch.randn(4, _V), torch.randint(0, _V, (4,))) + + +@requires_triton_cuda +class TestTritonValidation: + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_rejects_1d_logits(self): + op = self._get_op() + with pytest.raises(ValueError, match="at least 2-D"): + op(torch.randn(10, device="cuda"), torch.tensor([0], device="cuda")) + + def test_rejects_invalid_target(self): + op = self._get_op() + logits = torch.randn(4, _V, device="cuda") + target = torch.tensor([0, -1, 2, 3], device="cuda") + with pytest.raises(ValueError, match="outside"): + op(logits, target, validate=True) + + +# --------------------------------------------------------------------------- +# 7b. CUDA SM90 TMA kernel backend +# --------------------------------------------------------------------------- + + +@requires_sm90 +class TestCudaSM90Correctness: + """Compiled SM90 TMA kernel output must match log_softmax + gather.""" + + def _get_op(self): + from rl_engine.kernels.ops.cuda.loss.batch_invariant_logp import ( + BatchInvariantLogpSM90Op, + ) + return BatchInvariantLogpSM90Op() + + def test_matches_reference_fp32(self): + op = self._get_op() + logits = torch.randn(8, _VC, device="cuda") + target = torch.randint(0, _VC, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-4) + + def test_matches_reference_bf16(self): + op = self._get_op() + logits = torch.randn(8, _VC, device="cuda", dtype=torch.bfloat16) + target = torch.randint(0, _VC, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits.float(), target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-3) + + def test_large_vocab(self): + op = self._get_op() + logits = torch.randn(4, 128256, device="cuda", dtype=torch.bfloat16) + target = torch.randint(0, 128256, (4,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits.float(), target) + assert torch.allclose(out, ref, atol=2e-3) + + def test_unaligned_vocab(self): + # V not a multiple of the TMA box: exercises the global-read tail path. + op = self._get_op() + logits = torch.randn(8, 50257, device="cuda", dtype=torch.bfloat16) + target = torch.randint(0, 50257, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits.float(), target) + assert torch.allclose(out, ref, atol=2e-3) + + def test_single_token(self): + op = self._get_op() + logits = torch.randn(1, _VC, device="cuda") + target = torch.randint(0, _VC, (1,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-4) + + def test_3d_logits(self): + op = self._get_op() + logits = torch.randn(2, 3, _VC, device="cuda", dtype=torch.bfloat16) + target = torch.randint(0, _VC, (2, 3), device="cuda") + out = op(logits, target) + assert out.shape == (2, 3) + ref = _reference_logp(logits.float(), target) + assert torch.allclose(out, ref, atol=1e-3) + + def test_matches_pytorch_op(self): + op = self._get_op() + pytorch_op = NativeBatchInvariantLogpOp() + logits = torch.randn(16, _VC, device="cuda") + target = torch.randint(0, _VC, (16,), device="cuda") + assert torch.allclose(op(logits, target), pytorch_op(logits, target), atol=1e-4) + + +@requires_sm90 +class TestCudaSM90BatchInvariance: + """SM90 kernel must be bitwise batch-invariant (one CTA per row).""" + + def _get_op(self): + from rl_engine.kernels.ops.cuda.loss.batch_invariant_logp import ( + BatchInvariantLogpSM90Op, + ) + return BatchInvariantLogpSM90Op() + + def test_batch_size_1_vs_n(self): + op = self._get_op() + row = _make_row(42, vocab=_VC, device="cuda") + target = torch.tensor([7], device="cuda") + result_alone = op(row, target).item() + + for batch_size in [2, 4, 8, 16, 32, 64, 128]: + batch_logits = torch.randn(batch_size, _VC, device="cuda") + batch_target = torch.randint(0, _VC, (batch_size,), device="cuda") + batch_logits[0] = row.squeeze(0) + batch_target[0] = target.squeeze(0) + result_in_batch = op(batch_logits, batch_target)[0].item() + assert result_alone == result_in_batch, ( + f"SM90 drift at batch_size={batch_size}: " + f"alone={result_alone}, in_batch={result_in_batch}" + ) + + def test_different_positions(self): + op = self._get_op() + row = _make_row(99, vocab=_VC, device="cuda") + target = torch.tensor([13], device="cuda") + batch_size = 16 + results = [] + for pos in range(batch_size): + batch_logits = torch.randn(batch_size, _VC, device="cuda") + batch_target = torch.randint(0, _VC, (batch_size,), device="cuda") + batch_logits[pos] = row.squeeze(0) + batch_target[pos] = target.squeeze(0) + results.append(op(batch_logits, batch_target)[pos].item()) + assert all(r == results[0] for r in results), ( + f"SM90 position drift: unique = {set(results)}" + ) + + def test_repeated_runs(self): + op = self._get_op() + logits = torch.randn(16, _VC, device="cuda", dtype=torch.bfloat16) + target = torch.randint(0, _VC, (16,), device="cuda") + results = [op(logits, target) for _ in range(50)] + for i, r in enumerate(results[1:], 1): + assert torch.equal(r, results[0]), f"SM90 run {i} differs from run 0" + + +@requires_sm90 +class TestCudaSM90Backward: + """Gradient through the SM90 op must match reference.""" + + def _get_op(self): + from rl_engine.kernels.ops.cuda.loss.batch_invariant_logp import ( + BatchInvariantLogpSM90Op, + ) + return BatchInvariantLogpSM90Op() + + def test_backward_matches_reference(self): + op = self._get_op() + logits = torch.randn(4, _VC, device="cuda", requires_grad=True) + target = torch.randint(0, _VC, (4,), device="cuda") + op(logits, target).sum().backward() + grad = logits.grad.detach().clone() + + ref_logits = logits.detach().clone().requires_grad_(True) + _reference_logp(ref_logits, target).sum().backward() + assert torch.allclose(grad, ref_logits.grad, atol=1e-4) + + def test_ignored_row_grad_is_zero(self): + op = self._get_op() + logits = torch.randn(4, _VC, device="cuda", requires_grad=True) + target = torch.tensor([0, -100, 2, -100], device="cuda") + op(logits, target).sum().backward() + assert torch.equal(logits.grad[1], torch.zeros(_VC, device="cuda")) + assert torch.equal(logits.grad[3], torch.zeros(_VC, device="cuda")) + + +@requires_sm90 +class TestCudaSM90IgnoreIndex: + + def _get_op(self): + from rl_engine.kernels.ops.cuda.loss.batch_invariant_logp import ( + BatchInvariantLogpSM90Op, + ) + return BatchInvariantLogpSM90Op() + + def test_ignore_outputs_zero(self): + op = self._get_op() + logits = torch.randn(4, _VC, device="cuda") + target = torch.tensor([0, -100, 2, -100], device="cuda") + out = op(logits, target) + assert out[1].item() == 0.0 + assert out[3].item() == 0.0 + ref = _reference_logp(logits[[0, 2]], target[[0, 2]]) + assert torch.allclose(out[[0, 2]], ref, atol=1e-4) + + +@requires_sm90 +class TestCudaSM90Fallback: + """Inputs the TMA path can't take must silently fall back and stay correct.""" + + def _get_op(self): + from rl_engine.kernels.ops.cuda.loss.batch_invariant_logp import ( + BatchInvariantLogpSM90Op, + ) + return BatchInvariantLogpSM90Op() + + def test_fp16_falls_back(self): + op = self._get_op() + logits = torch.randn(8, _VC, device="cuda", dtype=torch.float16) + target = torch.randint(0, _VC, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits.float(), target) + assert torch.allclose(out, ref, atol=1e-3) + + +# --------------------------------------------------------------------------- +# 8. Registry dispatch test +# --------------------------------------------------------------------------- + + +def test_registry_dispatches_correctly(): + from rl_engine.kernels.registry import kernel_registry + + op = kernel_registry.get_op("batch_invariant_logp") + assert ( + isinstance(op, NativeBatchInvariantLogpOp) + or type(op).__name__ == "TritonBatchInvariantLogpOp" + or type(op).__name__ == "BatchInvariantLogpSM90Op" + ) + logits = torch.randn(4, _V, device="cuda" if torch.cuda.is_available() else "cpu") + target = torch.randint(0, _V, (4,), device=logits.device) + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-6)