Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,6 @@ cython_debug/
marimo/_static/
marimo/_lsp/
__marimo__/

# Local dev notes (not for upstream)
_dev_notes/
1 change: 1 addition & 0 deletions docs/.nav.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ nav:
- operators/README.md
- operators/fused-logp.md
- operators/linear-logp.md
- operators/batch-invariant-logp.md
- operators/grpo-loss.md
- operators/ratio-kl.md
- operators/sampling.md
Expand Down
1 change: 1 addition & 0 deletions docs/operators/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Every operator page should include:

- [Fused LogP](fused-logp.md)
- [Fused Linear LogP](linear-logp.md)
- [Batch-Invariant LogP](batch-invariant-logp.md)
- [GRPO Loss](grpo-loss.md)
- [Policy Ratio + KL Penalty](ratio-kl.md)
- [Sampling](sampling.md)
Expand Down
174 changes: 174 additions & 0 deletions docs/operators/batch-invariant-logp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Batch-Invariant LogP

Batch-Invariant LogP computes selected token log-probabilities from already
materialized logits:

```text
out[row] = logits[row, target_ids[row]] - logsumexp(logits[row, :])
```

It targets RL post-training paths where policy log-probs are compared across
different packing, padding, and batch layouts. The key contract is
batch-invariance: for a fixed row of logits and target id, the result must not
change when that row is evaluated alone, at a different batch position, or with
different neighboring rows.

Unlike `linear_logp`, this operator does not fuse the LM-head projection. It
takes `[*, V]` logits as input and returns one selected log-probability per row.

## Entry Point

```python
from rl_engine.kernels.registry import kernel_registry

batch_invariant_logp = kernel_registry.get_op("batch_invariant_logp")

logp = batch_invariant_logp(
logits, # [B, T, V] or [N, V], differentiable
target_ids, # [B, T] or [N], int
ignore_index=-100,
validate=False, # Triton fast path; use True to debug-check target range
) # -> [B, T] or [N], float32

logp.sum().backward() # gradients flow into logits only
```

## Backends

| Backend | Wrapper | Status |
| --- | --- | --- |
| CUDA / ROCm (Triton) | `TritonBatchInvariantLogpOp` | Triton online-softmax forward and tile-wise backward. Requires a GPU tensor. |
| PyTorch native | `NativeBatchInvariantLogpOp` | FP32 reference path; CPU fallback and Triton-less fallback. |

Current dispatch:

```text
CUDA / ROCm: Triton -> PyTorch
CPU: PyTorch
```

A compiled CUDA backend and benchmark suite are planned follow-up work.
Benchmarks are not included in this PR; they will be added alongside the CUDA
backend in a subsequent PR.

## Tensor Contract

| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `logits` | `[N, V]` / `[B, T, V]` / `[*lead, V]` | fp32 / fp16 / bf16 | Differentiable input; last dimension is vocab. |
| `target_ids` | `[N]` / `[B, T]` / `[*lead]` | int | Same leading shape as `logits`; non-ignored values in `[0, V)`. |
| `ignore_index` | scalar int | Python int | Default `-100`. Ignored rows output zero and receive zero gradient. |
| Output | `[N]` / `[B, T]` / `[*lead]` | float32 | Selected log-probability per row. |

`target_ids` is integer and non-differentiable. Gradients flow only into
`logits`.

## Reference Semantics

For non-ignored rows:

```python
logits_2d = logits.reshape(-1, logits.size(-1)).float()
target_1d = target_ids.reshape(-1).long()

log_probs = torch.log_softmax(logits_2d, dim=-1)
selected = torch.gather(
log_probs,
dim=-1,
index=target_1d.unsqueeze(-1),
).squeeze(-1)

out = selected.reshape(target_ids.shape)
```

For ignored rows:

```text
target_ids[row] == ignore_index
out[row] = 0.0
grad_logits[row, :] = 0.0
```

Non-ignored target ids outside `[0, V)` are invalid. In particular,
`target=-1` is invalid unless `ignore_index=-1`.

The PyTorch native backend validates target ranges by default. The Triton
backend defaults to `validate=False` to avoid CUDA stream synchronization in
training hot paths. Use `validate=True` during debugging or in tests when
calling the Triton backend with untrusted targets.

## Batch-Invariance

The operator is designed so each row is computed independently:

- The PyTorch path reshapes to `[N, V]` and applies row-wise reductions.
- The Triton forward uses `grid=(num_tokens,)`, so one program owns exactly one
row.
- Triton vocab traversal uses a fixed `_BLOCK_V=1024` and does not autotune by
batch size.
- Triton forward scans vocab tiles left-to-right using online logsumexp.
- Triton backward uses `grid=(num_tokens, vocab_tiles)` and writes one row tile
per program. It reuses the forward-saved per-row `lse`, so no backward
reduction crosses row boundaries.
- No atomic writes are used.

These constraints ensure the result for a row depends only on that row's logits
and target id, not on batch size, row position, or neighboring rows.

## Accuracy

Both backends accumulate reductions in float32 and return float32 outputs. Tests
compare against `torch.log_softmax(...).gather(...)` with dtype-appropriate
tolerances:

```text
fp32 forward: atol around 1e-5
fp16/bf16 forward: atol around 1e-4
fp16/bf16 backward: checked against fp32 reference with relaxed tolerance
```

CPU-vs-CUDA comparisons use tolerance-based checks; batch-invariance checks
within the same backend use exact equality where appropriate.

## Minimal Example

```python
import torch

from rl_engine.kernels.registry import kernel_registry

op = kernel_registry.get_op("batch_invariant_logp")

logits = torch.randn(2, 4, 300, device="cuda", dtype=torch.bfloat16)
target_ids = torch.randint(0, 300, (2, 4), device="cuda")
target_ids[0, 0] = -100

out = op(logits, target_ids, ignore_index=-100)
assert out.shape == target_ids.shape
assert out.dtype == torch.float32
assert out[0, 0].item() == 0.0

out.sum().backward()
```

## Tests

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just make a note that this kernel is currently lacking benchmarks, I will support it in another PR (cuda ver. for this).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed, thanks.
Changes:

  • Renamed PR scope/title to Batch-invariant logprob (Native, Triton); CUDA + benchmarks will be handled in a follow-up PR.
  • Made target range validation opt-in via validate=False by default to avoid CUDA stream sync in the Triton hot path.
  • Merged Native and Triton tests into tests/test_batch_invariant_logp.py.
  • Updated operator docs to mention validate=True and benchmark/CUDA follow-up.


```bash
python -m pytest tests/test_batch_invariant_logp.py -q -rs
```

All backends (Native, Triton) are tested in a single file. Coverage includes:
correctness, leading-shape preservation, batch-invariance (bitwise), validation,
ignore-index behavior, backward correctness, CUDA smoke cases, registry
dispatch, and Triton-specific fp32/fp16/bf16 correctness, large vocab, backward
gradient batch-invariance, and ignored-row zero gradients.

Triton tests skip when Triton or CUDA is unavailable. On Windows, run via
WSL/Linux with CUDA.

## Implementation Files

- `rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py`
- `rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py`
- `rl_engine/kernels/registry.py`
- `tests/test_batch_invariant_logp.py`
124 changes: 124 additions & 0 deletions rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

from __future__ import annotations

import torch


class NativeBatchInvariantLogpOp:
"""Batch-invariant selected-token log-probability.

``selected_logprob[t] = logits[t, target_ids[t]] - logsumexp(logits[t, :])``

All reductions run in FP32. The row-wise max -> subtract -> exp -> sum -> log
pipeline is fully independent per row, so the result for any row depends
only on that row's logits and target - never on batch size or layout.
"""

def __init__(self) -> None:
pass

def __call__(
self,
logits: torch.Tensor,
target_ids: torch.Tensor,
ignore_index: int = -100,
*,
validate: bool = True,
) -> torch.Tensor:
return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def apply(
self,
logits: torch.Tensor,
target_ids: torch.Tensor,
ignore_index: int = -100,
*,
validate: bool = True,
) -> torch.Tensor:
self._validate_shapes(logits, target_ids)

lead_shape = logits.shape[:-1]
vocab_size = logits.size(-1)

logits_2d = logits.reshape(-1, vocab_size).float()
target_1d = target_ids.reshape(-1).to(logits.device, dtype=torch.long)
Comment on lines +45 to +46

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '\n== file outline ==\n'
ast-grep outline rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py --view expanded || true

printf '\n== file contents (relevant slice) ==\n'
sed -n '1,220p' rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py

printf '\n== related tests search ==\n'
rg -n "batch_invariant_logp|NativeBatchInvariantLogpOp|target_ids|ignore_index|dtype" rl_engine -g '*test*' -g '*spec*' || true

printf '\n== pytorch casting probe ==\n'
python3 - <<'PY'
import torch

cases = [
    torch.tensor([1.9, -0.2, 3.7], dtype=torch.float32),
    torch.tensor([True, False, True], dtype=torch.bool),
    torch.tensor([1, 2, 3], dtype=torch.int32),
]
for t in cases:
    out = t.to(dtype=torch.long)
    print(t.dtype, '->', out.dtype, out.tolist())
PY

Repository: RL-Align/RL-Kernel

Length of output: 5090


🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '\n== call sites ==\n'
rg -n "NativeBatchInvariantLogpOp|batch_invariant_logp|selected_logprob|target_ids" rl_engine -g '!**/*.pyc' || true

printf '\n== tests in loss area ==\n'
fd -a "test*.py" rl_engine/tests rl_engine/kernels 2>/dev/null || true

printf '\n== nearby test files mentioning logprob ==\n'
rg -n "logprob|batch_invariant|target_ids|ignore_index" rl_engine/tests rl_engine/kernels/tests -g '*test*.py' || true

Repository: RL-Align/RL-Kernel

Length of output: 9236


Reject non-integer target_ids before the cast. .to(dtype=torch.long) truncates float/bool tensors, so invalid inputs can silently pick the wrong class instead of failing.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py` around lines 41 -
42, The target conversion in batch_invariant_logp is too permissive because
target_ids is cast to torch.long after reshaping, which can silently truncate
float or bool inputs. Add an explicit validation step in the
batch_invariant_logp path to require target_ids to already be an integer/long
tensor before any cast, and fail fast for non-integer inputs. Keep the reshape
logic for logits_2d and target_1d, but ensure the check happens before
target_ids is converted or used.


selected_logp = self._row_wise_selected_logprob(
logits_2d, target_1d, ignore_index=ignore_index, validate=validate
)

return selected_logp.reshape(lead_shape)

# ---------------------------------------------------------------------- #
# Core Computation
# ---------------------------------------------------------------------- #
@staticmethod
def _row_wise_selected_logprob(
logits_2d: torch.Tensor,
target_1d: torch.Tensor,
*,
ignore_index: int,
validate: bool = True,
) -> torch.Tensor:
"""Per-row selected logprob with locked reduction order.

The three reduction steps (max, sum-exp, gather) operate on each row
independently. PyTorch's ``max(dim=-1)`` and ``sum(dim=-1)`` iterate
the vocab dimension in a fixed, deterministic order for a given row
length, and that order does **not** change when more rows are added
to the batch. This is the property that makes the op batch-invariant.

Accumulation is done entirely in FP32 to avoid half-precision
catastrophic cancellation during the ``exp(logit - max)`` step.
"""
vocab_size = logits_2d.size(1)

valid_mask = target_1d != ignore_index

if validate:
valid_targets = target_1d[valid_mask]
if valid_targets.numel() > 0 and (
(valid_targets < 0).any() or (valid_targets >= vocab_size).any()
):
bad = valid_targets[(valid_targets < 0) | (valid_targets >= vocab_size)]
raise ValueError(
f"target_ids contains values outside [0, {vocab_size}): {bad.tolist()}"
)

safe_target = target_1d.clone()
safe_target[~valid_mask] = 0

# logsumexp(z) = log(sum(exp(z - max(z)))) + max(z)
row_max = logits_2d.max(dim=-1).values
shifted = logits_2d - row_max.unsqueeze(-1)
exp_shifted = shifted.exp()
sum_exp = exp_shifted.sum(dim=-1)
log_sum_exp = sum_exp.log() + row_max

row_indices = torch.arange(logits_2d.size(0), device=logits_2d.device)
selected_logit = logits_2d[row_indices, safe_target]

selected_logp = selected_logit - log_sum_exp

selected_logp = selected_logp.where(
valid_mask, torch.zeros_like(selected_logp)
)

return selected_logp

# ---------------------------------------------------------------------- #
# Helper
# ---------------------------------------------------------------------- #
@staticmethod
def _validate_shapes(logits: torch.Tensor, target_ids: torch.Tensor) -> None:
if logits.dim() < 2:
raise ValueError(
f"logits must be at least 2-D ([*lead, V]), got shape {tuple(logits.shape)}"
)
if logits.shape[:-1] != target_ids.shape:
raise ValueError(
f"logits leading shape {tuple(logits.shape[:-1])} must match "
f"target_ids shape {tuple(target_ids.shape)}"
)
Loading