diff --git a/docs/contributing/issue-108-session-log.md b/docs/contributing/issue-108-session-log.md new file mode 100644 index 0000000..a6aa9c1 --- /dev/null +++ b/docs/contributing/issue-108-session-log.md @@ -0,0 +1,483 @@ +# ISSUE-108 Session Log + +This document records the engineering decisions made while building the ISSUE-108 kernel correctness checker. It is intentionally concise and review-oriented: it explains what was added, why it was added, how to use it, and what is still out of scope. + +## Logging Rules + +- Record the reason for each meaningful change, not only the files touched. +- Keep changes minimal and independently verifiable. +- Be explicit when a path is only a smoke test or an experimental path. +- Do not present failed CUDA paths as supported capabilities. +- Gold implementations must come from `rl_engine.kernels.ops.pytorch`. + +## Goal + +The goal of this work is to add a minimal, reusable operator correctness framework for post-training kernels. + +The framework should: + +- Generate deterministic operator inputs. +- Run PyTorch gold implementations and backend candidates on the same inputs. +- Compare every tensor output with dtype/operator-class tolerances. +- Report absolute error, relative error, pass rate, and final pass/fail status. +- Expose a CLI so a developer can validate a backend candidate without editing test files. + +## Final Layout + +```text +rl_engine/kernels/gtest/ + __init__.py + op_checks.py + operator_inputs.py + operator_specs.py + tolerance.py + tolerance_contract.json + +scripts/check_operator.py + +tests/test_op_checks.py +tests/test_operator_inputs.py +tests/test_tolerance_contract.py +``` + +## Key Design Decisions + +### Tolerance Contract + +Files: + +```text +rl_engine/kernels/gtest/tolerance.py +rl_engine/kernels/gtest/tolerance_contract.json +tests/test_tolerance_contract.py +``` + +Decision: + +- Store tolerance values in a small contract file rather than hard-coding them inside tests. +- Resolve tolerance by `op_class + dtype`, with optional `arch_key` overrides. +- Treat `default` as the generic fallback, not as CPU-specific tolerance. + +Current accuracy classes: + +```text +elementwise +reduction +logprob +``` + +### Operator Check Runner + +Files: + +```text +rl_engine/kernels/gtest/op_checks.py +tests/test_op_checks.py +``` + +Decision: + +- `OperatorCase` describes one deterministic test case: name, op class, dtype, inputs, and gold function. +- `CandidateSpec` describes one implementation under test: name, function, backend, and optional arch key. +- `run_operator_suite()` runs candidates against gold outputs and returns structured reports. +- The runner compares forward outputs only in this minimal version. + +Review follow-up: + +- `op_checks.py` includes a TODO for optional gradient checks on differentiable operators. +- Gradient checks require additional metadata and input cloning rules, so they are intentionally tracked as follow-up work instead of being silently implied by this PR. + +### Operator Inputs + +Files: + +```text +rl_engine/kernels/gtest/operator_inputs.py +tests/test_operator_inputs.py +``` + +Decision: + +- Build standard semantic inputs for each operator. +- Support both `random` and `constant` initialization. +- Make random inputs reproducible with `--seed`. +- Preserve semantic shapes such as `[B, S, V]`; do not flatten inputs for backend-specific kernels inside input generation. + +Current input builders cover: + +```text +rms_norm +matmul +attention +logp +rope +silu +swiglu +embedding +lm_head +kv_cache_attention +``` + +### Operator Specs + +File: + +```text +rl_engine/kernels/gtest/operator_specs.py +``` + +Decision: + +- Keep operator-specific registration outside `scripts/check_operator.py`. +- Register PyTorch gold paths and backend candidate paths in one place. +- Require `gold_path` to point into `rl_engine.kernels.ops.pytorch`. + +Current minimal registered operator: + +```text +op: logp +op_class: logprob +gold: rl_engine.kernels.ops.pytorch.loss.logp.NativeLogpOp +candidates: + pytorch -> NativeLogpOp + cuda -> FusedLogpGenericOp + cuda-generic -> FusedLogpGenericOp + cuda-sm90 -> FusedLogpSM90Op +``` + +Important note: + +- `candidate=pytorch` is only a smoke test for the checker itself. +- CUDA, Triton, ROCm, and future hardware-specific implementations are candidates. +- Registry dispatch is tested separately from this accuracy harness because it is hardware-dependent. +- Do not compare two operators that implement different math, such as ordinary `logp` and `linear_logp`. + +### SM90 Adapter Exception + +Current code contains `_LogpSM90CandidateAdapter` in `operator_specs.py`. + +Reason: + +- The existing SM90 logp wrapper accepts flattened inputs, while the checker standard input for `logp` is `[B, S, V]` logits and `[B, S]` token ids. +- The adapter exists only to validate the checker path against the current SM90 wrapper. + +Long-term rule: + +- Backend wrappers should align with the standard operator interface whenever possible. +- New operators should not rely on permanent test-side adapters for ordinary shape or parameter-name differences. + +## CLI Usage + +CPU smoke check against the PyTorch candidate: + +```bash +python scripts/check_operator.py --op logp --candidate pytorch --device cpu --dtype fp32 --batch 1 --seq 2 --vocab 17 +``` + +CUDA candidate check against the PyTorch gold path: + +```bash +python scripts/check_operator.py --op logp --candidate cuda --device cuda --dtype bf16 --arch-key sm90 --batch 1 --seq 1 --vocab 4096 +``` + +JSON report: + +```bash +python scripts/check_operator.py --op logp --candidate pytorch --device cpu --dtype fp32 --batch 1 --seq 2 --vocab 17 --json +``` + +Supported key options: + +```text +--op Operator name. The minimal version supports logp and linear_logp. +--candidate Candidate backend, for example pytorch, cuda, cuda-generic, cuda-sm90, triton. +--dtype fp32, bf16, or fp16. +--device auto, cpu, cuda, or another torch device string. +--arch-key Optional tolerance override key such as sm90. +--batch Batch size. +--seq Sequence length. +--vocab Vocabulary size. +--input-mode random or constant. +--constant-value Floating-point value for constant mode. +--token-value Token id for constant mode, reduced modulo vocab. +--seed Random seed for reproducible random inputs. +--check-grad Also compare gradients for inputs declared by the operator spec. +--grad-mode Upstream gradient mode for --check-grad: random by default; ones for smoke tests. +--grad-seed Random upstream gradient seed for --grad-mode random. +--json Print the full structured report as JSON. +``` + +Example output: + +```text +suite=logp passed=True pass_rate=1.0000 +candidate=cuda-logp backend=cuda passed=True pass_rate=1.0000 + case=logp-torch.bfloat16-1x1x4096 output=0 shape=(1, 1) dtype=torch.bfloat16 max_abs=2.69813538e-02 mean_abs=2.69813538e-02 max_rel=3.03093810e-03 tol=(atol=5.000e-02, rtol=0.000e+00) passed=True +``` + +## Adding a New Operator + +To add a new operator, keep the shared checker flow unchanged. Add only operator-specific inputs, specs, and tests. + +### 1. Add Input Generation + +File: + +```text +rl_engine/kernels/gtest/operator_inputs.py +``` + +Update `make_operator_inputs()`: + +```python +builders = { + ... + "new_op": _make_new_op_inputs, +} +``` + +Update `operator_shape_name()`: + +```python +names = { + ... + "new_op": f"{batch}x{seq}x...", +} +``` + +Add the input builder: + +```python +def _make_new_op_inputs(args, dtype, device): + batch, seq = _batch_seq(args) + return { + "x": _floating_tensor((batch, seq, ...), args, dtype, device, offset=0), + } +``` + +Rules: + +- Inputs should represent the operator's standard semantic interface. +- Do not generate backend-specific flattened inputs here. +- Support deterministic random inputs and constant inputs where practical. + +### 2. Register Gold and Candidates + +File: + +```text +rl_engine/kernels/gtest/operator_specs.py +``` + +Add an `OperatorSpec` entry: + +```python +"new_op": OperatorSpec( + name="new_op", + op_class="elementwise", + gold_path="rl_engine.kernels.ops.pytorch....NativeNewOp", + candidate_paths={ + "pytorch": "rl_engine.kernels.ops.pytorch....NativeNewOp", + "cuda": "rl_engine.kernels.ops.cuda....CudaNewOp", + "triton": "rl_engine.kernels.ops.triton....TritonNewOp", + }, +) +``` + +Rules: + +- `gold_path` must come from `rl_engine.kernels.ops.pytorch`. +- Backend implementations are candidates only. +- `candidate=pytorch` is for checker smoke tests only. +- Do not compare operators with different math. + +### 3. Update Tolerances If Needed + +File: + +```text +rl_engine/kernels/gtest/tolerance_contract.json +``` + +Reuse an existing class when possible: + +```text +elementwise +reduction +logprob +``` + +If a new class is needed, add dtype tolerances and set `op_class` accordingly in `operator_specs.py`. + +### 4. Add Tests + +Files: + +```text +tests/test_operator_inputs.py +tests/test_op_checks.py +``` + +Minimum expected coverage: + +- Add the operator to the `test_operator_inputs_support_all_issue_108_ops` parametrized list. +- Add a PyTorch-vs-PyTorch smoke case if the operator adds new runner behavior. +- Add a bad-candidate case if the operator introduces new comparison behavior. + +### 5. Validate + +```bash +python -m pytest tests/test_tolerance_contract.py tests/test_op_checks.py tests/test_operator_inputs.py -q +``` + +Then run the CLI: + +```bash +python scripts/check_operator.py --op new_op --candidate pytorch --device cpu --dtype fp32 +``` + +For CUDA: + +```bash +python scripts/check_operator.py --op new_op --candidate cuda --device cuda --dtype bf16 --arch-key sm90 +``` + +## CUDA Validation Notes + +H100 environment observed during development: + +```text +GPU: NVIDIA H100 80GB HBM3 +Driver: 580.95.05 +CUDA driver capability: 13.0 +nvcc: 13.0 +torch: 2.12.0+cu130 +compute capability: (9, 0) +``` + +Generic CUDA `logp` passed on H100 for vocab sizes 256, 512, 1024, 2048, and 4096 with bf16 inputs under the current tolerance contract. + +SM90 fused logp is not marked as a passing path in this PR. It compiled and loaded in some experiments, but runtime failures and accuracy failures were observed separately. Treat SM90 fused logp as a separate CUDA kernel validation task unless `check_operator.py` reports `passed=True` for the target case. + +## Validation Performed + +```bash +python -m pytest tests/test_tolerance_contract.py tests/test_op_checks.py tests/test_operator_inputs.py -q +``` + +CPU CLI smoke test: + +```bash +python scripts/check_operator.py --op logp --candidate pytorch --device cpu --dtype fp32 --batch 1 --seq 2 --vocab 17 +``` + +Backward CLI smoke test: + +```bash +python scripts/check_operator.py --op logp --candidate pytorch --device cpu --dtype fp32 --batch 1 --seq 2 --vocab 17 --check-grad +``` + +## PR Review Updates + +### LogP Gradient Coverage + +Files: + +```text +tests/test_logp.py +docs/contributing/issue-108-session-log.md +``` + +Change: + +- Added a forward-gradient test for `NativeLogpOp.forward_fp32`. + +Reasoning: + +- The checker PR already validates forward output values, but review feedback called out that logprob coverage should also prove gradient propagation and batch invariance. +- The new gradient test compares the op gradient against a direct PyTorch `log_softmax + gather` reference under a non-unit upstream gradient. +- Batch invariance was already covered by `TestNativeLogpOpBatchInvariance` in `tests/test_logp.py`, so no duplicate batch-invariance test was added. + +### GTest Backward Check Support + +Files: + +```text +rl_engine/kernels/gtest/op_checks.py +rl_engine/kernels/gtest/operator_specs.py +scripts/check_operator.py +tests/test_op_checks.py +docs/contributing/issue-108-session-log.md +``` + +Change: + +- Added `OperatorCase.grad_input_names` to declare which inputs should be checked for gradients. +- Added `run_operator_suite(..., check_grad=True)`. +- Added `_run_case_backward()` to compare candidate forward outputs and selected input gradients against the PyTorch gold path. +- Added `OperatorSpec.grad_input_names`; `logp` declares `("logits",)`. +- Added `scripts/check_operator.py --check-grad`. +- Added `--grad-mode ones|random` and `--grad-seed` for backward checks. + +Reasoning: + +- Forward-only checks can miss incorrect or disconnected backward paths. +- Gradient inputs must be declared per operator because not every floating tensor should receive gradients. +- Input generation remains independent of autograd; the runner clones inputs and enables `requires_grad` only inside the backward check path. +- `grad_mode=random` is the default and catches backward bugs hidden by all-one upstream gradients. +- `grad_mode=ones` remains available for quick smoke tests and preserves the old `output.sum().backward()` behavior. + +Known backend limitation: + +- `cuda` `logp` currently calls the compiled `_C.fused_logp` forward path directly and does not produce an autograd-connected output. +- Running `--check-grad` against `candidate=cuda` fails with a missing `grad_fn`; this is a backend implementation gap, not a tolerance issue. +- To support `cuda logp` backward, the backend must add or wrap a real backward path, usually via `torch.autograd.Function` or an explicit CUDA backward kernel. + +### Linear LogP Triton GTest Smoke Support + +Files: + +```text +rl_engine/kernels/gtest/operator_inputs.py +rl_engine/kernels/gtest/operator_specs.py +tests/test_operator_inputs.py +docs/contributing/issue-108-session-log.md +``` + +Change: + +- Added `linear_logp` input construction with `hidden`, `lm_head_weight`, `target_ids`, and `bias=None`. +- Added a `linear_logp` operator spec using `NativeLinearLogpOp.apply` as the PyTorch gold path. +- Added `triton` as a `linear_logp` candidate backend. +- Declared `("hidden", "lm_head_weight")` as gradient inputs for backward checks. + +Reasoning: + +- `linear_logp` is the smallest real fused op in the repository with an implemented Triton backward path. +- The first gtest integration keeps bias disabled to avoid optional-gradient handling in the initial smoke path. +- The op reuses the `logprob` tolerance class because it produces selected-token log probabilities. +- It gives the checker a real non-PyTorch differentiable candidate for end-to-end forward/backward reporting. + +Example Triton smoke command: + +```bash +python scripts/check_operator.py --op linear_logp --candidate triton --device cuda --dtype bf16 --batch 1 --seq 2 --vocab 1024 --normalized-dim 4096 --check-grad --grad-mode random --grad-seed 123 +``` + +Observed bf16 result on H100: + +```text +suite=linear_logp passed=False pass_rate=0.0000 +candidate=triton-linear_logp backend=triton passed=False pass_rate=0.6667 + case=linear_logp-torch.bfloat16-1x2x4096x1024 output=0 shape=(1, 2) dtype=torch.float32 max_abs=4.76226807e-01 mean_abs=4.74334717e-01 max_rel=2.96434597e-03 tol=(atol=5.000e-02, rtol=0.000e+00) passed=False + case=linear_logp-torch.bfloat16-1x2x4096x1024 output=1 gradient:hidden shape=(1, 2, 4096) dtype=torch.bfloat16 max_abs=0.00000000e+00 mean_abs=0.00000000e+00 max_rel=0.00000000e+00 tol=(atol=5.000e-02, rtol=0.000e+00) passed=True + case=linear_logp-torch.bfloat16-1x2x4096x1024 output=2 gradient:lm_head_weight shape=(1024, 4096) dtype=torch.bfloat16 max_abs=0.00000000e+00 mean_abs=0.00000000e+00 max_rel=0.00000000e+00 tol=(atol=5.000e-02, rtol=0.000e+00) passed=True +``` + +Interpretation: + +- The checker flow is complete: CLI parsing, input construction, PyTorch gold loading, Triton candidate loading, forward execution, backward execution, gradient collection, comparison, and report formatting all ran successfully. +- The suite is intentionally not marked as passing because the bf16 forward output exceeded the current `logprob` absolute tolerance. +- The current tolerance uses `atol=5.0e-2` and `rtol=0.0`; the observed forward absolute error is about `4.76e-1`, while the relative error is about `2.96e-3`. +- `linear_logp` includes a large bf16 matrix multiply before selected-token logprob, so it likely needs an operator-specific gold policy or tolerance rather than reusing plain `logprob` tolerances unchanged. +- In this case, both checked gradients passed, so the failure is a forward accuracy/tolerance calibration issue, not a failure to execute the backward checker path. diff --git a/docs/operators/fused-logp.md b/docs/operators/fused-logp.md index d5008e5..ccc0f61 100644 --- a/docs/operators/fused-logp.md +++ b/docs/operators/fused-logp.md @@ -13,21 +13,36 @@ logp_op = kernel_registry.get_op("logp") output = logp_op(logits, token_ids) ``` +The PyTorch native reference also exposes the Issue #108 interface: + +```python +from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp + +logp_ref = NativeLogpOp() +output = logp_ref.forward(logits, token_ids) +reference = logp_ref.forward_fp32(logits, token_ids) +``` + +`apply(...)` and `apply_fp32(...)` remain available as backward-compatible aliases. + ## Backends | Backend | Wrapper | Native symbol | Notes | | --- | --- | --- | --- | | CUDA SM90 | `FusedLogpSM90Op` | `_C.fused_logp_sm90` | TMA-oriented path for Hopper-class GPUs. | | CUDA generic | `FusedLogpGenericOp` | `_C.fused_logp` | Generic compiled extension fallback. | -| PyTorch native | `NativeOp` | None | Baseline fallback path. | +| PyTorch native | `NativeLogpOp` | None | PyTorch baseline/reference path. | ## Tensor Contract | Argument | Shape | Dtype | Requirements | | --- | --- | --- | --- | -| `logits` | `[N, V]` | `bfloat16` for SM90 path | Contiguous, on the target device. | -| `token_ids` / `labels` | `[N]` | Converted to `int32` | Same logical device as `logits`. | -| Output | `[N]` | Backend-defined tensor dtype | One selected log probability per row. | +| `logits` | `[..., V]` | Floating point | Contiguous for fused CUDA paths; arbitrary leading dimensions. | +| `token_ids` / `labels` | `[...]` | Integer | Must match `logits.shape[:-1]`. | +| Output | `[...]` | See below | One selected log probability per row. | + +For `NativeLogpOp`, `forward(...)` returns the input dtype and `forward_fp32(...)` +returns `torch.float32`. ## Reference Semantics @@ -39,16 +54,20 @@ ref = torch.gather(ref, dim=-1, index=token_ids.unsqueeze(-1).long()).squeeze(-1 ## Tests ```bash -python tests/test_op_accuracy.py +python -m pytest tests/test_logp.py -q +python -m pytest tests/test_op_accuracy.py -q ``` -The current accuracy test compares the dispatched operator with a PyTorch reference and -uses a dtype-dependent threshold. +`tests/test_logp.py` covers the PyTorch reference contract, dtype behavior, +backward-compatible aliases, batch invariance, and registry dispatch. The existing +operator accuracy tests continue to validate native/CUDA fused API compatibility. ## Implementation Files - `rl_engine/kernels/registry.py` -- `rl_engine/kernels/ops/cuda.py` +- `rl_engine/kernels/ops/pytorch/loss/logp.py` +- `rl_engine/kernels/ops/cuda/loss/logp.py` - `csrc/ops.cpp` - `csrc/fused_logp_kernel.cu` - `csrc/cuda/fused_logp_sm90.cu` +- `tests/test_logp.py` diff --git a/rl_engine/kernels/gtest/__init__.py b/rl_engine/kernels/gtest/__init__.py new file mode 100644 index 0000000..c3fc366 --- /dev/null +++ b/rl_engine/kernels/gtest/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from .op_checks import CandidateSpec, OperatorCase, run_operator_suite + +__all__ = [ + "CandidateSpec", + "OperatorCase", + "run_operator_suite", +] diff --git a/rl_engine/kernels/gtest/op_checks.py b/rl_engine/kernels/gtest/op_checks.py new file mode 100644 index 0000000..cbd0fd6 --- /dev/null +++ b/rl_engine/kernels/gtest/op_checks.py @@ -0,0 +1,496 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +from collections.abc import Callable, Mapping, Sequence +from dataclasses import asdict, dataclass +from typing import Any + +import torch + +from rl_engine.kernels.gtest.tolerance import load_contract + + +@dataclass(frozen=True) +class OperatorCase: + """One deterministic test object for an operator candidate.""" + + name: str + op_class: str + dtype: torch.dtype + inputs: Mapping[str, Any] + gold_fn: Callable[..., Any] + grad_input_names: tuple[str, ...] = () + + +@dataclass(frozen=True) +class CandidateSpec: + """One implementation to validate against the gold path.""" + + name: str + fn: Callable[..., Any] | Any + backend: str = "unknown" + arch_key: str | None = None + + +@dataclass(frozen=True) +class OutputCheck: + """Per-output comparison result.""" + + output_index: int + shape: tuple[int, ...] + candidate_dtype: str + gold_dtype: str + atol: float + rtol: float + max_abs_error: float + mean_abs_error: float + max_rel_error: float + passed: bool + message: str = "" + + +@dataclass(frozen=True) +class CaseCheck: + """Per-case result for one candidate.""" + + case_name: str + dtype: str + op_class: str + passed: bool + outputs: list[OutputCheck] + + +@dataclass(frozen=True) +class CandidateReport: + """Aggregate report for one candidate implementation.""" + + candidate_name: str + backend: str + total_outputs: int + passed_outputs: int + pass_rate: float + passed: bool + cases: list[CaseCheck] + + +@dataclass(frozen=True) +class OperatorCheckReport: + """Suite-level report across candidates.""" + + suite_name: str + total_candidates: int + passed_candidates: int + pass_rate: float + passed: bool + candidates: list[CandidateReport] + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + +def run_operator_suite( + suite_name: str, + *, + candidates: Sequence[CandidateSpec], + cases: Sequence[OperatorCase], + contract: Mapping[str, Any] | None = None, + check_grad: bool = False, + grad_mode: str = "random", + grad_seed: int = 123, +) -> OperatorCheckReport: + """Run candidates against gold outputs and return a structured report.""" + + loaded_contract = dict(contract or load_contract()) + # run all test ops + # cases : test object + # camdidate : test instance + # loaded_contract : tolerance table + candidate_reports = [ + _run_candidate( + candidate, + cases, + loaded_contract, + check_grad=check_grad, + grad_mode=grad_mode, + grad_seed=grad_seed, + ) + for candidate in candidates + ] + passed_candidates = sum(1 for report in candidate_reports if report.passed) + total_candidates = len(candidate_reports) + pass_rate = float(passed_candidates / total_candidates) if total_candidates else 0.0 + return OperatorCheckReport( + suite_name=suite_name, + total_candidates=total_candidates, + passed_candidates=passed_candidates, + pass_rate=pass_rate, + passed=passed_candidates == total_candidates, + candidates=candidate_reports, + ) + + +def _run_candidate( + candidate: CandidateSpec, + cases: Sequence[OperatorCase], + contract: Mapping[str, Any], + *, + check_grad: bool, + grad_mode: str, + grad_seed: int, +) -> CandidateReport: + if check_grad: + case_checks = [ + _run_case_backward( + candidate, + case, + contract, + grad_mode=grad_mode, + grad_seed=grad_seed, + ) + for case in cases + ] + else: + case_checks = [_run_case(candidate, case, contract) for case in cases] + total_outputs = sum(len(case.outputs) for case in case_checks) + passed_outputs = sum( + 1 for case in case_checks for output in case.outputs if output.passed + ) + pass_rate = float(passed_outputs / total_outputs) if total_outputs else 0.0 + return CandidateReport( + candidate_name=candidate.name, + backend=candidate.backend, + total_outputs=total_outputs, + passed_outputs=passed_outputs, + pass_rate=pass_rate, + passed=passed_outputs == total_outputs, + cases=case_checks, + ) + + +def _run_case( + candidate: CandidateSpec, + case: OperatorCase, + contract: Mapping[str, Any], +) -> CaseCheck: + candidate_outputs = _flatten_tensors(_call_candidate(candidate.fn, case.inputs)) + gold_outputs = _flatten_tensors(case.gold_fn(**case.inputs)) + return _compare_case_outputs(candidate, case, contract, candidate_outputs, gold_outputs) + + +def _run_case_backward( + candidate: CandidateSpec, + case: OperatorCase, + contract: Mapping[str, Any], + *, + grad_mode: str, + grad_seed: int, +) -> CaseCheck: + if not case.grad_input_names: + raise ValueError(f"case {case.name!r} does not declare gradient inputs") + + candidate_inputs = _clone_inputs_for_backward(case.inputs, case.grad_input_names) + gold_inputs = _clone_inputs_for_backward(case.inputs, case.grad_input_names) + candidate_outputs = _flatten_tensors(_call_candidate(candidate.fn, candidate_inputs)) + gold_outputs = _flatten_tensors(case.gold_fn(**gold_inputs)) + # Candidate and gold must use the same upstream gradients; otherwise we + # would compare different vector-Jacobian products. + # grad_mode="ones" is the old output.sum().backward() smoke path. + # grad_mode="random" is closer to training, where dL/doutput is non-uniform. + grad_outputs = _make_grad_outputs(candidate_outputs, grad_mode=grad_mode, seed=grad_seed) + candidate_grads = _backward_grads( + candidate_outputs, + candidate_inputs, + case.grad_input_names, + grad_outputs=grad_outputs, + ) + gold_grads = _backward_grads( + gold_outputs, + gold_inputs, + case.grad_input_names, + grad_outputs=_match_grad_outputs(grad_outputs, gold_outputs), + ) + output_checks = _compare_case_outputs( + candidate, + case, + contract, + candidate_outputs, + gold_outputs, + ).outputs + # Reuse the same tolerance class for gradients as for values. This is a + # first conservative default; operator-specific gradient tolerances can be + # split out later if a real backend shows different numerical behavior. + atol, rtol = _resolve_tolerance( + contract, + op_class=case.op_class, + dtype=case.dtype, + arch_key=candidate.arch_key, + ) + grad_checks = [ + _compare_output( + candidate_grad, + gold_grad, + output_index=len(output_checks) + index, + atol=atol, + rtol=rtol, + message=f"gradient:{name}", + ) + for index, (name, candidate_grad, gold_grad) in enumerate( + zip(case.grad_input_names, candidate_grads, gold_grads, strict=True) + ) + ] + checks = [*output_checks, *grad_checks] + return CaseCheck( + case_name=case.name, + dtype=str(case.dtype), + op_class=case.op_class, + passed=all(output.passed for output in checks), + outputs=checks, + ) + + +def _compare_case_outputs( + candidate: CandidateSpec, + case: OperatorCase, + contract: Mapping[str, Any], + candidate_outputs: list[torch.Tensor], + gold_outputs: list[torch.Tensor], +) -> CaseCheck: + if len(candidate_outputs) != len(gold_outputs): + raise ValueError( + f"candidate {candidate.name!r} returned {len(candidate_outputs)} outputs, " + f"gold returned {len(gold_outputs)}" + ) + atol, rtol = _resolve_tolerance( + contract, + op_class=case.op_class, + dtype=case.dtype, + arch_key=candidate.arch_key, + ) + output_checks = [ + _compare_output( + candidate_output, + gold_output, + output_index=index, + atol=atol, + rtol=rtol, + ) + for index, (candidate_output, gold_output) in enumerate( + zip(candidate_outputs, gold_outputs, strict=True) + ) + ] + return CaseCheck( + case_name=case.name, + dtype=str(case.dtype), + op_class=case.op_class, + passed=all(output.passed for output in output_checks), + outputs=output_checks, + ) + + +# compatibility function or forward +def _call_candidate(candidate: Callable[..., Any] | Any, inputs: Mapping[str, Any]) -> Any: + if hasattr(candidate, "forward") and callable(candidate.forward): + return candidate.forward(**inputs) + return candidate(**inputs) + + +def _clone_inputs_for_backward( + inputs: Mapping[str, Any], + grad_input_names: tuple[str, ...], +) -> dict[str, Any]: + grad_names = set(grad_input_names) + cloned: dict[str, Any] = {} + for name, value in inputs.items(): + if isinstance(value, torch.Tensor): + tensor = value.detach().clone() + if name in grad_names: + if not tensor.is_floating_point(): + raise TypeError(f"gradient input {name!r} must be floating point") + tensor.requires_grad_(True) + cloned[name] = tensor + else: + cloned[name] = value + missing = grad_names.difference(cloned) + if missing: + raise ValueError(f"missing gradient inputs: {', '.join(sorted(missing))}") + return cloned + + +def _backward_grads( + outputs: list[torch.Tensor], + inputs: Mapping[str, Any], + grad_input_names: tuple[str, ...], + *, + grad_outputs: list[torch.Tensor], +) -> list[torch.Tensor]: + if len(outputs) != len(grad_outputs): + raise ValueError( + f"got {len(grad_outputs)} upstream gradients for {len(outputs)} outputs" + ) + # `ones` makes this equivalent to output.sum().backward(); `random` tests a + # stricter vector-Jacobian product. + loss = sum( + (output.float() * grad_output.to(device=output.device).float()).sum() + for output, grad_output in zip(outputs, grad_outputs, strict=True) + ) + loss.backward() + grads: list[torch.Tensor] = [] + for name in grad_input_names: + grad = inputs[name].grad + if grad is None: + raise ValueError(f"gradient for input {name!r} is None") + grads.append(grad) + return grads + + +def _make_grad_outputs( + outputs: list[torch.Tensor], + *, + grad_mode: str, + seed: int, +) -> list[torch.Tensor]: + if grad_mode == "ones": + # All-one upstream gradients make the scalar loss equal output.sum(). + return [torch.ones_like(output, dtype=torch.float32) for output in outputs] + if grad_mode != "random": + raise ValueError(f"unsupported grad_mode: {grad_mode}") + + grad_outputs: list[torch.Tensor] = [] + generators: dict[torch.device, torch.Generator] = {} + for output in outputs: + if output.device not in generators: + # Generators are device-local; a CUDA generator cannot draw CPU tensors. + generator = torch.Generator(device=output.device) + generator.manual_seed(seed) + generators[output.device] = generator + # Random upstream gradients test a non-uniform dL/doutput. The same + # tensors are later reused for gold so the comparison stays fair. + grad_outputs.append( + torch.randn( + output.shape, + generator=generators[output.device], + device=output.device, + dtype=torch.float32, + ) + ) + return grad_outputs + + +def _match_grad_outputs( + grad_outputs: list[torch.Tensor], + outputs: list[torch.Tensor], +) -> list[torch.Tensor]: + # Reuse upstream values for gold; only move device when needed. + return [ + grad_output.to(device=output.device) + for grad_output, output in zip(grad_outputs, outputs, strict=True) + ] + + +def _flatten_tensors(value: Any) -> list[torch.Tensor]: + if isinstance(value, torch.Tensor): + return [value] + if isinstance(value, (tuple, list)): + outputs: list[torch.Tensor] = [] + for item in value: + outputs.extend(_flatten_tensors(item)) + return outputs + raise TypeError(f"operator output must be Tensor or sequence, got {type(value)!r}") + + +def _resolve_tolerance( + contract: Mapping[str, Any], + *, + op_class: str, + dtype: torch.dtype, + arch_key: str | None = None, +) -> tuple[float, float]: + dtype_name = _dtype_name(dtype) + if arch_key is not None: + arch_values = ( + contract["accuracy"] + .get("arch_overrides", {}) + .get(arch_key, {}) + .get(op_class, {}) + .get(dtype_name) + ) + if arch_values is not None: + return float(arch_values["atol"]), float(arch_values.get("rtol", 0.0)) + + values = contract["accuracy"]["default"][op_class][dtype_name] + return float(values["atol"]), float(values.get("rtol", 0.0)) + + +def _dtype_name(dtype: torch.dtype) -> str: + if dtype is torch.float32: + return "float32" + if dtype is torch.bfloat16: + return "bfloat16" + if dtype is torch.float16: + return "float16" + raise ValueError(f"unsupported dtype: {dtype}") + + +def _compare_output( + candidate: torch.Tensor, + gold: torch.Tensor, + *, + output_index: int, + atol: float, + rtol: float, + message: str = "", +) -> OutputCheck: + if candidate.shape != gold.shape: + return OutputCheck( + output_index=output_index, + shape=tuple(candidate.shape), + candidate_dtype=str(candidate.dtype), + gold_dtype=str(gold.dtype), + atol=atol, + rtol=rtol, + max_abs_error=float("inf"), + mean_abs_error=float("inf"), + max_rel_error=float("inf"), + passed=False, + message=f"shape mismatch: candidate={tuple(candidate.shape)} gold={tuple(gold.shape)}", + ) + + candidate_fp32 = candidate.float() + gold_fp32 = gold.float() + abs_error = (candidate_fp32 - gold_fp32).abs() + if abs_error.numel() == 0: + max_abs_error = 0.0 + mean_abs_error = 0.0 + max_rel_error = 0.0 + else: + max_abs_error = float(abs_error.max().item()) + mean_abs_error = float(abs_error.mean().item()) + rel_error = abs_error / gold_fp32.abs().clamp_min(1e-12) + max_rel_error = float(rel_error.max().item()) + + return OutputCheck( + output_index=output_index, + shape=tuple(candidate.shape), + candidate_dtype=str(candidate.dtype), + gold_dtype=str(gold.dtype), + atol=atol, + rtol=rtol, + max_abs_error=max_abs_error, + mean_abs_error=mean_abs_error, + max_rel_error=max_rel_error, + passed=bool(torch.allclose(candidate_fp32, gold_fp32, atol=atol, rtol=rtol)), + message=message, + ) + + +__all__ = [ + "CandidateReport", + "CandidateSpec", + "CaseCheck", + "OperatorCase", + "OperatorCheckReport", + "OutputCheck", + "run_operator_suite", +] diff --git a/rl_engine/kernels/gtest/operator_inputs.py b/rl_engine/kernels/gtest/operator_inputs.py new file mode 100644 index 0000000..4d131e2 --- /dev/null +++ b/rl_engine/kernels/gtest/operator_inputs.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import argparse +from typing import Any + +import torch + + +DEFAULT_HIDDEN = 4096 +DEFAULT_N_HEADS = 32 +DEFAULT_N_KV_HEADS = 8 +DEFAULT_HEAD_DIM = 128 +DEFAULT_INTERMEDIATE = 12288 +DEFAULT_VOCAB = 151936 +DEFAULT_ROPE_THETA = 1.0e6 +DEFAULT_RMS_EPS = 1.0e-6 + + +def make_operator_inputs( + op_name: str, + args: argparse.Namespace, + dtype: torch.dtype, + device: torch.device, +) -> dict[str, Any]: + builders = { + "rms_norm": _make_rms_norm_inputs, + "matmul": _make_matmul_inputs, + "attention": _make_attention_inputs, + "logp": _make_logp_inputs, + "linear_logp": _make_linear_logp_inputs, + "rope": _make_rope_inputs, + "silu": _make_silu_inputs, + "swiglu": _make_swiglu_inputs, + "embedding": _make_embedding_inputs, + "lm_head": _make_lm_head_inputs, + "kv_cache_attention": _make_kv_cache_attention_inputs, + } + try: + return builders[op_name](args, dtype, device) + except KeyError as exc: + raise ValueError(f"unsupported operator inputs: {op_name}") from exc + + +def operator_shape_name(op_name: str, args: argparse.Namespace) -> str: + batch, seq = _batch_seq(args) + vocab = _arg_int(args, "vocab", DEFAULT_VOCAB) + names = { + "rms_norm": f"{batch}x{seq}x{_normalized_dim(args)}", + "matmul": f"{batch}x{seq}x{_matmul_k(args)}x{_matmul_n(args)}", + "attention": f"{batch}x{DEFAULT_N_HEADS}x{seq}x{DEFAULT_HEAD_DIM}", + "logp": f"{batch}x{seq}x{vocab}", + "linear_logp": f"{batch}x{seq}x{_normalized_dim(args)}x{vocab}", + "rope": f"{batch}x{DEFAULT_N_HEADS}x{seq}x{DEFAULT_HEAD_DIM}", + "silu": f"{batch}x{seq}x{DEFAULT_INTERMEDIATE}", + "swiglu": f"{batch}x{seq}x{DEFAULT_INTERMEDIATE}", + "embedding": f"{batch}x{seq}x{vocab}x{DEFAULT_HIDDEN}", + "lm_head": f"{batch}x{seq}x{vocab}", + "kv_cache_attention": f"{batch}x{DEFAULT_N_HEADS}x1x{seq + 1}x{DEFAULT_HEAD_DIM}", + } + try: + return names[op_name] + except KeyError as exc: + raise ValueError(f"unsupported operator shape: {op_name}") from exc + + +def _make_rms_norm_inputs( + args: argparse.Namespace, dtype: torch.dtype, device: torch.device +) -> dict[str, Any]: + batch, seq = _batch_seq(args) + normalized_dim = _normalized_dim(args) + return { + "x": _floating_tensor((batch, seq, normalized_dim), args, dtype, device, offset=0), + "weight": _floating_tensor((normalized_dim,), args, dtype, device, offset=1), + "eps": _arg_float(args, "eps", DEFAULT_RMS_EPS), + } + + +def _make_matmul_inputs( + args: argparse.Namespace, dtype: torch.dtype, device: torch.device +) -> dict[str, Any]: + batch, seq = _batch_seq(args) + k_dim = _matmul_k(args) + n_dim = _matmul_n(args) + return { + "a": _floating_tensor((batch, seq, k_dim), args, dtype, device, offset=0), + "b": _floating_tensor((k_dim, n_dim), args, dtype, device, offset=1), + } + + +def _make_attention_inputs( + args: argparse.Namespace, dtype: torch.dtype, device: torch.device +) -> dict[str, Any]: + batch, seq = _batch_seq(args) + return { + "q": _floating_tensor((batch, DEFAULT_N_HEADS, seq, DEFAULT_HEAD_DIM), args, dtype, device, 0), + "k": _floating_tensor((batch, DEFAULT_N_KV_HEADS, seq, DEFAULT_HEAD_DIM), args, dtype, device, 1), + "v": _floating_tensor((batch, DEFAULT_N_KV_HEADS, seq, DEFAULT_HEAD_DIM), args, dtype, device, 2), + "causal": True, + } + + +def _make_logp_inputs( + args: argparse.Namespace, dtype: torch.dtype, device: torch.device +) -> dict[str, Any]: + batch, seq = _batch_seq(args) + vocab = _arg_int(args, "vocab", DEFAULT_VOCAB) + return { + "logits": _floating_tensor((batch, seq, vocab), args, dtype, device, offset=0), + "token_ids": _token_ids((batch, seq), vocab, args, device), + } + + +def _make_linear_logp_inputs( + args: argparse.Namespace, dtype: torch.dtype, device: torch.device +) -> dict[str, Any]: + batch, seq = _batch_seq(args) + hidden_dim = _normalized_dim(args) + vocab = _arg_int(args, "vocab", DEFAULT_VOCAB) + return { + "hidden": _floating_tensor((batch, seq, hidden_dim), args, dtype, device, offset=0), + "lm_head_weight": _floating_tensor((vocab, hidden_dim), args, dtype, device, offset=1), + "target_ids": _token_ids((batch, seq), vocab, args, device), + "bias": None, + } + + +def _make_rope_inputs( + args: argparse.Namespace, dtype: torch.dtype, device: torch.device +) -> dict[str, Any]: + batch, seq = _batch_seq(args) + return { + "x": _floating_tensor((batch, DEFAULT_N_HEADS, seq, DEFAULT_HEAD_DIM), args, dtype, device, 0), + "positions": torch.arange(seq, device=device, dtype=torch.long), + "theta": _arg_float(args, "theta", DEFAULT_ROPE_THETA), + } + + +def _make_silu_inputs( + args: argparse.Namespace, dtype: torch.dtype, device: torch.device +) -> dict[str, Any]: + batch, seq = _batch_seq(args) + return { + "x": _floating_tensor((batch, seq, DEFAULT_INTERMEDIATE), args, dtype, device, 0), + } + + +def _make_swiglu_inputs( + args: argparse.Namespace, dtype: torch.dtype, device: torch.device +) -> dict[str, Any]: + batch, seq = _batch_seq(args) + return { + "gate": _floating_tensor((batch, seq, DEFAULT_INTERMEDIATE), args, dtype, device, 0), + "up": _floating_tensor((batch, seq, DEFAULT_INTERMEDIATE), args, dtype, device, 1), + } + + +def _make_embedding_inputs( + args: argparse.Namespace, dtype: torch.dtype, device: torch.device +) -> dict[str, Any]: + batch, seq = _batch_seq(args) + vocab = _arg_int(args, "vocab", DEFAULT_VOCAB) + return { + "token_ids": _token_ids((batch, seq), vocab, args, device), + "weight": _floating_tensor((vocab, DEFAULT_HIDDEN), args, dtype, device, 0), + } + + +def _make_lm_head_inputs( + args: argparse.Namespace, dtype: torch.dtype, device: torch.device +) -> dict[str, Any]: + batch, seq = _batch_seq(args) + vocab = _arg_int(args, "vocab", DEFAULT_VOCAB) + return { + "hidden": _floating_tensor((batch, seq, DEFAULT_HIDDEN), args, dtype, device, 0), + "weight": _floating_tensor((vocab, DEFAULT_HIDDEN), args, dtype, device, 1), + "bias": None, + } + + +def _make_kv_cache_attention_inputs( + args: argparse.Namespace, dtype: torch.dtype, device: torch.device +) -> dict[str, Any]: + batch, seq = _batch_seq(args) + return { + "q": _floating_tensor((batch, DEFAULT_N_HEADS, 1, DEFAULT_HEAD_DIM), args, dtype, device, 0), + "k_cache": _floating_tensor((batch, DEFAULT_N_KV_HEADS, seq, DEFAULT_HEAD_DIM), args, dtype, device, 1), + "v_cache": _floating_tensor((batch, DEFAULT_N_KV_HEADS, seq, DEFAULT_HEAD_DIM), args, dtype, device, 2), + "k_new": _floating_tensor((batch, DEFAULT_N_KV_HEADS, 1, DEFAULT_HEAD_DIM), args, dtype, device, 3), + "v_new": _floating_tensor((batch, DEFAULT_N_KV_HEADS, 1, DEFAULT_HEAD_DIM), args, dtype, device, 4), + "causal": True, + } + + +def _floating_tensor( + shape: tuple[int, ...], + args: argparse.Namespace, + dtype: torch.dtype, + device: torch.device, + offset: int, +) -> torch.Tensor: + # Example: torch.randn((B, S, V), device="cuda", dtype=torch.bfloat16) + mode = _arg_str(args, "input_mode", "random") + if mode == "constant": + value = _arg_float(args, "constant_value", 0.25) + float(offset) * 0.01 + return torch.full(shape, value, device=device, dtype=dtype) + if mode != "random": + raise ValueError(f"unsupported input_mode: {mode}") + generator = _generator(args, device, offset) + return torch.randn(shape, generator=generator, device=device, dtype=dtype) + + +def _token_ids( + shape: tuple[int, ...], + vocab: int, + args: argparse.Namespace, + device: torch.device, +) -> torch.Tensor: + mode = _arg_str(args, "input_mode", "random") + if mode == "constant": + value = _arg_int(args, "token_value", 0) % vocab + return torch.full(shape, value, device=device, dtype=torch.long) + generator = _generator(args, device, offset=13) + return torch.randint(0, vocab, shape, generator=generator, device=device, dtype=torch.long) + + +def _generator(args: argparse.Namespace, device: torch.device, offset: int) -> torch.Generator: + generator = torch.Generator(device=device) + generator.manual_seed(_arg_int(args, "seed", 123) + offset) + return generator + + +def _batch_seq(args: argparse.Namespace) -> tuple[int, int]: + return _arg_int(args, "batch", 2), _arg_int(args, "seq", 16) + + +def _normalized_dim(args: argparse.Namespace) -> int: + return _arg_int(args, "normalized_dim", DEFAULT_HIDDEN) + + +def _matmul_k(args: argparse.Namespace) -> int: + return _arg_int(args, "k_dim", DEFAULT_HIDDEN) + + +def _matmul_n(args: argparse.Namespace) -> int: + return _arg_int(args, "n_dim", DEFAULT_HIDDEN) + + +def _arg_float(args: argparse.Namespace, name: str, default: float) -> float: + return float(getattr(args, name, default)) + + +def _arg_int(args: argparse.Namespace, name: str, default: int) -> int: + return int(getattr(args, name, default)) + + +def _arg_str(args: argparse.Namespace, name: str, default: str) -> str: + return str(getattr(args, name, default)) diff --git a/rl_engine/kernels/gtest/operator_specs.py b/rl_engine/kernels/gtest/operator_specs.py new file mode 100644 index 0000000..99c45ec --- /dev/null +++ b/rl_engine/kernels/gtest/operator_specs.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import argparse +import importlib +from dataclasses import dataclass +from typing import Any + +import torch + +from rl_engine.kernels.gtest.operator_inputs import make_operator_inputs, operator_shape_name +from rl_engine.kernels.gtest.op_checks import CandidateSpec, OperatorCase + + +@dataclass(frozen=True) +class OperatorSpec: + name: str + op_class: str + gold_path: str + gold_method: str + candidate_paths: dict[str, str] + grad_input_names: tuple[str, ...] = () + + +def _load_object(path: str) -> Any: + module_path, object_name = path.rsplit(".", 1) + # dynamic loading ops + module = importlib.import_module(module_path) + return getattr(module, object_name) + + +OP_SPECS = { + "logp": OperatorSpec( + name="logp", + op_class="logprob", + gold_path="rl_engine.kernels.ops.pytorch.loss.logp.NativeLogpOp", + gold_method="forward_fp32", + candidate_paths={ + "pytorch": "rl_engine.kernels.ops.pytorch.loss.logp.NativeLogpOp", + "cuda": "rl_engine.kernels.ops.cuda.loss.logp.FusedLogpGenericOp", + "cuda-generic": "rl_engine.kernels.ops.cuda.loss.logp.FusedLogpGenericOp", + "cuda-sm90": "rl_engine.kernels.ops.cuda.loss.logp.FusedLogpSM90Op", + }, + grad_input_names=("logits",), + ), + "linear_logp": OperatorSpec( + name="linear_logp", + op_class="logprob", + gold_path="rl_engine.kernels.ops.pytorch.loss.linear_logp.NativeLinearLogpOp", + gold_method="apply", + candidate_paths={ + "pytorch": "rl_engine.kernels.ops.pytorch.loss.linear_logp.NativeLinearLogpOp", + "triton": "rl_engine.kernels.ops.triton.loss.linear_logp.TritonLinearLogpOp", + "cuda-sm90": "rl_engine.kernels.ops.cuda.loss.linear_logp.FusedLinearLogpSM90Op", + }, + grad_input_names=("hidden", "lm_head_weight"), + ), +} + + +class _LogpSM90CandidateAdapter: + def __init__(self, candidate: Any) -> None: + self._candidate = candidate + + def __call__(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + orig_shape = logits.shape[:-1] + logits_2d = logits.contiguous().view(-1, logits.size(-1)) + labels_1d = token_ids.contiguous().view(-1) + return self._candidate(logits_2d, labels_1d).view(orig_shape) + + +def operator_names() -> tuple[str, ...]: + return tuple(OP_SPECS) + + +def make_operator_case( + args: argparse.Namespace, dtype: torch.dtype, device: torch.device +) -> OperatorCase: + spec = OP_SPECS[args.op] + gold_op = _load_object(spec.gold_path)() + gold_fn = getattr(gold_op, spec.gold_method) + return OperatorCase( + name=f"{args.op}-{dtype}-{operator_shape_name(args.op, args)}", + op_class=spec.op_class, + dtype=dtype, + inputs=make_operator_inputs(args.op, args, dtype, device), + gold_fn=gold_fn, + grad_input_names=spec.grad_input_names, + ) + + +def make_candidate(args: argparse.Namespace) -> CandidateSpec: + spec = OP_SPECS[args.op] + candidate_name = "pytorch" if args.candidate == "native" else args.candidate + + if candidate_name in spec.candidate_paths: + candidate_op = _load_object(spec.candidate_paths[candidate_name])() + if args.op == "logp" and candidate_name == "cuda-sm90": + candidate_op = _LogpSM90CandidateAdapter(candidate_op) + return CandidateSpec( + name=f"{candidate_name}-{args.op}", + backend=candidate_name, + arch_key=args.arch_key, + fn=candidate_op, + ) + + supported = sorted([*spec.candidate_paths, "native"]) + raise ValueError( + f"unsupported candidate {args.candidate!r} for op {args.op!r}; " + f"supported candidates: {', '.join(supported)}" + ) diff --git a/rl_engine/kernels/gtest/tolerance.py b/rl_engine/kernels/gtest/tolerance.py new file mode 100644 index 0000000..3265a45 --- /dev/null +++ b/rl_engine/kernels/gtest/tolerance.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + + +_CONTRACT_PATH = Path(__file__).with_name("tolerance_contract.json") + + +def load_contract(path: str | Path = _CONTRACT_PATH) -> dict[str, Any]: + """Load the dtype/operator-class tolerance contract.""" + + with Path(path).open("r", encoding="utf-8") as handle: + return json.load(handle) + + +__all__ = ["load_contract"] diff --git a/rl_engine/kernels/gtest/tolerance_contract.json b/rl_engine/kernels/gtest/tolerance_contract.json new file mode 100644 index 0000000..7a2ce2b --- /dev/null +++ b/rl_engine/kernels/gtest/tolerance_contract.json @@ -0,0 +1,25 @@ +{ + "batch_invariance": {"atol": 0.0, "rtol": 0.0}, + "accuracy": { + "default": { + "elementwise": { + "float32": {"atol": 1.0e-5, "rtol": 1.0e-5}, + "bfloat16": {"atol": 2.0e-2, "rtol": 1.6e-2}, + "float16": {"atol": 1.0e-3, "rtol": 1.0e-3} + }, + "reduction": { + "float32": {"atol": 1.0e-4, "rtol": 1.0e-4}, + "bfloat16": {"atol": 5.0e-2, "rtol": 2.0e-2}, + "float16": {"atol": 1.0e-3, "rtol": 1.0e-3} + }, + "logprob": { + "float32": {"atol": 1.0e-5, "rtol": 0.0}, + "bfloat16": {"atol": 5.0e-2, "rtol": 0.0}, + "float16": {"atol": 5.0e-3, "rtol": 0.0} + } + }, + "arch_overrides": { + "sm90": {} + } + } +} diff --git a/rl_engine/kernels/ops/pytorch/loss/logp.py b/rl_engine/kernels/ops/pytorch/loss/logp.py index c791927..4ef1f24 100644 --- a/rl_engine/kernels/ops/pytorch/loss/logp.py +++ b/rl_engine/kernels/ops/pytorch/loss/logp.py @@ -1,17 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2026 RL-Kernel Contributors +from __future__ import annotations + import torch class NativeLogpOp: """Pure PyTorch native fallback for Fused LogP.""" - def __init__(self): + op_class = "logprob" + + def __init__(self) -> None: pass def __call__(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: - return self.apply(logits, token_ids) + return self.forward(logits, token_ids) def _selected_logps( self, @@ -45,14 +49,22 @@ def _validate_output_shape(self, output: torch.Tensor, logits: torch.Tensor) -> f"{tuple(logits.shape[:-1])}" ) - def apply(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + def forward(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: """Baseline selected-token log probability extraction using torch.gather.""" return self._selected_logps(logits, token_ids, output_dtype=logits.dtype) - def apply_fp32(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + def forward_fp32(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: """Same as apply but forces float32 output for numerical stability.""" return self._selected_logps(logits, token_ids, output_dtype=torch.float32) + def apply(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + """Backward-compatible alias for forward.""" + return self.forward(logits, token_ids) + + def apply_fp32(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + """Backward-compatible alias for forward_fp32.""" + return self.forward_fp32(logits, token_ids) + def indexed_out( self, logits: torch.Tensor, diff --git a/scripts/check_operator.py b/scripts/check_operator.py new file mode 100644 index 0000000..677f01b --- /dev/null +++ b/scripts/check_operator.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import argparse +import json +import pathlib +import sys +from typing import Any + +import torch + +REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from rl_engine.kernels.gtest import run_operator_suite # noqa: E402 +from rl_engine.kernels.gtest.operator_specs import ( # noqa: E402 + make_candidate, + make_operator_case, + operator_names, +) + + +def _parse_dtype(value: str) -> torch.dtype: + normalized = value.lower() + if normalized in {"fp32", "float32"}: + return torch.float32 + if normalized in {"bf16", "bfloat16"}: + return torch.bfloat16 + if normalized in {"fp16", "float16", "half"}: + return torch.float16 + raise ValueError(f"unsupported dtype: {value}") + + +def _select_device(value: str) -> torch.device: + if value == "auto": + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device(value) + if device.type == "cuda" and not torch.cuda.is_available(): + raise RuntimeError("--device cuda was requested, but CUDA is not available") + return device + + +def _summarize(report: Any) -> None: + print(f"suite={report.suite_name} passed={report.passed} pass_rate={report.pass_rate:.4f}") + for candidate in report.candidates: + print( + f"candidate={candidate.candidate_name} backend={candidate.backend} " + f"passed={candidate.passed} pass_rate={candidate.pass_rate:.4f}" + ) + for case in candidate.cases: + for output in case.outputs: + label = f" {output.message}" if output.message else "" + print( + f" case={case.case_name} output={output.output_index}{label} " + f"shape={output.shape} dtype={output.candidate_dtype} " + f"max_abs={output.max_abs_error:.8e} " + f"mean_abs={output.mean_abs_error:.8e} " + f"max_rel={output.max_rel_error:.8e} " + f"tol=(atol={output.atol:.3e}, rtol={output.rtol:.3e}) " + f"passed={output.passed}" + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Validate an operator candidate against a PyTorch gold path.") + parser.add_argument("--op", choices=operator_names(), default="logp") + parser.add_argument( + "--candidate", + default="pytorch", + help="Candidate backend to validate, for example pytorch, cuda, cuda-sm90, triton.", + ) + parser.add_argument("--dtype", choices=("fp32", "bf16", "fp16"), default="fp32") + parser.add_argument("--device", default="auto") + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--seq", type=int, default=16) + parser.add_argument("--vocab", type=int, default=257) + parser.add_argument("--input-mode", choices=("random", "constant"), default="random") + parser.add_argument("--constant-value", type=float, default=0.25) + parser.add_argument("--token-value", type=int, default=0) + parser.add_argument("--normalized-dim", type=int, default=4096) + parser.add_argument("--k-dim", type=int, default=4096) + parser.add_argument("--n-dim", type=int, default=4096) + parser.add_argument("--theta", type=float, default=1.0e6) + parser.add_argument("--eps", type=float, default=1.0e-6) + parser.add_argument("--seed", type=int, default=123) + parser.add_argument( + "--arch-key", + default=None, + help="Optional tolerance override key, for example sm90. Defaults to contract.default.", + ) + parser.add_argument("--check-grad", action="store_true", help="Also compare gradients for supported inputs.") + # Defaults to random because it catches bugs hidden by output.sum().backward(). + parser.add_argument( + "--grad-mode", + choices=("ones", "random"), + default="random", + help="Upstream gradient mode used with --check-grad.", + ) + parser.add_argument("--grad-seed", type=int, default=123, help="Seed for --grad-mode random.") + parser.add_argument("--json", action="store_true", help="Print the full structured report as JSON.") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + dtype = _parse_dtype(args.dtype) + device = _select_device(args.device) + candidate = make_candidate(args) + case = make_operator_case(args, dtype, device) + report = run_operator_suite( + args.op, + candidates=[candidate], + cases=[case], + check_grad=args.check_grad, + grad_mode=args.grad_mode, + grad_seed=args.grad_seed, + ) + + if args.json: + print(json.dumps(report.to_dict(), indent=2, default=str)) + else: + _summarize(report) + + if not report.passed: + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/test_logp.py b/tests/test_logp.py new file mode 100644 index 0000000..bcb0e3f --- /dev/null +++ b/tests/test_logp.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Tests for NativeLogpOp, the PyTorch selected-logprob reference.""" + +from __future__ import annotations + +import pytest +import torch + +from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp + + +def _make_inputs( + batch: int, + seq: int, + vocab: int, + *, + dtype: torch.dtype = torch.float32, + seed: int = 123, +) -> tuple[torch.Tensor, torch.Tensor]: + gen = torch.Generator().manual_seed(seed) + logits = torch.randn(batch, seq, vocab, generator=gen, dtype=dtype) + token_ids = torch.randint(0, vocab, (batch, seq), generator=gen, dtype=torch.long) + return logits, token_ids + + +def _reference_selected_logp(logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + log_probs = torch.log_softmax(logits.float(), dim=-1) + return torch.gather(log_probs, dim=-1, index=token_ids.long().unsqueeze(-1)).squeeze(-1) + + +class TestNativeLogpOpCorrectness: + def test_output_shape_matches_token_ids(self): + op = NativeLogpOp() + logits, token_ids = _make_inputs(2, 16, 257) + out = op.forward_fp32(logits, token_ids) + assert out.shape == token_ids.shape + + def test_forward_fp32_returns_fp32(self): + op = NativeLogpOp() + logits, token_ids = _make_inputs(2, 16, 257, dtype=torch.bfloat16) + out = op.forward_fp32(logits, token_ids) + assert out.dtype == torch.float32 + + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) + def test_forward_returns_input_dtype(self, dtype): + op = NativeLogpOp() + logits, token_ids = _make_inputs(2, 16, 257, dtype=dtype) + out = op.forward(logits, token_ids) + assert out.dtype == dtype + + def test_call_and_apply_alias_forward(self): + op = NativeLogpOp() + logits, token_ids = _make_inputs(2, 16, 257) + forward = op.forward(logits, token_ids) + assert torch.equal(op(logits, token_ids), forward) + assert torch.equal(op.apply(logits, token_ids), forward) + + def test_apply_fp32_alias_forward_fp32(self): + op = NativeLogpOp() + logits, token_ids = _make_inputs(2, 16, 257) + assert torch.equal(op.apply_fp32(logits, token_ids), op.forward_fp32(logits, token_ids)) + + def test_matches_fp32_reference_bitwise(self): + op = NativeLogpOp() + logits, token_ids = _make_inputs(2, 16, 257) + out = op.forward_fp32(logits, token_ids) + ref = _reference_selected_logp(logits, token_ids) + assert torch.equal(out, ref) + + def test_pure_function_no_inplace(self): + op = NativeLogpOp() + logits, token_ids = _make_inputs(2, 16, 257) + logits_orig = logits.clone() + token_ids_orig = token_ids.clone() + _ = op.forward_fp32(logits, token_ids) + assert torch.equal(logits, logits_orig) + assert torch.equal(token_ids, token_ids_orig) + + def test_forward_fp32_gradient_matches_reference(self): + gen = torch.Generator().manual_seed(654) + logits = torch.randn(2, 4, 17, generator=gen, requires_grad=True) + ref_logits = logits.detach().clone().requires_grad_(True) + token_ids = torch.randint(0, logits.size(-1), (2, 4), generator=gen) + upstream = torch.randn(2, 4, generator=gen) + + (NativeLogpOp().forward_fp32(logits, token_ids) * upstream).sum().backward() + (_reference_selected_logp(ref_logits, token_ids) * upstream).sum().backward() + + assert logits.grad is not None + assert ref_logits.grad is not None + assert torch.allclose(logits.grad, ref_logits.grad, atol=1e-6, rtol=1e-6) + + def test_op_class_is_logprob(self): + assert NativeLogpOp.op_class == "logprob" + + def test_rejects_mismatched_shapes(self): + op = NativeLogpOp() + logits = torch.randn(2, 3, 5) + token_ids = torch.randint(0, 5, (2, 4)) + with pytest.raises(ValueError, match="must match"): + op.forward_fp32(logits, token_ids) + + +class TestNativeLogpOpBatchInvariance: + def test_batch1_vs_batchN_bitwise(self): + op = NativeLogpOp() + logits, token_ids = _make_inputs(4, 16, 257, seed=321) + full_out = op.forward_fp32(logits, token_ids) + for row in range(logits.shape[0]): + single_out = op.forward_fp32(logits[row : row + 1], token_ids[row : row + 1]) + assert torch.equal( + full_out[row], single_out[0] + ), f"Batch invariance broken at row {row}" + + def test_batch_invariance_with_padding(self): + op = NativeLogpOp() + logits_valid, token_ids_valid = _make_inputs(2, 16, 257, seed=456) + gen = torch.Generator().manual_seed(789) + logits_padding = torch.randn(3, 16, 257, generator=gen) + token_padding = torch.randint(0, 257, (3, 16), generator=gen) + logits_padded = torch.cat([logits_valid, logits_padding], dim=0) + token_ids_padded = torch.cat([token_ids_valid, token_padding], dim=0) + + out_valid = op.forward_fp32(logits_valid, token_ids_valid) + out_padded = op.forward_fp32(logits_padded, token_ids_padded) + assert torch.equal(out_valid[0], out_padded[0]) + assert torch.equal(out_valid[1], out_padded[1]) + + +class TestNativeLogpOpAccuracy: + @pytest.mark.parametrize( + "dtype, atol", + [ + (torch.float32, 1e-5), + (torch.bfloat16, 2e-2), + (torch.float16, 5e-3), + ], + ) + def test_forward_vs_fp32_within_tolerance(self, dtype, atol): + op = NativeLogpOp() + logits, token_ids = _make_inputs(2, 16, 17, dtype=dtype) + out_typed = op.forward(logits, token_ids).float() + out_fp32 = op.forward_fp32(logits, token_ids) + diff = (out_typed - out_fp32).abs().max().item() + assert torch.allclose( + out_typed, out_fp32, atol=atol, rtol=0.0 + ), f"dtype={dtype}, max_abs_error={diff:.3e} exceeds atol={atol}" + + +class TestNativeLogpOpRegistry: + @pytest.mark.skipif(torch.cuda.is_available(), reason="CUDA dispatch may select fused logp") + def test_registry_returns_logp_op(self): + from rl_engine.kernels.registry import kernel_registry + + op = kernel_registry.get_op("logp") + assert isinstance(op, NativeLogpOp) diff --git a/tests/test_op_checks.py b/tests/test_op_checks.py new file mode 100644 index 0000000..8f08d87 --- /dev/null +++ b/tests/test_op_checks.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import torch + +from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp +from rl_engine.kernels.gtest.op_checks import CandidateSpec, OperatorCase, run_operator_suite + + +def _logp_case(name: str, dtype: torch.dtype, *, seed: int = 0) -> OperatorCase: + generator = torch.Generator().manual_seed(seed) + logits = torch.randn(2, 8, 257, dtype=dtype, generator=generator) + token_ids = torch.randint(0, logits.size(-1), (2, 8), generator=generator) + return OperatorCase( + name=name, + op_class="logprob", + dtype=dtype, + inputs={"logits": logits, "token_ids": token_ids}, + gold_fn=NativeLogpOp().forward_fp32, + ) + + +def _logp_backward_case(name: str, *, seed: int = 0) -> OperatorCase: + case = _logp_case(name, torch.float32, seed=seed) + return OperatorCase( + name=case.name, + op_class=case.op_class, + dtype=case.dtype, + inputs=case.inputs, + gold_fn=case.gold_fn, + grad_input_names=("logits",), + ) + + +def test_logp_native_candidate_suite_passes(): + report = run_operator_suite( + "logp", + candidates=[CandidateSpec(name="native-logp", backend="pytorch", fn=NativeLogpOp())], + cases=[ + _logp_case("fp32", torch.float32, seed=1), + _logp_case("bf16", torch.bfloat16, seed=2), + _logp_case("fp16", torch.float16, seed=3), + ], + ) + + assert report.passed + assert report.pass_rate == 1.0 + assert report.candidates[0].passed_outputs == 3 + assert all(case.passed for case in report.candidates[0].cases) + + +def test_suite_reports_failure_for_bad_candidate(): + def bad_logp(logits, token_ids): + del token_ids + return torch.zeros(logits.shape[:-1], dtype=logits.dtype) + + report = run_operator_suite( + "logp", + candidates=[CandidateSpec(name="bad-logp", backend="test", fn=bad_logp)], + cases=[_logp_case("fp32", torch.float32, seed=5)], + ) + + output = report.candidates[0].cases[0].outputs[0] + assert not report.passed + assert report.pass_rate == 0.0 + assert output.max_abs_error > 0.0 + + +def test_suite_report_to_dict_contains_error_metrics(): + report = run_operator_suite( + "logp", + candidates=[CandidateSpec(name="native-logp", backend="pytorch", fn=NativeLogpOp())], + cases=[_logp_case("fp32", torch.float32, seed=6)], + ) + + data = report.to_dict() + output = data["candidates"][0]["cases"][0]["outputs"][0] + assert data["suite_name"] == "logp" + assert "max_abs_error" in output + assert "atol" in output + assert "passed" in output + + +def test_candidate_arch_key_uses_tolerance_override(): + def slightly_shifted_logp(logits, token_ids): + return NativeLogpOp().forward_fp32(logits, token_ids) + 0.02 + + contract = { + "accuracy": { + "default": { + "logprob": { + "float32": {"atol": 1.0e-5, "rtol": 0.0}, + } + }, + "arch_overrides": { + "testarch": { + "logprob": { + "float32": {"atol": 5.0e-2, "rtol": 0.0}, + } + } + }, + } + } + report = run_operator_suite( + "logp", + candidates=[ + CandidateSpec( + name="shifted-logp", + backend="test", + fn=slightly_shifted_logp, + arch_key="testarch", + ) + ], + cases=[_logp_case("fp32", torch.float32, seed=7)], + contract=contract, + ) + + output = report.candidates[0].cases[0].outputs[0] + assert report.passed + assert output.atol == 5.0e-2 + + +def test_logp_native_candidate_backward_suite_passes(): + report = run_operator_suite( + "logp", + candidates=[CandidateSpec(name="native-logp", backend="pytorch", fn=NativeLogpOp())], + cases=[_logp_backward_case("fp32", seed=8)], + check_grad=True, + ) + + assert report.passed + assert report.candidates[0].passed_outputs == 2 + assert report.candidates[0].cases[0].outputs[1].message == "gradient:logits" + + +def test_backward_suite_reports_failure_for_bad_gradient(): + def bad_grad_logp(logits, token_ids): + values = NativeLogpOp().forward_fp32(logits, token_ids) + return values.detach() + logits.sum(dim=-1) * 0.0 + + report = run_operator_suite( + "logp", + candidates=[CandidateSpec(name="bad-grad-logp", backend="test", fn=bad_grad_logp)], + cases=[_logp_backward_case("fp32", seed=9)], + check_grad=True, + ) + + gradient_output = report.candidates[0].cases[0].outputs[1] + assert not report.passed + assert gradient_output.message == "gradient:logits" + assert gradient_output.max_abs_error > 0.0 + + +def test_random_grad_mode_catches_nonuniform_upstream_gradient_bug(): + # Forward is identity, so only a non-uniform upstream gradient can expose + # the intentionally wrong backward below. + class MeanUpstreamIdentity(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.clone() + + @staticmethod + def backward(ctx, grad_output): + # Wrong for random upstream gradients, but correct when all values are 1. + return grad_output.mean().expand_as(grad_output) + + def bad_identity(x): + return MeanUpstreamIdentity.apply(x) + + case = OperatorCase( + name="identity", + op_class="elementwise", + dtype=torch.float32, + inputs={"x": torch.randn(8, dtype=torch.float32)}, + gold_fn=lambda x: x, + grad_input_names=("x",), + ) + + ones_report = run_operator_suite( + "identity", + candidates=[CandidateSpec(name="bad-identity", backend="test", fn=bad_identity)], + cases=[case], + check_grad=True, + grad_mode="ones", + ) + # ones passes by design; random must fail and prove the stricter path works. + random_report = run_operator_suite( + "identity", + candidates=[CandidateSpec(name="bad-identity", backend="test", fn=bad_identity)], + cases=[case], + check_grad=True, + grad_mode="random", + grad_seed=7, + ) + + assert ones_report.passed + gradient_output = random_report.candidates[0].cases[0].outputs[1] + assert not random_report.passed + assert gradient_output.message == "gradient:x" + assert gradient_output.max_abs_error > 0.0 diff --git a/tests/test_operator_inputs.py b/tests/test_operator_inputs.py new file mode 100644 index 0000000..bb1a222 --- /dev/null +++ b/tests/test_operator_inputs.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import argparse + +import pytest +import torch + +from rl_engine.kernels.gtest.operator_inputs import make_operator_inputs, operator_shape_name + + +def _args(**overrides): + values = { + "batch": 1, + "seq": 2, + "vocab": 17, + "seed": 123, + "input_mode": "constant", + "constant_value": 0.5, + "token_value": 3, + "normalized_dim": 128, + "k_dim": 16, + "n_dim": 32, + "theta": 1.0e6, + "eps": 1.0e-6, + } + values.update(overrides) + return argparse.Namespace(**values) + + +@pytest.mark.parametrize( + "op_name", + [ + "rms_norm", + "matmul", + "attention", + "logp", + "linear_logp", + "rope", + "silu", + "swiglu", + "embedding", + "lm_head", + "kv_cache_attention", + ], +) +def test_operator_inputs_support_all_issue_108_ops(op_name): + args = _args() + inputs = make_operator_inputs(op_name, args, torch.float32, torch.device("cpu")) + + assert inputs + assert operator_shape_name(op_name, args) + + +def test_constant_logp_inputs_are_deterministic(): + args = _args(input_mode="constant", constant_value=0.5, token_value=3) + inputs = make_operator_inputs("logp", args, torch.float32, torch.device("cpu")) + + assert torch.equal(inputs["logits"], torch.full((1, 2, 17), 0.5)) + assert torch.equal(inputs["token_ids"], torch.full((1, 2), 3, dtype=torch.long)) + + +def test_random_logp_inputs_are_seeded(): + args = _args(input_mode="random", seed=7) + first = make_operator_inputs("logp", args, torch.float32, torch.device("cpu")) + second = make_operator_inputs("logp", args, torch.float32, torch.device("cpu")) + + assert torch.equal(first["logits"], second["logits"]) + assert torch.equal(first["token_ids"], second["token_ids"]) + + +def test_constant_linear_logp_inputs_match_operator_contract(): + args = _args(input_mode="constant", constant_value=0.5, token_value=3) + inputs = make_operator_inputs("linear_logp", args, torch.float32, torch.device("cpu")) + + assert torch.equal(inputs["hidden"], torch.full((1, 2, 128), 0.5)) + assert torch.equal(inputs["lm_head_weight"], torch.full((17, 128), 0.51)) + assert torch.equal(inputs["target_ids"], torch.full((1, 2), 3, dtype=torch.long)) + assert inputs["bias"] is None diff --git a/tests/test_tolerance_contract.py b/tests/test_tolerance_contract.py new file mode 100644 index 0000000..fb429d8 --- /dev/null +++ b/tests/test_tolerance_contract.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +from rl_engine.kernels.gtest.tolerance import load_contract + + +def test_load_contract_contains_expected_operator_classes(): + contract = load_contract() + accuracy = contract["accuracy"]["default"] + assert set(accuracy) == {"elementwise", "reduction", "logprob"} + + +def test_load_contract_contains_expected_dtypes(): + contract = load_contract() + for op_class in ("elementwise", "reduction", "logprob"): + assert set(contract["accuracy"]["default"][op_class]) == { + "float32", + "bfloat16", + "float16", + } + + +def test_logprob_bfloat16_tolerance_covers_observed_reference_drift(): + contract = load_contract() + tolerance = contract["accuracy"]["default"]["logprob"]["bfloat16"] + assert tolerance["atol"] >= 5.0e-2 + assert tolerance["rtol"] == 0.0