Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions rl_engine/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Testing helpers for RL-shaped kernel validation."""

from .logprob_parity import compare_selected_logprob_layouts, make_padded_batch_layout
from .reference_ops import (
active_token_count,
compute_policy_ratio,
Expand All @@ -17,8 +18,10 @@
__all__ = [
"SyntheticRLKernelBatch",
"active_token_count",
"compare_selected_logprob_layouts",
"compute_policy_ratio",
"compute_reference_kl",
"make_padded_batch_layout",
"make_synthetic_rl_kernel_batch",
"masked_mean",
"masked_sum",
Expand Down
104 changes: 104 additions & 0 deletions rl_engine/testing/logprob_parity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

from __future__ import annotations

from typing import Any, Optional

import torch

from rl_engine.testing.reference_ops import selected_logprobs_reference, summarize_kernel_drift


def make_padded_batch_layout(
logits: torch.Tensor,
token_ids: torch.Tensor,
mask: torch.Tensor,
*,
destination_rows: torch.Tensor,
padded_batch_size: Optional[int] = None,
pad_token_id: int = 0,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Place completion rows into a larger padded batch layout."""

if logits.ndim < 2:
raise ValueError("logits must have at least batch and vocab dimensions")
if logits.shape[:-1] != token_ids.shape:
raise ValueError("logits leading shape must match token_ids shape")
if mask.shape != token_ids.shape:
raise ValueError("mask shape must match token_ids shape")

vocab_size = int(logits.shape[-1])
if not 0 <= int(pad_token_id) < vocab_size:
raise ValueError("pad_token_id must be within the logits vocabulary range")

source_batch = int(logits.shape[0])
rows = destination_rows.to(device=logits.device, dtype=torch.long).reshape(-1)
if rows.numel() != source_batch:
raise ValueError("destination_rows must contain one destination per source row")
if rows.numel() and int(rows.min().item()) < 0:
raise ValueError("destination_rows must be non-negative")
if rows.unique().numel() != rows.numel():
raise ValueError("destination_rows must not contain duplicates")

resolved_batch = int(padded_batch_size) if padded_batch_size is not None else source_batch
if rows.numel() and int(rows.max().item()) >= resolved_batch:
raise ValueError("destination_rows contains a row outside padded_batch_size")
if resolved_batch < source_batch:
raise ValueError("padded_batch_size must be at least the source batch size")

out_shape = (resolved_batch,) + tuple(logits.shape[1:])
token_shape = (resolved_batch,) + tuple(token_ids.shape[1:])

padded_logits = torch.zeros(out_shape, device=logits.device, dtype=logits.dtype)
padded_token_ids = torch.full(
token_shape,
int(pad_token_id),
device=token_ids.device,
dtype=token_ids.dtype,
)
padded_mask = torch.zeros(token_shape, device=mask.device, dtype=torch.bool)

padded_logits[rows] = logits
padded_token_ids[rows] = token_ids
padded_mask[rows] = mask.to(dtype=torch.bool)
return padded_logits, padded_token_ids, padded_mask


def compare_selected_logprob_layouts(
reference_logits: torch.Tensor,
reference_token_ids: torch.Tensor,
reference_mask: torch.Tensor,
candidate_logits: torch.Tensor,
candidate_token_ids: torch.Tensor,
candidate_mask: torch.Tensor,
*,
candidate_rows: torch.Tensor,
output_dtype: torch.dtype = torch.float32,
) -> dict[str, Any]:
"""Compare selected logprobs for identical rows under different batch layouts."""

reference = selected_logprobs_reference(
reference_logits,
reference_token_ids,
mask=reference_mask,
output_dtype=output_dtype,
)
candidate = selected_logprobs_reference(
candidate_logits,
candidate_token_ids,
mask=candidate_mask,
output_dtype=output_dtype,
)
rows = candidate_rows.to(device=candidate.device, dtype=torch.long).reshape(-1)
if rows.numel() != int(reference.shape[0]):
raise ValueError("candidate_rows must contain one candidate row per reference row")
if rows.numel() and int(rows.min().item()) < 0:
raise ValueError("candidate_rows must be non-negative")
if rows.unique().numel() != rows.numel():
raise ValueError("candidate_rows must not contain duplicates")
if rows.numel() and int(rows.max().item()) >= int(candidate.shape[0]):
raise ValueError("candidate_rows contains a row outside the candidate batch")

restored = candidate[rows]
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return summarize_kernel_drift(restored, reference, reference_mask)
150 changes: 150 additions & 0 deletions tests/test_logprob_parity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

import pytest
import torch

from rl_engine.testing import (
compare_selected_logprob_layouts,
make_padded_batch_layout,
selected_logprobs_reference,
summarize_kernel_drift,
)


def _case(*, device="cpu", dtype=torch.float32):
generator = torch.Generator(device=device).manual_seed(123)
logits = torch.randn(4, 5, 17, device=device, dtype=dtype, generator=generator)
token_ids = torch.randint(0, 17, (4, 5), device=device, generator=generator)
mask = torch.tensor(
[
[True, True, False, True, False],
[True, False, True, True, True],
[False, True, True, False, True],
[True, True, True, False, False],
],
device=device,
dtype=torch.bool,
)
return logits, token_ids, mask


def test_selected_logprob_is_invariant_to_batch_position():
logits, token_ids, mask = _case()
row_order = torch.tensor([2, 0, 3, 1])

base = selected_logprobs_reference(logits, token_ids, mask=mask)
shuffled = selected_logprobs_reference(
logits[row_order],
token_ids[row_order],
mask=mask[row_order],
)
restored = torch.empty_like(base)
restored[row_order] = shuffled

summary = summarize_kernel_drift(restored, base, mask)
assert summary["active_count"] == int(mask.sum().item())
assert summary["max_abs_error"] == 0.0
assert summary["mean_abs_error"] == 0.0


def test_selected_logprob_is_invariant_to_padding_layout():
logits, token_ids, mask = _case()
destination_rows = torch.tensor([4, 0, 2, 5])

padded_logits, padded_token_ids, padded_mask = make_padded_batch_layout(
logits,
token_ids,
mask,
destination_rows=destination_rows,
padded_batch_size=6,
)

summary = compare_selected_logprob_layouts(
logits,
token_ids,
mask,
padded_logits,
padded_token_ids,
padded_mask,
candidate_rows=destination_rows,
)

assert summary["active_count"] == int(mask.sum().item())
assert summary["max_abs_error"] == 0.0
assert summary["mean_abs_error"] == 0.0


def test_make_padded_batch_layout_rejects_out_of_range_pad_token_id():
logits, token_ids, mask = _case()

with pytest.raises(ValueError, match="pad_token_id"):
make_padded_batch_layout(
logits,
token_ids,
mask,
destination_rows=torch.tensor([0, 1, 2, 3]),
padded_batch_size=4,
pad_token_id=logits.shape[-1],
)


@pytest.mark.parametrize(
"candidate_rows",
[
torch.tensor([0, 1, 1, 3]),
torch.tensor([0, -1, 2, 3]),
torch.tensor([0, 1, 2, 6]),
torch.tensor([0, 1, 2]),
],
)
def test_selected_logprob_layout_compare_rejects_bad_candidate_rows(candidate_rows):
logits, token_ids, mask = _case()
padded_logits, padded_token_ids, padded_mask = make_padded_batch_layout(
logits,
token_ids,
mask,
destination_rows=torch.tensor([0, 1, 2, 3]),
padded_batch_size=4,
)

with pytest.raises(ValueError):
compare_selected_logprob_layouts(
logits,
token_ids,
mask,
padded_logits,
padded_token_ids,
padded_mask,
candidate_rows=candidate_rows,
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_selected_logprob_padding_layout_cuda_dtype_sweep(dtype):
logits, token_ids, mask = _case(device="cuda", dtype=dtype)
destination_rows = torch.tensor([1, 4, 0, 3], device="cuda")

padded_logits, padded_token_ids, padded_mask = make_padded_batch_layout(
logits,
token_ids,
mask,
destination_rows=destination_rows,
padded_batch_size=5,
)

summary = compare_selected_logprob_layouts(
logits,
token_ids,
mask,
padded_logits,
padded_token_ids,
padded_mask,
candidate_rows=destination_rows,
output_dtype=torch.float32,
)

assert summary["active_count"] == int(mask.sum().item())
assert summary["max_abs_error"] == 0.0
assert summary["mean_abs_error"] == 0.0
Loading