-
Notifications
You must be signed in to change notification settings - Fork 42
[WS1][kernels] Batch-invariant logprob (Native, Triton) #199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d829be2
d748a8f
60697a5
c31cc31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -206,3 +206,6 @@ cython_debug/ | |
| marimo/_static/ | ||
| marimo/_lsp/ | ||
| __marimo__/ | ||
|
|
||
| # Local dev notes (not for upstream) | ||
| _dev_notes/ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,174 @@ | ||
| # Batch-Invariant LogP | ||
|
|
||
| Batch-Invariant LogP computes selected token log-probabilities from already | ||
| materialized logits: | ||
|
|
||
| ```text | ||
| out[row] = logits[row, target_ids[row]] - logsumexp(logits[row, :]) | ||
| ``` | ||
|
|
||
| It targets RL post-training paths where policy log-probs are compared across | ||
| different packing, padding, and batch layouts. The key contract is | ||
| batch-invariance: for a fixed row of logits and target id, the result must not | ||
| change when that row is evaluated alone, at a different batch position, or with | ||
| different neighboring rows. | ||
|
|
||
| Unlike `linear_logp`, this operator does not fuse the LM-head projection. It | ||
| takes `[*, V]` logits as input and returns one selected log-probability per row. | ||
|
|
||
| ## Entry Point | ||
|
|
||
| ```python | ||
| from rl_engine.kernels.registry import kernel_registry | ||
|
|
||
| batch_invariant_logp = kernel_registry.get_op("batch_invariant_logp") | ||
|
|
||
| logp = batch_invariant_logp( | ||
| logits, # [B, T, V] or [N, V], differentiable | ||
| target_ids, # [B, T] or [N], int | ||
| ignore_index=-100, | ||
| validate=False, # Triton fast path; use True to debug-check target range | ||
| ) # -> [B, T] or [N], float32 | ||
|
|
||
| logp.sum().backward() # gradients flow into logits only | ||
| ``` | ||
|
|
||
| ## Backends | ||
|
|
||
| | Backend | Wrapper | Status | | ||
| | --- | --- | --- | | ||
| | CUDA / ROCm (Triton) | `TritonBatchInvariantLogpOp` | Triton online-softmax forward and tile-wise backward. Requires a GPU tensor. | | ||
| | PyTorch native | `NativeBatchInvariantLogpOp` | FP32 reference path; CPU fallback and Triton-less fallback. | | ||
|
|
||
| Current dispatch: | ||
|
|
||
| ```text | ||
| CUDA / ROCm: Triton -> PyTorch | ||
| CPU: PyTorch | ||
| ``` | ||
|
|
||
| A compiled CUDA backend and benchmark suite are planned follow-up work. | ||
| Benchmarks are not included in this PR; they will be added alongside the CUDA | ||
| backend in a subsequent PR. | ||
|
|
||
| ## Tensor Contract | ||
|
|
||
| | Argument | Shape | Dtype | Requirements | | ||
| | --- | --- | --- | --- | | ||
| | `logits` | `[N, V]` / `[B, T, V]` / `[*lead, V]` | fp32 / fp16 / bf16 | Differentiable input; last dimension is vocab. | | ||
| | `target_ids` | `[N]` / `[B, T]` / `[*lead]` | int | Same leading shape as `logits`; non-ignored values in `[0, V)`. | | ||
| | `ignore_index` | scalar int | Python int | Default `-100`. Ignored rows output zero and receive zero gradient. | | ||
| | Output | `[N]` / `[B, T]` / `[*lead]` | float32 | Selected log-probability per row. | | ||
|
|
||
| `target_ids` is integer and non-differentiable. Gradients flow only into | ||
| `logits`. | ||
|
|
||
| ## Reference Semantics | ||
|
|
||
| For non-ignored rows: | ||
|
|
||
| ```python | ||
| logits_2d = logits.reshape(-1, logits.size(-1)).float() | ||
| target_1d = target_ids.reshape(-1).long() | ||
|
|
||
| log_probs = torch.log_softmax(logits_2d, dim=-1) | ||
| selected = torch.gather( | ||
| log_probs, | ||
| dim=-1, | ||
| index=target_1d.unsqueeze(-1), | ||
| ).squeeze(-1) | ||
|
|
||
| out = selected.reshape(target_ids.shape) | ||
| ``` | ||
|
|
||
| For ignored rows: | ||
|
|
||
| ```text | ||
| target_ids[row] == ignore_index | ||
| out[row] = 0.0 | ||
| grad_logits[row, :] = 0.0 | ||
| ``` | ||
|
|
||
| Non-ignored target ids outside `[0, V)` are invalid. In particular, | ||
| `target=-1` is invalid unless `ignore_index=-1`. | ||
|
|
||
| The PyTorch native backend validates target ranges by default. The Triton | ||
| backend defaults to `validate=False` to avoid CUDA stream synchronization in | ||
| training hot paths. Use `validate=True` during debugging or in tests when | ||
| calling the Triton backend with untrusted targets. | ||
|
|
||
| ## Batch-Invariance | ||
|
|
||
| The operator is designed so each row is computed independently: | ||
|
|
||
| - The PyTorch path reshapes to `[N, V]` and applies row-wise reductions. | ||
| - The Triton forward uses `grid=(num_tokens,)`, so one program owns exactly one | ||
| row. | ||
| - Triton vocab traversal uses a fixed `_BLOCK_V=1024` and does not autotune by | ||
| batch size. | ||
| - Triton forward scans vocab tiles left-to-right using online logsumexp. | ||
| - Triton backward uses `grid=(num_tokens, vocab_tiles)` and writes one row tile | ||
| per program. It reuses the forward-saved per-row `lse`, so no backward | ||
| reduction crosses row boundaries. | ||
| - No atomic writes are used. | ||
|
|
||
| These constraints ensure the result for a row depends only on that row's logits | ||
| and target id, not on batch size, row position, or neighboring rows. | ||
|
|
||
| ## Accuracy | ||
|
|
||
| Both backends accumulate reductions in float32 and return float32 outputs. Tests | ||
| compare against `torch.log_softmax(...).gather(...)` with dtype-appropriate | ||
| tolerances: | ||
|
|
||
| ```text | ||
| fp32 forward: atol around 1e-5 | ||
| fp16/bf16 forward: atol around 1e-4 | ||
| fp16/bf16 backward: checked against fp32 reference with relaxed tolerance | ||
| ``` | ||
|
|
||
| CPU-vs-CUDA comparisons use tolerance-based checks; batch-invariance checks | ||
| within the same backend use exact equality where appropriate. | ||
|
|
||
| ## Minimal Example | ||
|
|
||
| ```python | ||
| import torch | ||
|
|
||
| from rl_engine.kernels.registry import kernel_registry | ||
|
|
||
| op = kernel_registry.get_op("batch_invariant_logp") | ||
|
|
||
| logits = torch.randn(2, 4, 300, device="cuda", dtype=torch.bfloat16) | ||
| target_ids = torch.randint(0, 300, (2, 4), device="cuda") | ||
| target_ids[0, 0] = -100 | ||
|
|
||
| out = op(logits, target_ids, ignore_index=-100) | ||
| assert out.shape == target_ids.shape | ||
| assert out.dtype == torch.float32 | ||
| assert out[0, 0].item() == 0.0 | ||
|
|
||
| out.sum().backward() | ||
| ``` | ||
|
|
||
| ## Tests | ||
|
|
||
| ```bash | ||
| python -m pytest tests/test_batch_invariant_logp.py -q -rs | ||
| ``` | ||
|
|
||
| All backends (Native, Triton) are tested in a single file. Coverage includes: | ||
| correctness, leading-shape preservation, batch-invariance (bitwise), validation, | ||
| ignore-index behavior, backward correctness, CUDA smoke cases, registry | ||
| dispatch, and Triton-specific fp32/fp16/bf16 correctness, large vocab, backward | ||
| gradient batch-invariance, and ignored-row zero gradients. | ||
|
|
||
| Triton tests skip when Triton or CUDA is unavailable. On Windows, run via | ||
| WSL/Linux with CUDA. | ||
|
|
||
| ## Implementation Files | ||
|
|
||
| - `rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py` | ||
| - `rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py` | ||
| - `rl_engine/kernels/registry.py` | ||
| - `tests/test_batch_invariant_logp.py` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2026 RL-Kernel Contributors | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class NativeBatchInvariantLogpOp: | ||
| """Batch-invariant selected-token log-probability. | ||
|
|
||
| ``selected_logprob[t] = logits[t, target_ids[t]] - logsumexp(logits[t, :])`` | ||
|
|
||
| All reductions run in FP32. The row-wise max -> subtract -> exp -> sum -> log | ||
| pipeline is fully independent per row, so the result for any row depends | ||
| only on that row's logits and target - never on batch size or layout. | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| pass | ||
|
|
||
| def __call__( | ||
| self, | ||
| logits: torch.Tensor, | ||
| target_ids: torch.Tensor, | ||
| ignore_index: int = -100, | ||
| *, | ||
| validate: bool = True, | ||
| ) -> torch.Tensor: | ||
| return self.apply(logits, target_ids, ignore_index=ignore_index, validate=validate) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| def apply( | ||
| self, | ||
| logits: torch.Tensor, | ||
| target_ids: torch.Tensor, | ||
| ignore_index: int = -100, | ||
| *, | ||
| validate: bool = True, | ||
| ) -> torch.Tensor: | ||
| self._validate_shapes(logits, target_ids) | ||
|
|
||
| lead_shape = logits.shape[:-1] | ||
| vocab_size = logits.size(-1) | ||
|
|
||
| logits_2d = logits.reshape(-1, vocab_size).float() | ||
| target_1d = target_ids.reshape(-1).to(logits.device, dtype=torch.long) | ||
|
Comment on lines
+45
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win 🧩 Analysis chain🏁 Script executed: #!/bin/bash
set -euo pipefail
printf '\n== file outline ==\n'
ast-grep outline rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py --view expanded || true
printf '\n== file contents (relevant slice) ==\n'
sed -n '1,220p' rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
printf '\n== related tests search ==\n'
rg -n "batch_invariant_logp|NativeBatchInvariantLogpOp|target_ids|ignore_index|dtype" rl_engine -g '*test*' -g '*spec*' || true
printf '\n== pytorch casting probe ==\n'
python3 - <<'PY'
import torch
cases = [
torch.tensor([1.9, -0.2, 3.7], dtype=torch.float32),
torch.tensor([True, False, True], dtype=torch.bool),
torch.tensor([1, 2, 3], dtype=torch.int32),
]
for t in cases:
out = t.to(dtype=torch.long)
print(t.dtype, '->', out.dtype, out.tolist())
PYRepository: RL-Align/RL-Kernel Length of output: 5090 🏁 Script executed: #!/bin/bash
set -euo pipefail
printf '\n== call sites ==\n'
rg -n "NativeBatchInvariantLogpOp|batch_invariant_logp|selected_logprob|target_ids" rl_engine -g '!**/*.pyc' || true
printf '\n== tests in loss area ==\n'
fd -a "test*.py" rl_engine/tests rl_engine/kernels 2>/dev/null || true
printf '\n== nearby test files mentioning logprob ==\n'
rg -n "logprob|batch_invariant|target_ids|ignore_index" rl_engine/tests rl_engine/kernels/tests -g '*test*.py' || trueRepository: RL-Align/RL-Kernel Length of output: 9236 Reject non-integer 🤖 Prompt for AI Agents |
||
|
|
||
| selected_logp = self._row_wise_selected_logprob( | ||
| logits_2d, target_1d, ignore_index=ignore_index, validate=validate | ||
| ) | ||
|
|
||
| return selected_logp.reshape(lead_shape) | ||
|
|
||
| # ---------------------------------------------------------------------- # | ||
| # Core Computation | ||
| # ---------------------------------------------------------------------- # | ||
| @staticmethod | ||
| def _row_wise_selected_logprob( | ||
| logits_2d: torch.Tensor, | ||
| target_1d: torch.Tensor, | ||
| *, | ||
| ignore_index: int, | ||
| validate: bool = True, | ||
| ) -> torch.Tensor: | ||
| """Per-row selected logprob with locked reduction order. | ||
|
|
||
| The three reduction steps (max, sum-exp, gather) operate on each row | ||
| independently. PyTorch's ``max(dim=-1)`` and ``sum(dim=-1)`` iterate | ||
| the vocab dimension in a fixed, deterministic order for a given row | ||
| length, and that order does **not** change when more rows are added | ||
| to the batch. This is the property that makes the op batch-invariant. | ||
|
|
||
| Accumulation is done entirely in FP32 to avoid half-precision | ||
| catastrophic cancellation during the ``exp(logit - max)`` step. | ||
| """ | ||
| vocab_size = logits_2d.size(1) | ||
|
|
||
| valid_mask = target_1d != ignore_index | ||
|
|
||
| if validate: | ||
| valid_targets = target_1d[valid_mask] | ||
| if valid_targets.numel() > 0 and ( | ||
| (valid_targets < 0).any() or (valid_targets >= vocab_size).any() | ||
| ): | ||
| bad = valid_targets[(valid_targets < 0) | (valid_targets >= vocab_size)] | ||
| raise ValueError( | ||
| f"target_ids contains values outside [0, {vocab_size}): {bad.tolist()}" | ||
| ) | ||
|
|
||
| safe_target = target_1d.clone() | ||
| safe_target[~valid_mask] = 0 | ||
|
|
||
| # logsumexp(z) = log(sum(exp(z - max(z)))) + max(z) | ||
| row_max = logits_2d.max(dim=-1).values | ||
| shifted = logits_2d - row_max.unsqueeze(-1) | ||
| exp_shifted = shifted.exp() | ||
| sum_exp = exp_shifted.sum(dim=-1) | ||
| log_sum_exp = sum_exp.log() + row_max | ||
|
|
||
| row_indices = torch.arange(logits_2d.size(0), device=logits_2d.device) | ||
| selected_logit = logits_2d[row_indices, safe_target] | ||
|
|
||
| selected_logp = selected_logit - log_sum_exp | ||
|
|
||
| selected_logp = selected_logp.where( | ||
| valid_mask, torch.zeros_like(selected_logp) | ||
| ) | ||
|
|
||
| return selected_logp | ||
|
|
||
| # ---------------------------------------------------------------------- # | ||
| # Helper | ||
| # ---------------------------------------------------------------------- # | ||
| @staticmethod | ||
| def _validate_shapes(logits: torch.Tensor, target_ids: torch.Tensor) -> None: | ||
| if logits.dim() < 2: | ||
| raise ValueError( | ||
| f"logits must be at least 2-D ([*lead, V]), got shape {tuple(logits.shape)}" | ||
| ) | ||
| if logits.shape[:-1] != target_ids.shape: | ||
| raise ValueError( | ||
| f"logits leading shape {tuple(logits.shape[:-1])} must match " | ||
| f"target_ids shape {tuple(target_ids.shape)}" | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just make a note that this kernel is currently lacking benchmarks, I will support it in another PR (cuda ver. for this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed, thanks.
Changes: