From d829be209b7e6dc574c1d5b3a71865bc60a305c8 Mon Sep 17 00:00:00 2001 From: hihaluemen <1596916766@qq.com> Date: Sun, 28 Jun 2026 20:05:55 +0800 Subject: [PATCH 1/4] [WS1][kernels] Batch-invariant logprob (selected, locked reduction) Implements batch_invariant_logp for selected-token log probabilities from materialized logits with row-local, batch-invariant semantics. - PyTorch NativeBatchInvariantLogpOp: FP32 row-wise reference with ignore_index handling and target validation. - Triton TritonBatchInvariantLogpOp: online-softmax forward with fixed vocab tiling and tile-wise backward using saved per-row lse. - Registry dispatch, PyTorch/Triton tests, and operator docs. --- .gitignore | 3 + docs/.nav.yml | 1 + docs/operators/README.md | 1 + docs/operators/batch-invariant-logp.md | 172 +++++++ .../ops/pytorch/loss/batch_invariant_logp.py | 119 +++++ .../ops/triton/loss/batch_invariant_logp.py | 244 +++++++++ rl_engine/kernels/registry.py | 17 + tests/test_batch_invariant_logp.py | 478 ++++++++++++++++++ tests/test_triton_batch_invariant_logp.py | 352 +++++++++++++ 9 files changed, 1387 insertions(+) create mode 100644 docs/operators/batch-invariant-logp.md create mode 100644 rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py create mode 100644 rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py create mode 100644 tests/test_batch_invariant_logp.py create mode 100644 tests/test_triton_batch_invariant_logp.py diff --git a/.gitignore b/.gitignore index ae89c0d..edc4b00 100644 --- a/.gitignore +++ b/.gitignore @@ -206,3 +206,6 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +# Local dev notes (not for upstream) +_dev_notes/ diff --git a/docs/.nav.yml b/docs/.nav.yml index 60525c2..3a8ddc0 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -13,6 +13,7 @@ nav: - operators/README.md - operators/fused-logp.md - operators/linear-logp.md + - operators/batch-invariant-logp.md - operators/grpo-loss.md - operators/ratio-kl.md - operators/sampling.md diff --git a/docs/operators/README.md b/docs/operators/README.md index c4eae60..d0174b4 100644 --- a/docs/operators/README.md +++ b/docs/operators/README.md @@ -20,6 +20,7 @@ Every operator page should include: - [Fused LogP](fused-logp.md) - [Fused Linear LogP](linear-logp.md) +- [Batch-Invariant LogP](batch-invariant-logp.md) - [GRPO Loss](grpo-loss.md) - [Policy Ratio + KL Penalty](ratio-kl.md) - [Sampling](sampling.md) diff --git a/docs/operators/batch-invariant-logp.md b/docs/operators/batch-invariant-logp.md new file mode 100644 index 0000000..2cb38cb --- /dev/null +++ b/docs/operators/batch-invariant-logp.md @@ -0,0 +1,172 @@ +# Batch-Invariant LogP + +Batch-Invariant LogP computes selected token log-probabilities from already +materialized logits: + +```text +out[row] = logits[row, target_ids[row]] - logsumexp(logits[row, :]) +``` + +It targets RL post-training paths where policy log-probs are compared across +different packing, padding, and batch layouts. The key contract is +batch-invariance: for a fixed row of logits and target id, the result must not +change when that row is evaluated alone, at a different batch position, or with +different neighboring rows. + +Unlike `linear_logp`, this operator does not fuse the LM-head projection. It +takes `[*, V]` logits as input and returns one selected log-probability per row. + +## Entry Point + +```python +from rl_engine.kernels.registry import kernel_registry + +batch_invariant_logp = kernel_registry.get_op("batch_invariant_logp") + +logp = batch_invariant_logp( + logits, # [B, T, V] or [N, V], differentiable + target_ids, # [B, T] or [N], int + ignore_index=-100, +) # -> [B, T] or [N], float32 + +logp.sum().backward() # gradients flow into logits only +``` + +## Backends + +| Backend | Wrapper | Status | +| --- | --- | --- | +| CUDA / ROCm (Triton) | `TritonBatchInvariantLogpOp` | Triton online-softmax forward and tile-wise backward. Requires a GPU tensor. | +| PyTorch native | `NativeBatchInvariantLogpOp` | FP32 reference path; CPU fallback and Triton-less fallback. | + +Current dispatch: + +```text +CUDA / ROCm: Triton -> PyTorch +CPU: PyTorch +``` + +A compiled CUDA backend and benchmark suite are planned follow-up work. + +## Tensor Contract + +| Argument | Shape | Dtype | Requirements | +| --- | --- | --- | --- | +| `logits` | `[N, V]` / `[B, T, V]` / `[*lead, V]` | fp32 / fp16 / bf16 | Differentiable input; last dimension is vocab. | +| `target_ids` | `[N]` / `[B, T]` / `[*lead]` | int | Same leading shape as `logits`; non-ignored values in `[0, V)`. | +| `ignore_index` | scalar int | Python int | Default `-100`. Ignored rows output zero and receive zero gradient. | +| Output | `[N]` / `[B, T]` / `[*lead]` | float32 | Selected log-probability per row. | + +`target_ids` is integer and non-differentiable. Gradients flow only into +`logits`. + +## Reference Semantics + +For non-ignored rows: + +```python +logits_2d = logits.reshape(-1, logits.size(-1)).float() +target_1d = target_ids.reshape(-1).long() + +log_probs = torch.log_softmax(logits_2d, dim=-1) +selected = torch.gather( + log_probs, + dim=-1, + index=target_1d.unsqueeze(-1), +).squeeze(-1) + +out = selected.reshape(target_ids.shape) +``` + +For ignored rows: + +```text +target_ids[row] == ignore_index +out[row] = 0.0 +grad_logits[row, :] = 0.0 +``` + +Non-ignored target ids outside `[0, V)` raise `ValueError`. In particular, +`target=-1` is invalid unless `ignore_index=-1`. + +## Batch-Invariance + +The operator is designed so each row is computed independently: + +- The PyTorch path reshapes to `[N, V]` and applies row-wise reductions. +- The Triton forward uses `grid=(num_tokens,)`, so one program owns exactly one + row. +- Triton vocab traversal uses a fixed `_BLOCK_V=1024` and does not autotune by + batch size. +- Triton forward scans vocab tiles left-to-right using online logsumexp. +- Triton backward uses `grid=(num_tokens, vocab_tiles)` and writes one row tile + per program. It reuses the forward-saved per-row `lse`, so no backward + reduction crosses row boundaries. +- No atomic writes are used. + +These constraints ensure the result for a row depends only on that row's logits +and target id, not on batch size, row position, or neighboring rows. + +## Accuracy + +Both backends accumulate reductions in float32 and return float32 outputs. Tests +compare against `torch.log_softmax(...).gather(...)` with dtype-appropriate +tolerances: + +```text +fp32 forward: atol around 1e-5 +fp16/bf16 forward: atol around 1e-4 +fp16/bf16 backward: checked against fp32 reference with relaxed tolerance +``` + +CPU-vs-CUDA comparisons use tolerance-based checks; batch-invariance checks +within the same backend use exact equality where appropriate. + +## Minimal Example + +```python +import torch + +from rl_engine.kernels.registry import kernel_registry + +op = kernel_registry.get_op("batch_invariant_logp") + +logits = torch.randn(2, 4, 300, device="cuda", dtype=torch.bfloat16) +target_ids = torch.randint(0, 300, (2, 4), device="cuda") +target_ids[0, 0] = -100 + +out = op(logits, target_ids, ignore_index=-100) +assert out.shape == target_ids.shape +assert out.dtype == torch.float32 +assert out[0, 0].item() == 0.0 + +out.sum().backward() +``` + +## Tests + +```bash +python -m pytest tests/test_batch_invariant_logp.py -q -rs +python -m pytest tests/test_triton_batch_invariant_logp.py -q -rs +python -m pytest tests/test_batch_invariant_logp.py tests/test_triton_batch_invariant_logp.py -q -rs +``` + +The PyTorch tests cover correctness, leading-shape preservation, +batch-invariance, validation, ignore-index behavior, backward correctness, CUDA +smoke cases, and registry dispatch. + +The Triton tests cover fp32/fp16/bf16 correctness, large vocab, 3D leading +shapes, batch-size and position invariance, repeated-run determinism, backward +correctness, gradient batch-invariance, ignored-row zero gradients, and invalid +input rejection. + +Triton tests skip when Triton or CUDA is unavailable. On Windows, run the Triton +suite from WSL/Linux with CUDA. + +## Implementation Files + +- `rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py` +- `rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py` +- `rl_engine/kernels/registry.py` +- `tests/test_batch_invariant_logp.py` +- `tests/test_triton_batch_invariant_logp.py` diff --git a/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py b/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py new file mode 100644 index 0000000..47d1a58 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import torch + + +class NativeBatchInvariantLogpOp: + """Batch-invariant selected-token log-probability. + + ``selected_logprob[t] = logits[t, target_ids[t]] - logsumexp(logits[t, :])`` + + All reductions run in FP32. The row-wise max -> subtract -> exp -> sum -> log + pipeline is fully independent per row, so the result for any row depends + only on that row's logits and target - never on batch size or layout. + """ + + def __init__(self) -> None: + pass + + def __call__( + self, + logits: torch.Tensor, + target_ids: torch.Tensor, + ignore_index: int = -100, + ) -> torch.Tensor: + return self.apply(logits, target_ids, ignore_index=ignore_index) + + def apply( + self, + logits: torch.Tensor, + target_ids: torch.Tensor, + ignore_index: int = -100, + ) -> torch.Tensor: + self._validate_shapes(logits, target_ids) + + lead_shape = logits.shape[:-1] + vocab_size = logits.size(-1) + + logits_2d = logits.reshape(-1, vocab_size).float() + target_1d = target_ids.reshape(-1).to(logits.device, dtype=torch.long) + + selected_logp = self._row_wise_selected_logprob( + logits_2d, target_1d, ignore_index=ignore_index + ) + + return selected_logp.reshape(lead_shape) + + # ---------------------------------------------------------------------- # + # Core Computation + # ---------------------------------------------------------------------- # + @staticmethod + def _row_wise_selected_logprob( + logits_2d: torch.Tensor, + target_1d: torch.Tensor, + *, + ignore_index: int, + ) -> torch.Tensor: + """Per-row selected logprob with locked reduction order. + + The three reduction steps (max, sum-exp, gather) operate on each row + independently. PyTorch's ``max(dim=-1)`` and ``sum(dim=-1)`` iterate + the vocab dimension in a fixed, deterministic order for a given row + length, and that order does **not** change when more rows are added + to the batch. This is the property that makes the op batch-invariant. + + Accumulation is done entirely in FP32 to avoid half-precision + catastrophic cancellation during the ``exp(logit - max)`` step. + """ + vocab_size = logits_2d.size(1) + + valid_mask = target_1d != ignore_index + + valid_targets = target_1d[valid_mask] + # Check if target_ids contains values outside the valid range. + if valid_targets.numel() > 0 and ( + (valid_targets < 0).any() or (valid_targets >= vocab_size).any() + ): + bad = valid_targets[(valid_targets < 0) | (valid_targets >= vocab_size)] + raise ValueError( + f"target_ids contains values outside [0, {vocab_size}): {bad.tolist()}" + ) + + safe_target = target_1d.clone() + safe_target[~valid_mask] = 0 + + # logsumexp(z) = log(sum(exp(z - max(z)))) + max(z) + row_max = logits_2d.max(dim=-1).values + shifted = logits_2d - row_max.unsqueeze(-1) + exp_shifted = shifted.exp() + sum_exp = exp_shifted.sum(dim=-1) + log_sum_exp = sum_exp.log() + row_max + + row_indices = torch.arange(logits_2d.size(0), device=logits_2d.device) + selected_logit = logits_2d[row_indices, safe_target] + + selected_logp = selected_logit - log_sum_exp + + selected_logp = selected_logp.where( + valid_mask, torch.zeros_like(selected_logp) + ) + + return selected_logp + + # ---------------------------------------------------------------------- # + # Helper + # ---------------------------------------------------------------------- # + @staticmethod + def _validate_shapes(logits: torch.Tensor, target_ids: torch.Tensor) -> None: + if logits.dim() < 2: + raise ValueError( + f"logits must be at least 2-D ([*lead, V]), got shape {tuple(logits.shape)}" + ) + if logits.shape[:-1] != target_ids.shape: + raise ValueError( + f"logits leading shape {tuple(logits.shape[:-1])} must match " + f"target_ids shape {tuple(target_ids.shape)}" + ) diff --git a/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py b/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py new file mode 100644 index 0000000..2bd77b8 --- /dev/null +++ b/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +_BLOCK_V: int = 1024 + + +@triton.jit +def _batch_invariant_logp_kernel( + logits_ptr, # logits [N, V] + target_ptr, # target_ids [N] + output_ptr, # selected logprob output [N] + lse_ptr, # per-row log-sum-exp, saved for backward [N] + num_tokens, # N + vocab_size, # V + stride_row, # stride between consecutive rows in logits + ignore_index: tl.constexpr, + BLOCK_V: tl.constexpr, +): + """One program = one token row. Computes selected logprob via online softmax. + + Algorithm (one-pass online log-sum-exp): + m = -inf, s = 0, z_target = 0 + for each vocab tile [v0, v0+BLOCK_V): + load tile, cast to fp32 + collect target logit from tile (if target falls in this tile) + online softmax update: m, s + lse = log(s) + m + output = z_target - lse + + Also stores the per-row lse for use by the backward kernel. + """ + row_id = tl.program_id(0) + if row_id >= num_tokens: + return + + target_id = tl.load(target_ptr + row_id) + is_ignored = target_id == ignore_index + safe_target = tl.where(is_ignored, 0, target_id) + + m = tl.full((), float("-inf"), dtype=tl.float32) + s = tl.zeros((), dtype=tl.float32) + z_target = tl.zeros((), dtype=tl.float32) + + row_base = row_id.to(tl.int64) * stride_row + + for v0 in range(0, vocab_size, BLOCK_V): + cols = v0 + tl.arange(0, BLOCK_V) + mask = cols < vocab_size + + tile = tl.load( + logits_ptr + row_base + cols, + mask=mask, + other=float("-inf"), + ).to(tl.float32) + + is_target = (cols == safe_target) & mask + z_target += tl.sum(tl.where(is_target, tile, 0.0)) + + tile_max = tl.max(tile) + new_m = tl.maximum(m, tile_max) + s = s * tl.exp(m - new_m) + tl.sum(tl.exp(tile - new_m)) + m = new_m + + lse = m + tl.log(s) + result = z_target - lse + result = tl.where(is_ignored, 0.0, result) + + tl.store(output_ptr + row_id, result) + tl.store(lse_ptr + row_id, lse) + + +@triton.jit +def _batch_invariant_logp_bwd_kernel( + logits_ptr, # logits [N, V] + target_ptr, # target_ids [N] + lse_ptr, # per-row log-sum-exp from forward [N] + grad_out_ptr, # upstream gradient [N] + grad_logits_ptr, # gradient output for logits [N, V] + num_tokens, # N + vocab_size, # V + stride_row, # stride between consecutive rows in logits / grad_logits + ignore_index: tl.constexpr, + BLOCK_V: tl.constexpr, +): + row_id = tl.program_id(0) + tile_id = tl.program_id(1) + + cols = tile_id * BLOCK_V + tl.arange(0, BLOCK_V) + mask = cols < vocab_size + + target = tl.load(target_ptr + row_id) + ignored = target == ignore_index + + row_base = row_id.to(tl.int64) * stride_row + + logits = tl.load( + logits_ptr + row_base + cols, + mask=mask, + other=0.0, + ).to(tl.float32) + lse = tl.load(lse_ptr + row_id) + grad_out = tl.load(grad_out_ptr + row_id).to(tl.float32) + + soft = tl.exp(logits - lse) + onehot = tl.where(cols == target, 1.0, 0.0) + grad = grad_out * (onehot - soft) + grad = tl.where(ignored, 0.0, grad) + + tl.store(grad_logits_ptr + row_base + cols, grad, mask=mask) + + +class _BatchInvariantLogpFunction(torch.autograd.Function): + """Autograd wrapper for the Triton batch-invariant logp kernel.""" + + @staticmethod + def forward(ctx, logits, target_ids, ignore_index): + lead_shape = logits.shape[:-1] + vocab_size = logits.size(-1) + + logits_2d = logits.reshape(-1, vocab_size).contiguous() + target_1d = target_ids.reshape(-1).to( + device=logits.device, dtype=torch.int64 + ).contiguous() + + num_tokens = logits_2d.shape[0] + output = torch.empty(num_tokens, device=logits.device, dtype=torch.float32) + lse = torch.empty(num_tokens, device=logits.device, dtype=torch.float32) + + grid = (num_tokens,) + _batch_invariant_logp_kernel[grid]( + logits_2d, + target_1d, + output, + lse, + num_tokens, + vocab_size, + logits_2d.stride(0), + ignore_index=ignore_index, + BLOCK_V=_BLOCK_V, + ) + + ctx.save_for_backward(logits_2d, target_1d, lse) + ctx.ignore_index = ignore_index + ctx.lead_shape = lead_shape + ctx.vocab_size = vocab_size + + return output.reshape(lead_shape) + + @staticmethod + def backward(ctx, grad_output): + logits_2d, target_1d, lse = ctx.saved_tensors + ignore_index = ctx.ignore_index + vocab_size = ctx.vocab_size + num_tokens = logits_2d.shape[0] + + grad_flat = grad_output.reshape(-1).contiguous().to(torch.float32) + grad_logits = torch.empty_like(logits_2d, dtype=torch.float32) + + grid = (num_tokens, triton.cdiv(vocab_size, _BLOCK_V)) + _batch_invariant_logp_bwd_kernel[grid]( + logits_2d, + target_1d, + lse, + grad_flat, + grad_logits, + num_tokens, + vocab_size, + logits_2d.stride(0), + ignore_index=ignore_index, + BLOCK_V=_BLOCK_V, + ) + + grad_logits = grad_logits.to(logits_2d.dtype).reshape( + ctx.lead_shape + (vocab_size,) + ) + + return grad_logits, None, None + + +class TritonBatchInvariantLogpOp: + """Triton fused batch-invariant selected-token log-probability. + + Computes ``logits[t, target_ids[t]] - logsumexp(logits[t, :])`` using a + one-pass online softmax Triton kernel with locked reduction order. + + Requires a GPU tensor (CUDA / ROCm). + """ + + def __init__(self) -> None: + pass + + def __call__( + self, + logits: torch.Tensor, + target_ids: torch.Tensor, + ignore_index: int = -100, + ) -> torch.Tensor: + return self.apply(logits, target_ids, ignore_index=ignore_index) + + def apply( + self, + logits: torch.Tensor, + target_ids: torch.Tensor, + ignore_index: int = -100, + ) -> torch.Tensor: + if logits.device.type not in ("cuda", "xpu", "hip"): + raise RuntimeError( + "TritonBatchInvariantLogpOp requires a GPU tensor " + f"(CUDA / ROCm / XPU), got device '{logits.device}'." + ) + + if logits.dim() < 2: + raise ValueError( + f"logits must be at least 2-D ([*lead, V]), got shape " + f"{tuple(logits.shape)}" + ) + + if logits.shape[:-1] != target_ids.shape: + raise ValueError( + f"logits leading shape {tuple(logits.shape[:-1])} must match " + f"target_ids shape {tuple(target_ids.shape)}" + ) + + vocab_size = logits.size(-1) + target_flat = target_ids.reshape(-1) + valid_targets = target_flat[target_flat != ignore_index] + if valid_targets.numel() > 0 and ( + (valid_targets < 0).any() or (valid_targets >= vocab_size).any() + ): + bad = valid_targets[ + (valid_targets < 0) | (valid_targets >= vocab_size) + ] + raise ValueError( + f"target_ids contains values outside [0, {vocab_size}): " + f"{bad.tolist()}" + ) + + return _BatchInvariantLogpFunction.apply(logits, target_ids, ignore_index) diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 6780157..76c39ab 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -49,6 +49,14 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): TRITON_RATIO_KL = "rl_engine.kernels.ops.triton.loss.ratio_kl.TritonRatioKLOp" PYTORCH_RATIO_KL = "rl_engine.kernels.ops.pytorch.loss.ratio_kl.NativeRatioKLOp" + # Batch-invariant selected-logprob (WS1 #148: locked reduction order) + TRITON_BATCH_INVARIANT_LOGP = ( + "rl_engine.kernels.ops.triton.loss.batch_invariant_logp.TritonBatchInvariantLogpOp" + ) + PYTORCH_BATCH_INVARIANT_LOGP = ( + "rl_engine.kernels.ops.pytorch.loss.batch_invariant_logp.NativeBatchInvariantLogpOp" + ) + # Generic fallback TRITON_GENERIC = "rl_engine.kernels.ops.triton.generic.TritonOp" PYTORCH_ATTN = "rl_engine.kernels.ops.pytorch.attention.NativeAttentionOp" @@ -89,6 +97,10 @@ def __init__(self): "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "batch_invariant_logp": [ + OpBackend.TRITON_BATCH_INVARIANT_LOGP, + OpBackend.PYTORCH_BATCH_INVARIANT_LOGP, + ], # Default dispatch logic for new operators }, "rocm": { @@ -101,6 +113,10 @@ def __init__(self): "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "batch_invariant_logp": [ + OpBackend.TRITON_BATCH_INVARIANT_LOGP, + OpBackend.PYTORCH_BATCH_INVARIANT_LOGP, + ], }, "cpu": { "logp": [OpBackend.PYTORCH_NATIVE], @@ -108,6 +124,7 @@ def __init__(self): "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.PYTORCH_RATIO_KL], + "batch_invariant_logp": [OpBackend.PYTORCH_BATCH_INVARIANT_LOGP], }, } logger.info(f"KernelRegistry initialized for {device_ctx.device_type}") diff --git a/tests/test_batch_invariant_logp.py b/tests/test_batch_invariant_logp.py new file mode 100644 index 0000000..504c15a --- /dev/null +++ b/tests/test_batch_invariant_logp.py @@ -0,0 +1,478 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Tests for batch-invariant selected-logprob (issue #148). + +The test suite validates two orthogonal properties: +1. **Correctness** - output matches ``log_softmax + gather`` reference. +2. **Batch-invariance** - the result for a given row is bitwise identical + regardless of batch size, batch position, padding, or mixed-batch layout. +""" + +import pytest +import torch + +from rl_engine.kernels.ops.pytorch.loss.batch_invariant_logp import ( + NativeBatchInvariantLogpOp, +) +from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp + + +_V = 300 + +requires_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA device required.", +) + + +def _reference_logp(logits: torch.Tensor, target_ids: torch.Tensor) -> torch.Tensor: + """Canonical reference: log_softmax(fp32) + gather.""" + logits_2d = logits.reshape(-1, logits.size(-1)).float() + target_1d = target_ids.reshape(-1).long() + log_probs = torch.log_softmax(logits_2d, dim=-1) + selected = torch.gather(log_probs, dim=-1, index=target_1d.unsqueeze(1)).squeeze(1) + return selected.reshape(target_ids.shape) + + +def _make_row(seed: int, vocab: int = _V, device: str = "cpu") -> torch.Tensor: + """Generate a single deterministic logit row from a seed.""" + gen = torch.Generator(device=device).manual_seed(seed) + return torch.randn(1, vocab, generator=gen, device=device) + + +# --------------------------------------------------------------------------- +# 1. Correctness tests +# --------------------------------------------------------------------------- + + +class TestCorrectness: + """Output must match the canonical ``log_softmax + gather`` reference.""" + + def test_matches_reference_basic(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(8, _V) + target = torch.randint(0, _V, (8,)) + out = op(logits, target) + ref = _reference_logp(logits, target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-6) + + def test_matches_native_logp_op(self): + bi_op = NativeBatchInvariantLogpOp() + native_op = NativeLogpOp() + logits = torch.randn(16, _V) + target = torch.randint(0, _V, (16,)) + out_bi = bi_op(logits, target) + out_native = native_op.apply_fp32(logits, target) + assert torch.allclose(out_bi, out_native, atol=1e-6) + + def test_leading_shape_preserved(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(3, 5, _V) + target = torch.randint(0, _V, (3, 5)) + out = op(logits, target) + assert out.shape == (3, 5) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-6) + + def test_bf16_input_fp32_output(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(8, _V, dtype=torch.bfloat16) + target = torch.randint(0, _V, (8,)) + out = op(logits, target) + assert out.dtype == torch.float32 + ref = _reference_logp(logits.float(), target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_fp16_input_fp32_output(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(8, _V, dtype=torch.float16) + target = torch.randint(0, _V, (8,)) + out = op(logits, target) + assert out.dtype == torch.float32 + ref = _reference_logp(logits.float(), target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_single_token(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(1, _V) + target = torch.randint(0, _V, (1,)) + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-6) + + def test_vocab_size_1(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, 1) + target = torch.zeros(4, dtype=torch.long) + out = op(logits, target) + assert torch.allclose(out, torch.zeros(4), atol=1e-6) + + def test_large_vocab(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, 128256) + target = torch.randint(0, 128256, (4,)) + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + +# --------------------------------------------------------------------------- +# 2. Batch-invariance sweep tests - the core of issue #148 +# --------------------------------------------------------------------------- + + +class TestBatchInvariance: + """Same row must produce bitwise-identical output regardless of batch context.""" + + def _get_row_result_in_batch(self, row_logits, row_target, batch_size, position): + """Embed ``row_logits`` at *position* in a random batch of *batch_size* + and return the selected logprob for that row.""" + op = NativeBatchInvariantLogpOp() + vocab = row_logits.size(-1) + batch_logits = torch.randn(batch_size, vocab) + batch_target = torch.randint(0, vocab, (batch_size,)) + batch_logits[position] = row_logits.squeeze(0) + batch_target[position] = row_target.squeeze(0) + out = op(batch_logits, batch_target) + return out[position] + + def test_batch_size_1_vs_n(self): + """Same row in batch=1 vs batch=N must be bitwise equal.""" + op = NativeBatchInvariantLogpOp() + row = _make_row(42) + target = torch.tensor([7]) + + result_alone = op(row, target).item() + + for batch_size in [2, 4, 8, 16, 32, 64, 128]: + result_in_batch = self._get_row_result_in_batch( + row, target, batch_size, position=0 + ).item() + assert result_alone == result_in_batch, ( + f"Drift at batch_size={batch_size}: " + f"alone={result_alone}, in_batch={result_in_batch}" + ) + + def test_different_positions_in_batch(self): + """Same row at different positions in the same batch must be bitwise equal.""" + op = NativeBatchInvariantLogpOp() + row = _make_row(99) + target = torch.tensor([13]) + + batch_size = 16 + results = [] + for pos in range(batch_size): + val = self._get_row_result_in_batch(row, target, batch_size, pos).item() + results.append(val) + + assert all(r == results[0] for r in results), ( + f"Position-dependent drift detected: unique values = {set(results)}" + ) + + def test_mixed_batch_content(self): + """Changing *other* rows in the batch must not affect our row's result.""" + op = NativeBatchInvariantLogpOp() + row = _make_row(77) + target = torch.tensor([25]) + + batch_size = 8 + results = [] + for trial_seed in range(20): + torch.manual_seed(trial_seed * 1000) + batch_logits = torch.randn(batch_size, _V) + batch_target = torch.randint(0, _V, (batch_size,)) + batch_logits[3] = row.squeeze(0) + batch_target[3] = target.squeeze(0) + out = op(batch_logits, batch_target) + results.append(out[3].item()) + + assert all(r == results[0] for r in results), ( + f"Mixed-batch drift: unique values = {set(results)}" + ) + + def test_padding_layout_invariance(self): + """Left-padding vs right-padding must not affect real rows.""" + op = NativeBatchInvariantLogpOp() + row = _make_row(55) + target = torch.tensor([42]) + + pad_logits = torch.zeros(1, _V) + pad_target = torch.tensor([0]) + + batch_left = torch.cat([pad_logits, pad_logits, row], dim=0) + target_left = torch.cat([pad_target, pad_target, target], dim=0) + + batch_right = torch.cat([row, pad_logits, pad_logits], dim=0) + target_right = torch.cat([target, pad_target, pad_target], dim=0) + + out_left = op(batch_left, target_left) + out_right = op(batch_right, target_right) + + assert out_left[2].item() == out_right[0].item(), ( + "Padding layout changed the result" + ) + + def test_repeated_runs_deterministic(self): + """Same input repeated N times must produce bitwise-identical output.""" + op = NativeBatchInvariantLogpOp() + logits = torch.randn(16, _V) + target = torch.randint(0, _V, (16,)) + + results = [op(logits, target) for _ in range(50)] + for i, r in enumerate(results[1:], 1): + assert torch.equal(r, results[0]), f"Run {i} differs from run 0" + + def test_batch_invariance_with_ignore_index(self): + """Ignored positions must not affect other rows and must output 0.0.""" + op = NativeBatchInvariantLogpOp() + row = _make_row(33) + target_val = 10 + + batch_a = torch.cat([row, torch.randn(3, _V)], dim=0) + target_a = torch.tensor([target_val, 5, 8, 2]) + out_a = op(batch_a, target_a) + + target_b = torch.tensor([target_val, -100, -100, -100]) + out_b = op(batch_a, target_b) + + assert out_a[0].item() == out_b[0].item(), ( + "ignore_index on other rows changed row 0" + ) + assert out_b[1].item() == 0.0 + assert out_b[2].item() == 0.0 + assert out_b[3].item() == 0.0 + + +# --------------------------------------------------------------------------- +# 3. Shape / validation tests +# --------------------------------------------------------------------------- + + +class TestValidation: + + def test_rejects_1d_logits(self): + op = NativeBatchInvariantLogpOp() + with pytest.raises(ValueError, match="at least 2-D"): + op(torch.randn(10), torch.tensor([0])) + + def test_rejects_shape_mismatch(self): + op = NativeBatchInvariantLogpOp() + with pytest.raises(ValueError, match="must match"): + op(torch.randn(4, _V), torch.randint(0, _V, (5,))) + + def test_rejects_negative_target(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, _V) + target = torch.tensor([0, -1, 2, 3]) + with pytest.raises(ValueError, match="outside"): + op(logits, target) + + def test_rejects_target_ge_vocab(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, _V) + target = torch.tensor([0, 1, _V, 3]) + with pytest.raises(ValueError, match="outside"): + op(logits, target) + + def test_negative_target_with_ignore_index_ok(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, _V) + target = torch.tensor([0, -100, 2, 3]) + out = op(logits, target) + assert out[1].item() == 0.0 + + def test_3d_logits(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(2, 3, _V) + target = torch.randint(0, _V, (2, 3)) + out = op(logits, target) + assert out.shape == (2, 3) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-6) + + +# --------------------------------------------------------------------------- +# 4. Backward / gradient tests +# --------------------------------------------------------------------------- + + +class TestBackward: + """Gradient must match the reference log_softmax + gather backward.""" + + def test_backward_matches_reference(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, _V, requires_grad=True) + target = torch.randint(0, _V, (4,)) + + out = op(logits, target).sum() + out.backward() + grad = logits.grad.detach().clone() + + ref_logits = logits.detach().clone().requires_grad_(True) + ref = _reference_logp(ref_logits, target).sum() + ref.backward() + + assert torch.allclose(grad, ref_logits.grad, atol=1e-6) + + def test_gradient_batch_invariance(self): + """Same row's gradient must be bitwise equal in batch=1 vs batch=N.""" + op = NativeBatchInvariantLogpOp() + row = _make_row(42) + target = torch.tensor([7]) + + logits_alone = row.clone().requires_grad_(True) + op(logits_alone, target).sum().backward() + grad_alone = logits_alone.grad.detach().clone() + + for batch_size in [4, 16, 64]: + batch_logits = torch.randn(batch_size, _V) + batch_logits[0] = row.squeeze(0) + batch_logits.requires_grad_(True) + batch_target = torch.randint(0, _V, (batch_size,)) + batch_target[0] = target.squeeze(0) + op(batch_logits, batch_target).sum().backward() + grad_in_batch = batch_logits.grad[0:1].detach().clone() + assert torch.equal(grad_alone, grad_in_batch), ( + f"Gradient drift at batch_size={batch_size}" + ) + + +# --------------------------------------------------------------------------- +# 5. Edge cases: all-ignore and custom ignore_index +# --------------------------------------------------------------------------- + + +class TestIgnoreEdgeCases: + + def test_all_ignore_index_outputs_zero(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, _V) + target = torch.full((4,), -100) + out = op(logits, target) + assert torch.equal(out, torch.zeros_like(out)) + + def test_custom_ignore_index(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, _V) + target = torch.tensor([0, -1, 2, 3]) + out = op(logits, target, ignore_index=-1) + assert out[1].item() == 0.0 + valid_idx = [0, 2, 3] + ref = _reference_logp(logits[valid_idx], target[valid_idx]) + assert torch.allclose(out[valid_idx], ref, atol=1e-6) + + +# --------------------------------------------------------------------------- +# 6. CUDA tests - same logic on GPU +# --------------------------------------------------------------------------- + + +@requires_cuda +class TestCUDACorrectness: + """Correctness on CUDA device.""" + + def test_matches_reference_cuda(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(8, _V, device="cuda") + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert out.device.type == "cuda" + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-6) + + def test_bf16_cuda(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(8, _V, device="cuda", dtype=torch.bfloat16) + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + assert out.dtype == torch.float32 + ref = _reference_logp(logits.float(), target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_large_vocab_cuda(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(4, 128256, device="cuda") + target = torch.randint(0, 128256, (4,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + +@requires_cuda +class TestCUDABatchInvariance: + """Batch-invariance on CUDA - the most important GPU validation.""" + + def test_batch_size_1_vs_n_cuda(self): + op = NativeBatchInvariantLogpOp() + row = _make_row(42, device="cuda") + target = torch.tensor([7], device="cuda") + result_alone = op(row, target).item() + + for batch_size in [2, 4, 8, 16, 32, 64, 128]: + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[0] = row.squeeze(0) + batch_target[0] = target.squeeze(0) + result_in_batch = op(batch_logits, batch_target)[0].item() + assert result_alone == result_in_batch, ( + f"CUDA drift at batch_size={batch_size}: " + f"alone={result_alone}, in_batch={result_in_batch}" + ) + + def test_different_positions_cuda(self): + op = NativeBatchInvariantLogpOp() + row = _make_row(99, device="cuda") + target = torch.tensor([13], device="cuda") + batch_size = 16 + results = [] + for pos in range(batch_size): + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[pos] = row.squeeze(0) + batch_target[pos] = target.squeeze(0) + results.append(op(batch_logits, batch_target)[pos].item()) + assert all(r == results[0] for r in results), ( + f"CUDA position drift: unique = {set(results)}" + ) + + def test_repeated_runs_cuda(self): + op = NativeBatchInvariantLogpOp() + logits = torch.randn(16, _V, device="cuda") + target = torch.randint(0, _V, (16,), device="cuda") + results = [op(logits, target) for _ in range(50)] + for i, r in enumerate(results[1:], 1): + assert torch.equal(r, results[0]), f"CUDA run {i} differs from run 0" + + def test_cpu_gpu_cross_check(self): + """Same input on CPU vs CUDA should match within tolerance.""" + op = NativeBatchInvariantLogpOp() + logits_cpu = torch.randn(8, _V) + target_cpu = torch.randint(0, _V, (8,)) + out_cpu = op(logits_cpu, target_cpu) + out_cuda = op(logits_cpu.cuda(), target_cpu.cuda()) + assert torch.allclose(out_cpu, out_cuda.cpu(), atol=1e-6, rtol=1e-6), ( + "CPU vs CUDA result mismatch" + ) + + +# --------------------------------------------------------------------------- +# 7. Registry dispatch test +# --------------------------------------------------------------------------- + + +def test_registry_dispatches_correctly(): + from rl_engine.kernels.registry import kernel_registry + + op = kernel_registry.get_op("batch_invariant_logp") + assert ( + isinstance(op, NativeBatchInvariantLogpOp) + or type(op).__name__ == "TritonBatchInvariantLogpOp" + ) + logits = torch.randn(4, _V, device="cuda" if torch.cuda.is_available() else "cpu") + target = torch.randint(0, _V, (4,), device=logits.device) + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-6) diff --git a/tests/test_triton_batch_invariant_logp.py b/tests/test_triton_batch_invariant_logp.py new file mode 100644 index 0000000..064acd9 --- /dev/null +++ b/tests/test_triton_batch_invariant_logp.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Tests for the Triton batch-invariant selected-logprob kernel (issue #148). + +These tests validate that the Triton kernel produces results that: +1. Match the PyTorch reference implementation (correctness). +2. Are bitwise identical across different batch sizes / positions (batch-invariance). +3. Support backward pass (gradient correctness). + +All tests are skipped when Triton or CUDA is unavailable (e.g. on Windows or CPU-only). +""" + +import pytest +import torch + +try: + import triton # noqa: F401 + + _HAS_TRITON = True +except ImportError: + _HAS_TRITON = False + +requires_triton_cuda = pytest.mark.skipif( + not (_HAS_TRITON and torch.cuda.is_available()), + reason="Triton batch-invariant logp requires CUDA device and Triton.", +) + +_V = 300 + + +def _reference_logp(logits: torch.Tensor, target_ids: torch.Tensor) -> torch.Tensor: + """Canonical reference: log_softmax(fp32) + gather.""" + logits_2d = logits.reshape(-1, logits.size(-1)).float() + target_1d = target_ids.reshape(-1).long() + log_probs = torch.log_softmax(logits_2d, dim=-1) + selected = torch.gather(log_probs, dim=-1, index=target_1d.unsqueeze(1)).squeeze(1) + return selected.reshape(target_ids.shape) + + +def _make_row(seed: int, vocab: int = _V, device: str = "cuda") -> torch.Tensor: + """Generate a single deterministic logit row from a seed.""" + gen = torch.Generator(device=device).manual_seed(seed) + return torch.randn(1, vocab, generator=gen, device=device) + + +# --------------------------------------------------------------------------- +# 1. Correctness: Triton vs reference +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonCorrectness: + """Triton kernel output must match log_softmax + gather reference.""" + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_matches_reference_fp32(self): + op = self._get_op() + logits = torch.randn(8, _V, device="cuda") + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-5) + + def test_matches_reference_bf16(self): + op = self._get_op() + logits = torch.randn(8, _V, device="cuda", dtype=torch.bfloat16) + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits.float(), target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-4) + + def test_matches_reference_fp16(self): + op = self._get_op() + logits = torch.randn(8, _V, device="cuda", dtype=torch.float16) + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits.float(), target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-4) + + def test_large_vocab(self): + op = self._get_op() + logits = torch.randn(4, 128256, device="cuda") + target = torch.randint(0, 128256, (4,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_single_token(self): + op = self._get_op() + logits = torch.randn(1, _V, device="cuda") + target = torch.randint(0, _V, (1,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_3d_logits(self): + op = self._get_op() + logits = torch.randn(2, 3, _V, device="cuda") + target = torch.randint(0, _V, (2, 3), device="cuda") + out = op(logits, target) + assert out.shape == (2, 3) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_matches_pytorch_op(self): + """Triton and PyTorch ops should agree within tolerance.""" + from rl_engine.kernels.ops.pytorch.loss.batch_invariant_logp import ( + NativeBatchInvariantLogpOp, + ) + triton_op = self._get_op() + pytorch_op = NativeBatchInvariantLogpOp() + logits = torch.randn(16, _V, device="cuda") + target = torch.randint(0, _V, (16,), device="cuda") + out_triton = triton_op(logits, target) + out_pytorch = pytorch_op(logits, target) + assert torch.allclose(out_triton, out_pytorch, atol=1e-5) + + +# --------------------------------------------------------------------------- +# 2. Batch-invariance on GPU via Triton +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonBatchInvariance: + """Triton kernel must be bitwise batch-invariant.""" + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_batch_size_1_vs_n(self): + op = self._get_op() + row = _make_row(42) + target = torch.tensor([7], device="cuda") + result_alone = op(row, target).item() + + for batch_size in [2, 4, 8, 16, 32, 64, 128]: + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[0] = row.squeeze(0) + batch_target[0] = target.squeeze(0) + result_in_batch = op(batch_logits, batch_target)[0].item() + assert result_alone == result_in_batch, ( + f"Triton drift at batch_size={batch_size}: " + f"alone={result_alone}, in_batch={result_in_batch}" + ) + + def test_different_positions(self): + op = self._get_op() + row = _make_row(99) + target = torch.tensor([13], device="cuda") + batch_size = 16 + results = [] + for pos in range(batch_size): + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[pos] = row.squeeze(0) + batch_target[pos] = target.squeeze(0) + results.append(op(batch_logits, batch_target)[pos].item()) + assert all(r == results[0] for r in results), ( + f"Triton position drift: unique = {set(results)}" + ) + + def test_repeated_runs(self): + op = self._get_op() + logits = torch.randn(16, _V, device="cuda") + target = torch.randint(0, _V, (16,), device="cuda") + results = [op(logits, target) for _ in range(50)] + for i, r in enumerate(results[1:], 1): + assert torch.equal(r, results[0]), f"Triton run {i} differs from run 0" + + def test_mixed_batch_content(self): + op = self._get_op() + row = _make_row(77) + target = torch.tensor([25], device="cuda") + batch_size = 8 + results = [] + for trial_seed in range(20): + torch.manual_seed(trial_seed * 1000) + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[3] = row.squeeze(0) + batch_target[3] = target.squeeze(0) + results.append(op(batch_logits, batch_target)[3].item()) + assert all(r == results[0] for r in results), ( + f"Triton mixed-batch drift: unique = {set(results)}" + ) + + +# --------------------------------------------------------------------------- +# 3. Backward / gradient +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonBackward: + """Gradient through the Triton op must match reference.""" + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_backward_matches_reference(self): + op = self._get_op() + logits = torch.randn(4, _V, device="cuda", requires_grad=True) + target = torch.randint(0, _V, (4,), device="cuda") + + out = op(logits, target).sum() + out.backward() + grad = logits.grad.detach().clone() + + ref_logits = logits.detach().clone().requires_grad_(True) + ref = _reference_logp(ref_logits, target).sum() + ref.backward() + + assert torch.allclose(grad, ref_logits.grad, atol=1e-5) + + def test_gradient_batch_invariance(self): + op = self._get_op() + row = _make_row(42) + target = torch.tensor([7], device="cuda") + + logits_alone = row.clone().requires_grad_(True) + op(logits_alone, target).sum().backward() + grad_alone = logits_alone.grad.detach().clone() + + for batch_size in [4, 16, 64]: + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_logits[0] = row.squeeze(0) + batch_logits.requires_grad_(True) + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_target[0] = target.squeeze(0) + op(batch_logits, batch_target).sum().backward() + grad_in_batch = batch_logits.grad[0:1].detach().clone() + assert torch.allclose(grad_alone, grad_in_batch, atol=1e-5), ( + f"Triton gradient drift at batch_size={batch_size}" + ) + + def test_ignored_row_grad_is_zero(self): + """Ignored rows must have zero gradient across the entire vocab.""" + op = self._get_op() + logits = torch.randn(4, _V, device="cuda", requires_grad=True) + target = torch.tensor([0, -100, 2, -100], device="cuda") + op(logits, target).sum().backward() + assert torch.equal(logits.grad[1], torch.zeros(_V, device="cuda")) + assert torch.equal(logits.grad[3], torch.zeros(_V, device="cuda")) + + def test_backward_bf16_input(self): + """Backward with bf16 logits should match fp32 reference within tolerance.""" + op = self._get_op() + logits = torch.randn(8, _V, device="cuda", dtype=torch.bfloat16, requires_grad=True) + target = torch.randint(0, _V, (8,), device="cuda") + + op(logits, target).sum().backward() + grad = logits.grad.detach().clone() + + ref_logits = logits.detach().float().requires_grad_(True) + _reference_logp(ref_logits, target).sum().backward() + + assert torch.allclose(grad.float(), ref_logits.grad, atol=1e-2) + + def test_backward_fp16_input(self): + """Backward with fp16 logits should match fp32 reference within tolerance.""" + op = self._get_op() + logits = torch.randn(8, _V, device="cuda", dtype=torch.float16, requires_grad=True) + target = torch.randint(0, _V, (8,), device="cuda") + + op(logits, target).sum().backward() + grad = logits.grad.detach().clone() + + ref_logits = logits.detach().float().requires_grad_(True) + _reference_logp(ref_logits, target).sum().backward() + + assert torch.allclose(grad.float(), ref_logits.grad, atol=1e-2) + + +# --------------------------------------------------------------------------- +# 4. ignore_index handling +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonIgnoreIndex: + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_ignore_outputs_zero(self): + op = self._get_op() + logits = torch.randn(4, _V, device="cuda") + target = torch.tensor([0, -100, 2, -100], device="cuda") + out = op(logits, target) + assert out[1].item() == 0.0 + assert out[3].item() == 0.0 + ref = _reference_logp(logits[[0, 2]], target[[0, 2]]) + assert torch.allclose(out[[0, 2]], ref, atol=1e-5) + + def test_all_ignore(self): + op = self._get_op() + logits = torch.randn(4, _V, device="cuda") + target = torch.full((4,), -100, device="cuda") + out = op(logits, target) + assert torch.equal(out, torch.zeros(4, device="cuda")) + + +# --------------------------------------------------------------------------- +# 5. Validation +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonValidation: + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_rejects_cpu_tensor(self): + op = self._get_op() + with pytest.raises(RuntimeError, match="requires a GPU"): + op(torch.randn(4, _V), torch.randint(0, _V, (4,))) + + def test_rejects_1d_logits(self): + op = self._get_op() + with pytest.raises(ValueError, match="at least 2-D"): + op(torch.randn(10, device="cuda"), torch.tensor([0], device="cuda")) + + def test_rejects_invalid_target(self): + op = self._get_op() + logits = torch.randn(4, _V, device="cuda") + target = torch.tensor([0, -1, 2, 3], device="cuda") + with pytest.raises(ValueError, match="outside"): + op(logits, target) From d748a8fe516621ff361a335c87d2edaa3374c69e Mon Sep 17 00:00:00 2001 From: hihaluemen <1596916766@qq.com> Date: Mon, 29 Jun 2026 03:25:21 +0800 Subject: [PATCH 2/4] [REVIEW][kernels] address batch-invariant logprob comments --- docs/operators/batch-invariant-logp.md | 32 +- .../ops/pytorch/loss/batch_invariant_logp.py | 27 +- .../ops/triton/loss/batch_invariant_logp.py | 35 +- tests/test_batch_invariant_logp.py | 330 +++++++++++++++- tests/test_triton_batch_invariant_logp.py | 352 ------------------ 5 files changed, 379 insertions(+), 397 deletions(-) delete mode 100644 tests/test_triton_batch_invariant_logp.py diff --git a/docs/operators/batch-invariant-logp.md b/docs/operators/batch-invariant-logp.md index 2cb38cb..e6c2702 100644 --- a/docs/operators/batch-invariant-logp.md +++ b/docs/operators/batch-invariant-logp.md @@ -27,6 +27,7 @@ logp = batch_invariant_logp( logits, # [B, T, V] or [N, V], differentiable target_ids, # [B, T] or [N], int ignore_index=-100, + validate=False, # opt-in target range check (syncs CUDA stream) ) # -> [B, T] or [N], float32 logp.sum().backward() # gradients flow into logits only @@ -47,6 +48,8 @@ CPU: PyTorch ``` A compiled CUDA backend and benchmark suite are planned follow-up work. +Benchmarks are not included in this PR; they will be added alongside the CUDA +backend in a subsequent PR. ## Tensor Contract @@ -86,8 +89,13 @@ out[row] = 0.0 grad_logits[row, :] = 0.0 ``` -Non-ignored target ids outside `[0, V)` raise `ValueError`. In particular, -`target=-1` is invalid unless `ignore_index=-1`. +Non-ignored target ids outside `[0, V)` raise `ValueError` when +`validate=True`. In particular, `target=-1` is invalid unless +`ignore_index=-1`. + +`validate=False` (default) skips the target range check to avoid CUDA stream +synchronization in training hot paths. Use `validate=True` during debugging or +in tests. ## Batch-Invariance @@ -147,21 +155,16 @@ out.sum().backward() ```bash python -m pytest tests/test_batch_invariant_logp.py -q -rs -python -m pytest tests/test_triton_batch_invariant_logp.py -q -rs -python -m pytest tests/test_batch_invariant_logp.py tests/test_triton_batch_invariant_logp.py -q -rs ``` -The PyTorch tests cover correctness, leading-shape preservation, -batch-invariance, validation, ignore-index behavior, backward correctness, CUDA -smoke cases, and registry dispatch. - -The Triton tests cover fp32/fp16/bf16 correctness, large vocab, 3D leading -shapes, batch-size and position invariance, repeated-run determinism, backward -correctness, gradient batch-invariance, ignored-row zero gradients, and invalid -input rejection. +All backends (Native, Triton) are tested in a single file. Coverage includes: +correctness, leading-shape preservation, batch-invariance (bitwise), validation, +ignore-index behavior, backward correctness, CUDA smoke cases, registry +dispatch, and Triton-specific fp32/fp16/bf16 correctness, large vocab, backward +gradient batch-invariance, and ignored-row zero gradients. -Triton tests skip when Triton or CUDA is unavailable. On Windows, run the Triton -suite from WSL/Linux with CUDA. +Triton tests skip when Triton or CUDA is unavailable. On Windows, run via +WSL/Linux with CUDA. ## Implementation Files @@ -169,4 +172,3 @@ suite from WSL/Linux with CUDA. - `rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py` - `rl_engine/kernels/registry.py` - `tests/test_batch_invariant_logp.py` -- `tests/test_triton_batch_invariant_logp.py` diff --git a/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py b/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py index 47d1a58..ef9270b 100644 --- a/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py +++ b/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py @@ -24,14 +24,18 @@ def __call__( logits: torch.Tensor, target_ids: torch.Tensor, ignore_index: int = -100, + *, + validate: bool = False, ) -> torch.Tensor: - return self.apply(logits, target_ids, ignore_index=ignore_index) + return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate) def apply( self, logits: torch.Tensor, target_ids: torch.Tensor, ignore_index: int = -100, + *, + validate: bool = False, ) -> torch.Tensor: self._validate_shapes(logits, target_ids) @@ -42,7 +46,7 @@ def apply( target_1d = target_ids.reshape(-1).to(logits.device, dtype=torch.long) selected_logp = self._row_wise_selected_logprob( - logits_2d, target_1d, ignore_index=ignore_index + logits_2d, target_1d, ignore_index=ignore_index, validate=validate ) return selected_logp.reshape(lead_shape) @@ -56,6 +60,7 @@ def _row_wise_selected_logprob( target_1d: torch.Tensor, *, ignore_index: int, + validate: bool = False, ) -> torch.Tensor: """Per-row selected logprob with locked reduction order. @@ -72,15 +77,15 @@ def _row_wise_selected_logprob( valid_mask = target_1d != ignore_index - valid_targets = target_1d[valid_mask] - # Check if target_ids contains values outside the valid range. - if valid_targets.numel() > 0 and ( - (valid_targets < 0).any() or (valid_targets >= vocab_size).any() - ): - bad = valid_targets[(valid_targets < 0) | (valid_targets >= vocab_size)] - raise ValueError( - f"target_ids contains values outside [0, {vocab_size}): {bad.tolist()}" - ) + if validate: + valid_targets = target_1d[valid_mask] + if valid_targets.numel() > 0 and ( + (valid_targets < 0).any() or (valid_targets >= vocab_size).any() + ): + bad = valid_targets[(valid_targets < 0) | (valid_targets >= vocab_size)] + raise ValueError( + f"target_ids contains values outside [0, {vocab_size}): {bad.tolist()}" + ) safe_target = target_1d.clone() safe_target[~valid_mask] = 0 diff --git a/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py b/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py index 2bd77b8..f32017c 100644 --- a/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py +++ b/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py @@ -200,45 +200,48 @@ def __call__( logits: torch.Tensor, target_ids: torch.Tensor, ignore_index: int = -100, + *, + validate: bool = False, ) -> torch.Tensor: - return self.apply(logits, target_ids, ignore_index=ignore_index) + return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate) def apply( self, logits: torch.Tensor, target_ids: torch.Tensor, ignore_index: int = -100, + *, + validate: bool = False, ) -> torch.Tensor: if logits.device.type not in ("cuda", "xpu", "hip"): raise RuntimeError( "TritonBatchInvariantLogpOp requires a GPU tensor " f"(CUDA / ROCm / XPU), got device '{logits.device}'." ) - if logits.dim() < 2: raise ValueError( f"logits must be at least 2-D ([*lead, V]), got shape " f"{tuple(logits.shape)}" ) - if logits.shape[:-1] != target_ids.shape: raise ValueError( f"logits leading shape {tuple(logits.shape[:-1])} must match " f"target_ids shape {tuple(target_ids.shape)}" ) - vocab_size = logits.size(-1) - target_flat = target_ids.reshape(-1) - valid_targets = target_flat[target_flat != ignore_index] - if valid_targets.numel() > 0 and ( - (valid_targets < 0).any() or (valid_targets >= vocab_size).any() - ): - bad = valid_targets[ - (valid_targets < 0) | (valid_targets >= vocab_size) - ] - raise ValueError( - f"target_ids contains values outside [0, {vocab_size}): " - f"{bad.tolist()}" - ) + if validate: + vocab_size = logits.size(-1) + target_flat = target_ids.reshape(-1) + valid_targets = target_flat[target_flat != ignore_index] + if valid_targets.numel() > 0 and ( + (valid_targets < 0).any() or (valid_targets >= vocab_size).any() + ): + bad = valid_targets[ + (valid_targets < 0) | (valid_targets >= vocab_size) + ] + raise ValueError( + f"target_ids contains values outside [0, {vocab_size}): " + f"{bad.tolist()}" + ) return _BatchInvariantLogpFunction.apply(logits, target_ids, ignore_index) diff --git a/tests/test_batch_invariant_logp.py b/tests/test_batch_invariant_logp.py index 504c15a..d747e42 100644 --- a/tests/test_batch_invariant_logp.py +++ b/tests/test_batch_invariant_logp.py @@ -25,6 +25,17 @@ reason="CUDA device required.", ) +try: + import triton # noqa: F401 + _HAS_TRITON = True +except ImportError: + _HAS_TRITON = False + +requires_triton_cuda = pytest.mark.skipif( + not (_HAS_TRITON and torch.cuda.is_available()), + reason="Triton batch-invariant logp requires CUDA device and Triton.", +) + def _reference_logp(logits: torch.Tensor, target_ids: torch.Tensor) -> torch.Tensor: """Canonical reference: log_softmax(fp32) + gather.""" @@ -267,14 +278,14 @@ def test_rejects_negative_target(self): logits = torch.randn(4, _V) target = torch.tensor([0, -1, 2, 3]) with pytest.raises(ValueError, match="outside"): - op(logits, target) + op(logits, target, validate=True) def test_rejects_target_ge_vocab(self): op = NativeBatchInvariantLogpOp() logits = torch.randn(4, _V) target = torch.tensor([0, 1, _V, 3]) with pytest.raises(ValueError, match="outside"): - op(logits, target) + op(logits, target, validate=True) def test_negative_target_with_ignore_index_ok(self): op = NativeBatchInvariantLogpOp() @@ -458,8 +469,321 @@ def test_cpu_gpu_cross_check(self): ) + +# --------------------------------------------------------------------------- +# 7. Triton backend tests +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# Correctness: Triton vs reference +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonCorrectness: + """Triton kernel output must match log_softmax + gather reference.""" + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_matches_reference_fp32(self): + op = self._get_op() + logits = torch.randn(8, _V, device="cuda") + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-5) + + def test_matches_reference_bf16(self): + op = self._get_op() + logits = torch.randn(8, _V, device="cuda", dtype=torch.bfloat16) + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits.float(), target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-4) + + def test_matches_reference_fp16(self): + op = self._get_op() + logits = torch.randn(8, _V, device="cuda", dtype=torch.float16) + target = torch.randint(0, _V, (8,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits.float(), target) + assert out.dtype == torch.float32 + assert torch.allclose(out, ref, atol=1e-4) + + def test_large_vocab(self): + op = self._get_op() + logits = torch.randn(4, 128256, device="cuda") + target = torch.randint(0, 128256, (4,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_single_token(self): + op = self._get_op() + logits = torch.randn(1, _V, device="cuda") + target = torch.randint(0, _V, (1,), device="cuda") + out = op(logits, target) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_3d_logits(self): + op = self._get_op() + logits = torch.randn(2, 3, _V, device="cuda") + target = torch.randint(0, _V, (2, 3), device="cuda") + out = op(logits, target) + assert out.shape == (2, 3) + ref = _reference_logp(logits, target) + assert torch.allclose(out, ref, atol=1e-5) + + def test_matches_pytorch_op(self): + """Triton and PyTorch ops should agree within tolerance.""" + from rl_engine.kernels.ops.pytorch.loss.batch_invariant_logp import ( + NativeBatchInvariantLogpOp, + ) + triton_op = self._get_op() + pytorch_op = NativeBatchInvariantLogpOp() + logits = torch.randn(16, _V, device="cuda") + target = torch.randint(0, _V, (16,), device="cuda") + out_triton = triton_op(logits, target) + out_pytorch = pytorch_op(logits, target) + assert torch.allclose(out_triton, out_pytorch, atol=1e-5) + + +# --------------------------------------------------------------------------- +# Batch-invariance on GPU via Triton +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonBatchInvariance: + """Triton kernel must be bitwise batch-invariant.""" + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_batch_size_1_vs_n(self): + op = self._get_op() + row = _make_row(42, device="cuda") + target = torch.tensor([7], device="cuda") + result_alone = op(row, target).item() + + for batch_size in [2, 4, 8, 16, 32, 64, 128]: + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[0] = row.squeeze(0) + batch_target[0] = target.squeeze(0) + result_in_batch = op(batch_logits, batch_target)[0].item() + assert result_alone == result_in_batch, ( + f"Triton drift at batch_size={batch_size}: " + f"alone={result_alone}, in_batch={result_in_batch}" + ) + + def test_different_positions(self): + op = self._get_op() + row = _make_row(99, device="cuda") + target = torch.tensor([13], device="cuda") + batch_size = 16 + results = [] + for pos in range(batch_size): + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[pos] = row.squeeze(0) + batch_target[pos] = target.squeeze(0) + results.append(op(batch_logits, batch_target)[pos].item()) + assert all(r == results[0] for r in results), ( + f"Triton position drift: unique = {set(results)}" + ) + + def test_repeated_runs(self): + op = self._get_op() + logits = torch.randn(16, _V, device="cuda") + target = torch.randint(0, _V, (16,), device="cuda") + results = [op(logits, target) for _ in range(50)] + for i, r in enumerate(results[1:], 1): + assert torch.equal(r, results[0]), f"Triton run {i} differs from run 0" + + def test_mixed_batch_content(self): + op = self._get_op() + row = _make_row(77, device="cuda") + target = torch.tensor([25], device="cuda") + batch_size = 8 + results = [] + for trial_seed in range(20): + torch.manual_seed(trial_seed * 1000) + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_logits[3] = row.squeeze(0) + batch_target[3] = target.squeeze(0) + results.append(op(batch_logits, batch_target)[3].item()) + assert all(r == results[0] for r in results), ( + f"Triton mixed-batch drift: unique = {set(results)}" + ) + + +# --------------------------------------------------------------------------- +# Backward / gradient +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonBackward: + """Gradient through the Triton op must match reference.""" + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_backward_matches_reference(self): + op = self._get_op() + logits = torch.randn(4, _V, device="cuda", requires_grad=True) + target = torch.randint(0, _V, (4,), device="cuda") + + out = op(logits, target).sum() + out.backward() + grad = logits.grad.detach().clone() + + ref_logits = logits.detach().clone().requires_grad_(True) + ref = _reference_logp(ref_logits, target).sum() + ref.backward() + + assert torch.allclose(grad, ref_logits.grad, atol=1e-5) + + def test_gradient_batch_invariance(self): + op = self._get_op() + row = _make_row(42, device="cuda") + target = torch.tensor([7], device="cuda") + + logits_alone = row.clone().requires_grad_(True) + op(logits_alone, target).sum().backward() + grad_alone = logits_alone.grad.detach().clone() + + for batch_size in [4, 16, 64]: + batch_logits = torch.randn(batch_size, _V, device="cuda") + batch_logits[0] = row.squeeze(0) + batch_logits.requires_grad_(True) + batch_target = torch.randint(0, _V, (batch_size,), device="cuda") + batch_target[0] = target.squeeze(0) + op(batch_logits, batch_target).sum().backward() + grad_in_batch = batch_logits.grad[0:1].detach().clone() + assert torch.allclose(grad_alone, grad_in_batch, atol=1e-5), ( + f"Triton gradient drift at batch_size={batch_size}" + ) + + def test_ignored_row_grad_is_zero(self): + """Ignored rows must have zero gradient across the entire vocab.""" + op = self._get_op() + logits = torch.randn(4, _V, device="cuda", requires_grad=True) + target = torch.tensor([0, -100, 2, -100], device="cuda") + op(logits, target).sum().backward() + assert torch.equal(logits.grad[1], torch.zeros(_V, device="cuda")) + assert torch.equal(logits.grad[3], torch.zeros(_V, device="cuda")) + + def test_backward_bf16_input(self): + """Backward with bf16 logits should match fp32 reference within tolerance.""" + op = self._get_op() + logits = torch.randn(8, _V, device="cuda", dtype=torch.bfloat16, requires_grad=True) + target = torch.randint(0, _V, (8,), device="cuda") + + op(logits, target).sum().backward() + grad = logits.grad.detach().clone() + + ref_logits = logits.detach().float().requires_grad_(True) + _reference_logp(ref_logits, target).sum().backward() + + assert torch.allclose(grad.float(), ref_logits.grad, atol=1e-2) + + def test_backward_fp16_input(self): + """Backward with fp16 logits should match fp32 reference within tolerance.""" + op = self._get_op() + logits = torch.randn(8, _V, device="cuda", dtype=torch.float16, requires_grad=True) + target = torch.randint(0, _V, (8,), device="cuda") + + op(logits, target).sum().backward() + grad = logits.grad.detach().clone() + + ref_logits = logits.detach().float().requires_grad_(True) + _reference_logp(ref_logits, target).sum().backward() + + assert torch.allclose(grad.float(), ref_logits.grad, atol=1e-2) + + +# --------------------------------------------------------------------------- +# ignore_index handling +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonIgnoreIndex: + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_ignore_outputs_zero(self): + op = self._get_op() + logits = torch.randn(4, _V, device="cuda") + target = torch.tensor([0, -100, 2, -100], device="cuda") + out = op(logits, target) + assert out[1].item() == 0.0 + assert out[3].item() == 0.0 + ref = _reference_logp(logits[[0, 2]], target[[0, 2]]) + assert torch.allclose(out[[0, 2]], ref, atol=1e-5) + + def test_all_ignore(self): + op = self._get_op() + logits = torch.randn(4, _V, device="cuda") + target = torch.full((4,), -100, device="cuda") + out = op(logits, target) + assert torch.equal(out, torch.zeros(4, device="cuda")) + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +@requires_triton_cuda +class TestTritonValidation: + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + + def test_rejects_cpu_tensor(self): + op = self._get_op() + with pytest.raises(RuntimeError, match="requires a GPU"): + op(torch.randn(4, _V), torch.randint(0, _V, (4,))) + + def test_rejects_1d_logits(self): + op = self._get_op() + with pytest.raises(ValueError, match="at least 2-D"): + op(torch.randn(10, device="cuda"), torch.tensor([0], device="cuda")) + + def test_rejects_invalid_target(self): + op = self._get_op() + logits = torch.randn(4, _V, device="cuda") + target = torch.tensor([0, -1, 2, 3], device="cuda") + with pytest.raises(ValueError, match="outside"): + op(logits, target, validate=True) + + # --------------------------------------------------------------------------- -# 7. Registry dispatch test +# 8. Registry dispatch test # --------------------------------------------------------------------------- diff --git a/tests/test_triton_batch_invariant_logp.py b/tests/test_triton_batch_invariant_logp.py deleted file mode 100644 index 064acd9..0000000 --- a/tests/test_triton_batch_invariant_logp.py +++ /dev/null @@ -1,352 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2026 RL-Kernel Contributors - -"""Tests for the Triton batch-invariant selected-logprob kernel (issue #148). - -These tests validate that the Triton kernel produces results that: -1. Match the PyTorch reference implementation (correctness). -2. Are bitwise identical across different batch sizes / positions (batch-invariance). -3. Support backward pass (gradient correctness). - -All tests are skipped when Triton or CUDA is unavailable (e.g. on Windows or CPU-only). -""" - -import pytest -import torch - -try: - import triton # noqa: F401 - - _HAS_TRITON = True -except ImportError: - _HAS_TRITON = False - -requires_triton_cuda = pytest.mark.skipif( - not (_HAS_TRITON and torch.cuda.is_available()), - reason="Triton batch-invariant logp requires CUDA device and Triton.", -) - -_V = 300 - - -def _reference_logp(logits: torch.Tensor, target_ids: torch.Tensor) -> torch.Tensor: - """Canonical reference: log_softmax(fp32) + gather.""" - logits_2d = logits.reshape(-1, logits.size(-1)).float() - target_1d = target_ids.reshape(-1).long() - log_probs = torch.log_softmax(logits_2d, dim=-1) - selected = torch.gather(log_probs, dim=-1, index=target_1d.unsqueeze(1)).squeeze(1) - return selected.reshape(target_ids.shape) - - -def _make_row(seed: int, vocab: int = _V, device: str = "cuda") -> torch.Tensor: - """Generate a single deterministic logit row from a seed.""" - gen = torch.Generator(device=device).manual_seed(seed) - return torch.randn(1, vocab, generator=gen, device=device) - - -# --------------------------------------------------------------------------- -# 1. Correctness: Triton vs reference -# --------------------------------------------------------------------------- - - -@requires_triton_cuda -class TestTritonCorrectness: - """Triton kernel output must match log_softmax + gather reference.""" - - def _get_op(self): - from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( - TritonBatchInvariantLogpOp, - ) - return TritonBatchInvariantLogpOp() - - def test_matches_reference_fp32(self): - op = self._get_op() - logits = torch.randn(8, _V, device="cuda") - target = torch.randint(0, _V, (8,), device="cuda") - out = op(logits, target) - ref = _reference_logp(logits, target) - assert out.dtype == torch.float32 - assert torch.allclose(out, ref, atol=1e-5) - - def test_matches_reference_bf16(self): - op = self._get_op() - logits = torch.randn(8, _V, device="cuda", dtype=torch.bfloat16) - target = torch.randint(0, _V, (8,), device="cuda") - out = op(logits, target) - ref = _reference_logp(logits.float(), target) - assert out.dtype == torch.float32 - assert torch.allclose(out, ref, atol=1e-4) - - def test_matches_reference_fp16(self): - op = self._get_op() - logits = torch.randn(8, _V, device="cuda", dtype=torch.float16) - target = torch.randint(0, _V, (8,), device="cuda") - out = op(logits, target) - ref = _reference_logp(logits.float(), target) - assert out.dtype == torch.float32 - assert torch.allclose(out, ref, atol=1e-4) - - def test_large_vocab(self): - op = self._get_op() - logits = torch.randn(4, 128256, device="cuda") - target = torch.randint(0, 128256, (4,), device="cuda") - out = op(logits, target) - ref = _reference_logp(logits, target) - assert torch.allclose(out, ref, atol=1e-5) - - def test_single_token(self): - op = self._get_op() - logits = torch.randn(1, _V, device="cuda") - target = torch.randint(0, _V, (1,), device="cuda") - out = op(logits, target) - ref = _reference_logp(logits, target) - assert torch.allclose(out, ref, atol=1e-5) - - def test_3d_logits(self): - op = self._get_op() - logits = torch.randn(2, 3, _V, device="cuda") - target = torch.randint(0, _V, (2, 3), device="cuda") - out = op(logits, target) - assert out.shape == (2, 3) - ref = _reference_logp(logits, target) - assert torch.allclose(out, ref, atol=1e-5) - - def test_matches_pytorch_op(self): - """Triton and PyTorch ops should agree within tolerance.""" - from rl_engine.kernels.ops.pytorch.loss.batch_invariant_logp import ( - NativeBatchInvariantLogpOp, - ) - triton_op = self._get_op() - pytorch_op = NativeBatchInvariantLogpOp() - logits = torch.randn(16, _V, device="cuda") - target = torch.randint(0, _V, (16,), device="cuda") - out_triton = triton_op(logits, target) - out_pytorch = pytorch_op(logits, target) - assert torch.allclose(out_triton, out_pytorch, atol=1e-5) - - -# --------------------------------------------------------------------------- -# 2. Batch-invariance on GPU via Triton -# --------------------------------------------------------------------------- - - -@requires_triton_cuda -class TestTritonBatchInvariance: - """Triton kernel must be bitwise batch-invariant.""" - - def _get_op(self): - from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( - TritonBatchInvariantLogpOp, - ) - return TritonBatchInvariantLogpOp() - - def test_batch_size_1_vs_n(self): - op = self._get_op() - row = _make_row(42) - target = torch.tensor([7], device="cuda") - result_alone = op(row, target).item() - - for batch_size in [2, 4, 8, 16, 32, 64, 128]: - batch_logits = torch.randn(batch_size, _V, device="cuda") - batch_target = torch.randint(0, _V, (batch_size,), device="cuda") - batch_logits[0] = row.squeeze(0) - batch_target[0] = target.squeeze(0) - result_in_batch = op(batch_logits, batch_target)[0].item() - assert result_alone == result_in_batch, ( - f"Triton drift at batch_size={batch_size}: " - f"alone={result_alone}, in_batch={result_in_batch}" - ) - - def test_different_positions(self): - op = self._get_op() - row = _make_row(99) - target = torch.tensor([13], device="cuda") - batch_size = 16 - results = [] - for pos in range(batch_size): - batch_logits = torch.randn(batch_size, _V, device="cuda") - batch_target = torch.randint(0, _V, (batch_size,), device="cuda") - batch_logits[pos] = row.squeeze(0) - batch_target[pos] = target.squeeze(0) - results.append(op(batch_logits, batch_target)[pos].item()) - assert all(r == results[0] for r in results), ( - f"Triton position drift: unique = {set(results)}" - ) - - def test_repeated_runs(self): - op = self._get_op() - logits = torch.randn(16, _V, device="cuda") - target = torch.randint(0, _V, (16,), device="cuda") - results = [op(logits, target) for _ in range(50)] - for i, r in enumerate(results[1:], 1): - assert torch.equal(r, results[0]), f"Triton run {i} differs from run 0" - - def test_mixed_batch_content(self): - op = self._get_op() - row = _make_row(77) - target = torch.tensor([25], device="cuda") - batch_size = 8 - results = [] - for trial_seed in range(20): - torch.manual_seed(trial_seed * 1000) - batch_logits = torch.randn(batch_size, _V, device="cuda") - batch_target = torch.randint(0, _V, (batch_size,), device="cuda") - batch_logits[3] = row.squeeze(0) - batch_target[3] = target.squeeze(0) - results.append(op(batch_logits, batch_target)[3].item()) - assert all(r == results[0] for r in results), ( - f"Triton mixed-batch drift: unique = {set(results)}" - ) - - -# --------------------------------------------------------------------------- -# 3. Backward / gradient -# --------------------------------------------------------------------------- - - -@requires_triton_cuda -class TestTritonBackward: - """Gradient through the Triton op must match reference.""" - - def _get_op(self): - from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( - TritonBatchInvariantLogpOp, - ) - return TritonBatchInvariantLogpOp() - - def test_backward_matches_reference(self): - op = self._get_op() - logits = torch.randn(4, _V, device="cuda", requires_grad=True) - target = torch.randint(0, _V, (4,), device="cuda") - - out = op(logits, target).sum() - out.backward() - grad = logits.grad.detach().clone() - - ref_logits = logits.detach().clone().requires_grad_(True) - ref = _reference_logp(ref_logits, target).sum() - ref.backward() - - assert torch.allclose(grad, ref_logits.grad, atol=1e-5) - - def test_gradient_batch_invariance(self): - op = self._get_op() - row = _make_row(42) - target = torch.tensor([7], device="cuda") - - logits_alone = row.clone().requires_grad_(True) - op(logits_alone, target).sum().backward() - grad_alone = logits_alone.grad.detach().clone() - - for batch_size in [4, 16, 64]: - batch_logits = torch.randn(batch_size, _V, device="cuda") - batch_logits[0] = row.squeeze(0) - batch_logits.requires_grad_(True) - batch_target = torch.randint(0, _V, (batch_size,), device="cuda") - batch_target[0] = target.squeeze(0) - op(batch_logits, batch_target).sum().backward() - grad_in_batch = batch_logits.grad[0:1].detach().clone() - assert torch.allclose(grad_alone, grad_in_batch, atol=1e-5), ( - f"Triton gradient drift at batch_size={batch_size}" - ) - - def test_ignored_row_grad_is_zero(self): - """Ignored rows must have zero gradient across the entire vocab.""" - op = self._get_op() - logits = torch.randn(4, _V, device="cuda", requires_grad=True) - target = torch.tensor([0, -100, 2, -100], device="cuda") - op(logits, target).sum().backward() - assert torch.equal(logits.grad[1], torch.zeros(_V, device="cuda")) - assert torch.equal(logits.grad[3], torch.zeros(_V, device="cuda")) - - def test_backward_bf16_input(self): - """Backward with bf16 logits should match fp32 reference within tolerance.""" - op = self._get_op() - logits = torch.randn(8, _V, device="cuda", dtype=torch.bfloat16, requires_grad=True) - target = torch.randint(0, _V, (8,), device="cuda") - - op(logits, target).sum().backward() - grad = logits.grad.detach().clone() - - ref_logits = logits.detach().float().requires_grad_(True) - _reference_logp(ref_logits, target).sum().backward() - - assert torch.allclose(grad.float(), ref_logits.grad, atol=1e-2) - - def test_backward_fp16_input(self): - """Backward with fp16 logits should match fp32 reference within tolerance.""" - op = self._get_op() - logits = torch.randn(8, _V, device="cuda", dtype=torch.float16, requires_grad=True) - target = torch.randint(0, _V, (8,), device="cuda") - - op(logits, target).sum().backward() - grad = logits.grad.detach().clone() - - ref_logits = logits.detach().float().requires_grad_(True) - _reference_logp(ref_logits, target).sum().backward() - - assert torch.allclose(grad.float(), ref_logits.grad, atol=1e-2) - - -# --------------------------------------------------------------------------- -# 4. ignore_index handling -# --------------------------------------------------------------------------- - - -@requires_triton_cuda -class TestTritonIgnoreIndex: - - def _get_op(self): - from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( - TritonBatchInvariantLogpOp, - ) - return TritonBatchInvariantLogpOp() - - def test_ignore_outputs_zero(self): - op = self._get_op() - logits = torch.randn(4, _V, device="cuda") - target = torch.tensor([0, -100, 2, -100], device="cuda") - out = op(logits, target) - assert out[1].item() == 0.0 - assert out[3].item() == 0.0 - ref = _reference_logp(logits[[0, 2]], target[[0, 2]]) - assert torch.allclose(out[[0, 2]], ref, atol=1e-5) - - def test_all_ignore(self): - op = self._get_op() - logits = torch.randn(4, _V, device="cuda") - target = torch.full((4,), -100, device="cuda") - out = op(logits, target) - assert torch.equal(out, torch.zeros(4, device="cuda")) - - -# --------------------------------------------------------------------------- -# 5. Validation -# --------------------------------------------------------------------------- - - -@requires_triton_cuda -class TestTritonValidation: - - def _get_op(self): - from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( - TritonBatchInvariantLogpOp, - ) - return TritonBatchInvariantLogpOp() - - def test_rejects_cpu_tensor(self): - op = self._get_op() - with pytest.raises(RuntimeError, match="requires a GPU"): - op(torch.randn(4, _V), torch.randint(0, _V, (4,))) - - def test_rejects_1d_logits(self): - op = self._get_op() - with pytest.raises(ValueError, match="at least 2-D"): - op(torch.randn(10, device="cuda"), torch.tensor([0], device="cuda")) - - def test_rejects_invalid_target(self): - op = self._get_op() - logits = torch.randn(4, _V, device="cuda") - target = torch.tensor([0, -1, 2, 3], device="cuda") - with pytest.raises(ValueError, match="outside"): - op(logits, target) From 60697a518fffdaf2649a7ffcd763ee295c1b0a43 Mon Sep 17 00:00:00 2001 From: hihaluemen <1596916766@qq.com> Date: Mon, 29 Jun 2026 07:55:08 +0800 Subject: [PATCH 3/4] [REVIEW][kernels] tighten batch-invariant logprob validation --- docs/operators/batch-invariant-logp.md | 14 +++++------ .../ops/pytorch/loss/batch_invariant_logp.py | 6 ++--- tests/test_batch_invariant_logp.py | 24 +++++++++++++++---- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/docs/operators/batch-invariant-logp.md b/docs/operators/batch-invariant-logp.md index e6c2702..b05f85a 100644 --- a/docs/operators/batch-invariant-logp.md +++ b/docs/operators/batch-invariant-logp.md @@ -27,7 +27,7 @@ logp = batch_invariant_logp( logits, # [B, T, V] or [N, V], differentiable target_ids, # [B, T] or [N], int ignore_index=-100, - validate=False, # opt-in target range check (syncs CUDA stream) + validate=False, # Triton fast path; use True to debug-check target range ) # -> [B, T] or [N], float32 logp.sum().backward() # gradients flow into logits only @@ -89,13 +89,13 @@ out[row] = 0.0 grad_logits[row, :] = 0.0 ``` -Non-ignored target ids outside `[0, V)` raise `ValueError` when -`validate=True`. In particular, `target=-1` is invalid unless -`ignore_index=-1`. +Non-ignored target ids outside `[0, V)` are invalid. In particular, +`target=-1` is invalid unless `ignore_index=-1`. -`validate=False` (default) skips the target range check to avoid CUDA stream -synchronization in training hot paths. Use `validate=True` during debugging or -in tests. +The PyTorch native backend validates target ranges by default. The Triton +backend defaults to `validate=False` to avoid CUDA stream synchronization in +training hot paths. Use `validate=True` during debugging or in tests when +calling the Triton backend with untrusted targets. ## Batch-Invariance diff --git a/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py b/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py index ef9270b..043aa71 100644 --- a/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py +++ b/rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py @@ -25,7 +25,7 @@ def __call__( target_ids: torch.Tensor, ignore_index: int = -100, *, - validate: bool = False, + validate: bool = True, ) -> torch.Tensor: return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate) @@ -35,7 +35,7 @@ def apply( target_ids: torch.Tensor, ignore_index: int = -100, *, - validate: bool = False, + validate: bool = True, ) -> torch.Tensor: self._validate_shapes(logits, target_ids) @@ -60,7 +60,7 @@ def _row_wise_selected_logprob( target_1d: torch.Tensor, *, ignore_index: int, - validate: bool = False, + validate: bool = True, ) -> torch.Tensor: """Per-row selected logprob with locked reduction order. diff --git a/tests/test_batch_invariant_logp.py b/tests/test_batch_invariant_logp.py index d747e42..8bca4e8 100644 --- a/tests/test_batch_invariant_logp.py +++ b/tests/test_batch_invariant_logp.py @@ -36,6 +36,11 @@ reason="Triton batch-invariant logp requires CUDA device and Triton.", ) +requires_triton = pytest.mark.skipif( + not _HAS_TRITON, + reason="Triton package required.", +) + def _reference_logp(logits: torch.Tensor, target_ids: torch.Tensor) -> torch.Tensor: """Canonical reference: log_softmax(fp32) + gather.""" @@ -278,14 +283,14 @@ def test_rejects_negative_target(self): logits = torch.randn(4, _V) target = torch.tensor([0, -1, 2, 3]) with pytest.raises(ValueError, match="outside"): - op(logits, target, validate=True) + op(logits, target) def test_rejects_target_ge_vocab(self): op = NativeBatchInvariantLogpOp() logits = torch.randn(4, _V) target = torch.tensor([0, 1, _V, 3]) with pytest.raises(ValueError, match="outside"): - op(logits, target, validate=True) + op(logits, target) def test_negative_target_with_ignore_index_ok(self): op = NativeBatchInvariantLogpOp() @@ -755,8 +760,9 @@ def test_all_ignore(self): # --------------------------------------------------------------------------- -@requires_triton_cuda -class TestTritonValidation: +@requires_triton +class TestTritonCPUValidation: + """Tests that only need Triton importable, not a GPU.""" def _get_op(self): from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( @@ -769,6 +775,16 @@ def test_rejects_cpu_tensor(self): with pytest.raises(RuntimeError, match="requires a GPU"): op(torch.randn(4, _V), torch.randint(0, _V, (4,))) + +@requires_triton_cuda +class TestTritonValidation: + + def _get_op(self): + from rl_engine.kernels.ops.triton.loss.batch_invariant_logp import ( + TritonBatchInvariantLogpOp, + ) + return TritonBatchInvariantLogpOp() + def test_rejects_1d_logits(self): op = self._get_op() with pytest.raises(ValueError, match="at least 2-D"): From c31cc31a0346a20a11b20c9be48fc69d497e0d00 Mon Sep 17 00:00:00 2001 From: hihaluemen <1596916766@qq.com> Date: Mon, 29 Jun 2026 22:46:37 +0800 Subject: [PATCH 4/4] [REVIEW][kernels] optimize batch-invariant logprob backward storage --- rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py b/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py index f32017c..12e3472 100644 --- a/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py +++ b/rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py @@ -160,7 +160,7 @@ def backward(ctx, grad_output): num_tokens = logits_2d.shape[0] grad_flat = grad_output.reshape(-1).contiguous().to(torch.float32) - grad_logits = torch.empty_like(logits_2d, dtype=torch.float32) + grad_logits = torch.empty_like(logits_2d) grid = (num_tokens, triton.cdiv(vocab_size, _BLOCK_V)) _batch_invariant_logp_bwd_kernel[grid]( @@ -176,9 +176,7 @@ def backward(ctx, grad_output): BLOCK_V=_BLOCK_V, ) - grad_logits = grad_logits.to(logits_2d.dtype).reshape( - ctx.lead_shape + (vocab_size,) - ) + grad_logits = grad_logits.reshape(ctx.lead_shape + (vocab_size,)) return grad_logits, None, None