diff --git a/rl_engine/testing/__init__.py b/rl_engine/testing/__init__.py index 42be8c1..6b6a600 100644 --- a/rl_engine/testing/__init__.py +++ b/rl_engine/testing/__init__.py @@ -3,6 +3,7 @@ """Testing helpers for RL-shaped kernel validation.""" +from .logprob_parity import compare_selected_logprob_layouts, make_padded_batch_layout from .reference_ops import ( active_token_count, compute_policy_ratio, @@ -17,8 +18,10 @@ __all__ = [ "SyntheticRLKernelBatch", "active_token_count", + "compare_selected_logprob_layouts", "compute_policy_ratio", "compute_reference_kl", + "make_padded_batch_layout", "make_synthetic_rl_kernel_batch", "masked_mean", "masked_sum", diff --git a/rl_engine/testing/logprob_parity.py b/rl_engine/testing/logprob_parity.py new file mode 100644 index 0000000..c52b4fc --- /dev/null +++ b/rl_engine/testing/logprob_parity.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +from typing import Any, Optional + +import torch + +from rl_engine.testing.reference_ops import selected_logprobs_reference, summarize_kernel_drift + + +def make_padded_batch_layout( + logits: torch.Tensor, + token_ids: torch.Tensor, + mask: torch.Tensor, + *, + destination_rows: torch.Tensor, + padded_batch_size: Optional[int] = None, + pad_token_id: int = 0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Place completion rows into a larger padded batch layout.""" + + if logits.ndim < 2: + raise ValueError("logits must have at least batch and vocab dimensions") + if logits.shape[:-1] != token_ids.shape: + raise ValueError("logits leading shape must match token_ids shape") + if mask.shape != token_ids.shape: + raise ValueError("mask shape must match token_ids shape") + + vocab_size = int(logits.shape[-1]) + if not 0 <= int(pad_token_id) < vocab_size: + raise ValueError("pad_token_id must be within the logits vocabulary range") + + source_batch = int(logits.shape[0]) + rows = destination_rows.to(device=logits.device, dtype=torch.long).reshape(-1) + if rows.numel() != source_batch: + raise ValueError("destination_rows must contain one destination per source row") + if rows.numel() and int(rows.min().item()) < 0: + raise ValueError("destination_rows must be non-negative") + if rows.unique().numel() != rows.numel(): + raise ValueError("destination_rows must not contain duplicates") + + resolved_batch = int(padded_batch_size) if padded_batch_size is not None else source_batch + if rows.numel() and int(rows.max().item()) >= resolved_batch: + raise ValueError("destination_rows contains a row outside padded_batch_size") + if resolved_batch < source_batch: + raise ValueError("padded_batch_size must be at least the source batch size") + + out_shape = (resolved_batch,) + tuple(logits.shape[1:]) + token_shape = (resolved_batch,) + tuple(token_ids.shape[1:]) + + padded_logits = torch.zeros(out_shape, device=logits.device, dtype=logits.dtype) + padded_token_ids = torch.full( + token_shape, + int(pad_token_id), + device=token_ids.device, + dtype=token_ids.dtype, + ) + padded_mask = torch.zeros(token_shape, device=mask.device, dtype=torch.bool) + + padded_logits[rows] = logits + padded_token_ids[rows] = token_ids + padded_mask[rows] = mask.to(dtype=torch.bool) + return padded_logits, padded_token_ids, padded_mask + + +def compare_selected_logprob_layouts( + reference_logits: torch.Tensor, + reference_token_ids: torch.Tensor, + reference_mask: torch.Tensor, + candidate_logits: torch.Tensor, + candidate_token_ids: torch.Tensor, + candidate_mask: torch.Tensor, + *, + candidate_rows: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> dict[str, Any]: + """Compare selected logprobs for identical rows under different batch layouts.""" + + reference = selected_logprobs_reference( + reference_logits, + reference_token_ids, + mask=reference_mask, + output_dtype=output_dtype, + ) + candidate = selected_logprobs_reference( + candidate_logits, + candidate_token_ids, + mask=candidate_mask, + output_dtype=output_dtype, + ) + rows = candidate_rows.to(device=candidate.device, dtype=torch.long).reshape(-1) + if rows.numel() != int(reference.shape[0]): + raise ValueError("candidate_rows must contain one candidate row per reference row") + if rows.numel() and int(rows.min().item()) < 0: + raise ValueError("candidate_rows must be non-negative") + if rows.unique().numel() != rows.numel(): + raise ValueError("candidate_rows must not contain duplicates") + if rows.numel() and int(rows.max().item()) >= int(candidate.shape[0]): + raise ValueError("candidate_rows contains a row outside the candidate batch") + + restored = candidate[rows] + return summarize_kernel_drift(restored, reference, reference_mask) diff --git a/tests/test_logprob_parity.py b/tests/test_logprob_parity.py new file mode 100644 index 0000000..30be584 --- /dev/null +++ b/tests/test_logprob_parity.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +import pytest +import torch + +from rl_engine.testing import ( + compare_selected_logprob_layouts, + make_padded_batch_layout, + selected_logprobs_reference, + summarize_kernel_drift, +) + + +def _case(*, device="cpu", dtype=torch.float32): + generator = torch.Generator(device=device).manual_seed(123) + logits = torch.randn(4, 5, 17, device=device, dtype=dtype, generator=generator) + token_ids = torch.randint(0, 17, (4, 5), device=device, generator=generator) + mask = torch.tensor( + [ + [True, True, False, True, False], + [True, False, True, True, True], + [False, True, True, False, True], + [True, True, True, False, False], + ], + device=device, + dtype=torch.bool, + ) + return logits, token_ids, mask + + +def test_selected_logprob_is_invariant_to_batch_position(): + logits, token_ids, mask = _case() + row_order = torch.tensor([2, 0, 3, 1]) + + base = selected_logprobs_reference(logits, token_ids, mask=mask) + shuffled = selected_logprobs_reference( + logits[row_order], + token_ids[row_order], + mask=mask[row_order], + ) + restored = torch.empty_like(base) + restored[row_order] = shuffled + + summary = summarize_kernel_drift(restored, base, mask) + assert summary["active_count"] == int(mask.sum().item()) + assert summary["max_abs_error"] == 0.0 + assert summary["mean_abs_error"] == 0.0 + + +def test_selected_logprob_is_invariant_to_padding_layout(): + logits, token_ids, mask = _case() + destination_rows = torch.tensor([4, 0, 2, 5]) + + padded_logits, padded_token_ids, padded_mask = make_padded_batch_layout( + logits, + token_ids, + mask, + destination_rows=destination_rows, + padded_batch_size=6, + ) + + summary = compare_selected_logprob_layouts( + logits, + token_ids, + mask, + padded_logits, + padded_token_ids, + padded_mask, + candidate_rows=destination_rows, + ) + + assert summary["active_count"] == int(mask.sum().item()) + assert summary["max_abs_error"] == 0.0 + assert summary["mean_abs_error"] == 0.0 + + +def test_make_padded_batch_layout_rejects_out_of_range_pad_token_id(): + logits, token_ids, mask = _case() + + with pytest.raises(ValueError, match="pad_token_id"): + make_padded_batch_layout( + logits, + token_ids, + mask, + destination_rows=torch.tensor([0, 1, 2, 3]), + padded_batch_size=4, + pad_token_id=logits.shape[-1], + ) + + +@pytest.mark.parametrize( + "candidate_rows", + [ + torch.tensor([0, 1, 1, 3]), + torch.tensor([0, -1, 2, 3]), + torch.tensor([0, 1, 2, 6]), + torch.tensor([0, 1, 2]), + ], +) +def test_selected_logprob_layout_compare_rejects_bad_candidate_rows(candidate_rows): + logits, token_ids, mask = _case() + padded_logits, padded_token_ids, padded_mask = make_padded_batch_layout( + logits, + token_ids, + mask, + destination_rows=torch.tensor([0, 1, 2, 3]), + padded_batch_size=4, + ) + + with pytest.raises(ValueError): + compare_selected_logprob_layouts( + logits, + token_ids, + mask, + padded_logits, + padded_token_ids, + padded_mask, + candidate_rows=candidate_rows, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_selected_logprob_padding_layout_cuda_dtype_sweep(dtype): + logits, token_ids, mask = _case(device="cuda", dtype=dtype) + destination_rows = torch.tensor([1, 4, 0, 3], device="cuda") + + padded_logits, padded_token_ids, padded_mask = make_padded_batch_layout( + logits, + token_ids, + mask, + destination_rows=destination_rows, + padded_batch_size=5, + ) + + summary = compare_selected_logprob_layouts( + logits, + token_ids, + mask, + padded_logits, + padded_token_ids, + padded_mask, + candidate_rows=destination_rows, + output_dtype=torch.float32, + ) + + assert summary["active_count"] == int(mask.sum().item()) + assert summary["max_abs_error"] == 0.0 + assert summary["mean_abs_error"] == 0.0