Skip to content

Add kernel gtest operator checker and used logprob to test gtest#197

Open
a-kaa wants to merge 11 commits into
RL-Align:mainfrom
a-kaa:logp-gtest
Open

Add kernel gtest operator checker and used logprob to test gtest#197
a-kaa wants to merge 11 commits into
RL-Align:mainfrom
a-kaa:logp-gtest

Conversation

@a-kaa

@a-kaa a-kaa commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator

Summary

Usage

Run a CPU smoke check against the PyTorch gold implementation:

  python scripts/check_operator.py \
    --op logp \
    --candidate pytorch \
    --device cpu \
    --dtype fp32 \
    --batch 1 \
    --seq 2 \
    --vocab 17

Run a CUDA candidate check against the PyTorch gold path:

  python scripts/check_operator.py \
    --op logp \
    --candidate cuda \
    --device cuda \
    --dtype bf16 \
    --arch-key sm90 \
    --batch 1 \
    --seq 1 \
    --vocab 4096

Print the full structured report as JSON:

  python scripts/check_operator.py \
    --op logp \
    --candidate pytorch \
    --device cpu \
    --dtype fp32 \
    --batch 1 \
    --seq 2 \
    --vocab 17 \
    --json

Available key options:

  • --op: operator name. Current minimal version supports logp.
  • --candidate: backend candidate, for example pytorch, cuda, cuda-generic, cuda-sm90, or registry.
  • --dtype: fp32, bf16, or fp16.
  • --device: auto, cpu, cuda, or any torch device string.
  • --arch-key: optional tolerance override key, for example sm90.
  • --batch, --seq, --vocab: shape controls.
  • --input-mode: random or constant.
  • --constant-value: floating-point tensor value used by constant mode.
  • --token-value: token id used by constant mode, modulo vocab.
  • --seed: random input seed.
  • --json: emit full JSON report.

Example output:

suite=logp passed=True pass_rate=1.0000
candidate=cuda-logp backend=cuda passed=True pass_rate=1.0000
case=logp-torch.bfloat16-1x1x4096 output=0 shape=(1, 1) dtype=torch.bfloat16 max_abs=2.69813538e-02 mean_abs=2.69813538e-02 max_rel=3.03093810e-03
tol=(atol=5.000e-02, rtol=0.000e+00) passed=True

Adding a New Operator

To add a new operator to the checker, keep the public test flow unchanged and only update the operator-specific registration/input files.

1. Add input generation

(Already added, need check the shapes)

  rl_engine/kernels/gtest/operator_inputs.py 


  Update make_operator_inputs():

  builders = {
      ...
      "new_op": _make_new_op_inputs,
  }

  Update operator_shape_name():

  names = {
      ...
      "new_op": f"{batch}x{seq}x...",
  }

  Add the input builder:

  def _make_new_op_inputs(args, dtype, device):
      batch, seq = _batch_seq(args)
      return {
          "x": _floating_tensor((batch, seq, ...), args, dtype, device, offset=0),
          ...
      }

2. Register gold and candidates

(NEED OPS OWNER ADD)


  rl_engine/kernels/gtest/operator_specs.py

  Add an OperatorSpec entry:

  "new_op": OperatorSpec(
      name="new_op",
      op_class="elementwise",
      gold_path="rl_engine.kernels.ops.pytorch....NativeNewOp",
      registry_name="new_op",
      candidate_paths={
          "pytorch": "rl_engine.kernels.ops.pytorch....NativeNewOp",
          "cuda": "rl_engine.kernels.ops.cuda....CudaNewOp",
          "triton": "rl_engine.kernels.ops.triton....TritonNewOp",
      },
  )

Rules:

  • gold_path must come from rl_engine.kernels.ops.pytorch.
  • CUDA/Triton/ROCm implementations are candidates only.
  • Do not compare two operators that implement different math.
  • candidate=pytorch is only for checker smoke tests.

Validation

  • python scripts/check_operator.py --op logp --candidate pytorch --device cpu --dtype fp32 --batch 1 --seq 2 --vocab 17
71837019648a53174ec8e566ce210027

In fact, the test did not pass; however, it proves that the workflow of our testing framework has achieved the minimum viable capability.

Notes

  • Gold paths are required to come from rl_engine.kernels.ops.pytorch.
  • CUDA/Triton/ROCm implementations are treated as candidates.
  • SM90 fused logp remains under separate validation and is not included as a passing path in this PR.

Summary by CodeRabbit

  • New Features
    • Added a command-line tool to compare operator results against a reference implementation, with support for dtype, device, input modes, seeds, and optional gradient checks.
    • Expanded operator coverage for logp and linear_logp, including standardized input generation and tolerance-based comparisons.
  • Bug Fixes
    • Improved accuracy reporting for edge cases and added support for batch-invariance and backward-pass validation.
  • Documentation
    • Updated operator guidance and contribution notes with usage examples, test commands, and validation steps.

@coderabbitai

coderabbitai Bot commented Jun 27, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 39881f14-23eb-438c-91c5-8305b256132a

📥 Commits

Reviewing files that changed from the base of the PR and between 4bd2ed8 and 347726b.

📒 Files selected for processing (6)
  • docs/contributing/issue-108-session-log.md
  • rl_engine/kernels/gtest/operator_specs.py
  • rl_engine/kernels/gtest/tolerance.py
  • rl_engine/kernels/gtest/tolerance_contract.json
  • scripts/check_operator.py
  • tests/test_op_checks.py
💤 Files with no reviewable changes (2)
  • rl_engine/kernels/gtest/tolerance_contract.json
  • tests/test_op_checks.py
✅ Files skipped from review due to trivial changes (1)
  • docs/contributing/issue-108-session-log.md
🚧 Files skipped from review as they are similar to previous changes (2)
  • rl_engine/kernels/gtest/tolerance.py
  • scripts/check_operator.py

📝 Walkthrough

Walkthrough

Adds a new gtest-style operator correctness framework under rl_engine/kernels/gtest/ (tolerance contract, operator input builders, operator specs, gradient-aware suite runner), a scripts/check_operator.py CLI, refactors NativeLogpOp into forward/forward_fp32 with backward-compatible aliases, plus documentation and tests.

Changes

Operator Checking Framework

Layer / File(s) Summary
Tolerance contract and loader
rl_engine/kernels/gtest/tolerance.py, rl_engine/kernels/gtest/tolerance_contract.json, tests/test_tolerance_contract.py
Adds load_contract(), a JSON tolerance contract with per-op_class/dtype accuracy and batch-invariance tolerances plus arch_overrides, and tests validating its structure and values.
Suite runner and gradient checks
rl_engine/kernels/gtest/op_checks.py, rl_engine/kernels/gtest/__init__.py
Adds grad_input_names to OperatorCase, extends run_operator_suite with check_grad/grad_mode/grad_seed, adds _run_case_backward and gradient utilities, refactors output comparison, and re-exports public symbols.
Operator inputs, specs, and NativeLogpOp forward refactor
rl_engine/kernels/gtest/operator_inputs.py, rl_engine/kernels/gtest/operator_specs.py, rl_engine/kernels/ops/pytorch/loss/logp.py
Adds deterministic input builders/shape naming per operator, an OperatorSpec registry (logp, linear_logp) with gold/candidate resolution and an SM90 adapter, and refactors NativeLogpOp to add forward/forward_fp32 with apply/apply_fp32 as aliases.
CLI runner
scripts/check_operator.py
Adds a CLI that builds a candidate/case, runs the suite with configurable dtype/device/gradient options, and prints JSON or text reports, exiting non-zero on failure.
Operator and harness tests
tests/test_logp.py, tests/test_operator_inputs.py, tests/test_op_checks.py, tests/test_tolerance_contract.py
Adds pytest coverage for NativeLogpOp correctness/accuracy/registry, operator input determinism, forward/gradient suite pass/fail behavior, and tolerance contract structure.
Session log and operator docs
docs/contributing/issue-108-session-log.md, docs/operators/fused-logp.md
Documents the framework's design decisions, CLI usage, adding-new-operator flow, CUDA validation notes, and revises the fused-logp guide's backend naming, tensor contract, and test/implementation file lists.

Estimated code review effort: 4 (Complex) | ~60 minutes

Sequence Diagram(s)

sequenceDiagram
  participant CLI as check_operator.py
  participant Runner as run_operator_suite
  participant Candidate as _run_candidate
  participant Backward as _run_case_backward
  participant Compare as _compare_output

  CLI->>Runner: candidate, cases, contract, check_grad
  Runner->>Candidate: execute cases
  alt check_grad enabled
    Candidate->>Backward: candidate, gold, grad inputs
    Backward->>Compare: compare outputs and gradients
  else value-only mode
    Candidate->>Compare: compare candidate vs gold outputs
  end
  Compare-->>Runner: case checks
  Runner-->>CLI: OperatorCheckReport
Loading

Possibly related PRs

  • RL-Align/RL-Kernel#97: NativeLogpOp's refactor to forward/forward_fp32 with apply/apply_fp32 aliases directly affects consumers like NativeRatioKLOp that rely on apply_fp32.
  • RL-Align/RL-Kernel#122: This PR's linear_logp input generation and operator-spec registration complement that PR's introduction of linear_logp backends and registry integration.

Suggested labels: needs-gpu-ci

Suggested reviewers: inaniloquentee, KJLdefeated

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 7.32% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title is related to the main change: it adds the kernel gtest operator checker and the logprob test path.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (1)
rl_engine/kernels/gtest/op_checks.py (1)

184-187: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick win

Call the candidate's public callable path first.

For torch.nn.Module-like candidates, jumping straight to .forward() bypasses __call__ hooks and wrappers, so the checker may validate a different path than production. Prefer candidate(**inputs) when the object is callable, and only fall back to .forward() for non-callable adapters.

Proposed change
 def _call_candidate(candidate: Callable[..., Any] | Any, inputs: Mapping[str, Any]) -> Any:
-    if hasattr(candidate, "forward") and callable(candidate.forward):
-        return candidate.forward(**inputs)
-    return candidate(**inputs)
+    if callable(candidate):
+        return candidate(**inputs)
+    if hasattr(candidate, "forward") and callable(candidate.forward):
+        return candidate.forward(**inputs)
+    raise TypeError(f"candidate is not callable: {type(candidate)!r}")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/gtest/op_checks.py` around lines 184 - 187, The
_call_candidate helper is invoking .forward() first for Module-like objects,
which skips the public callable path. Update _call_candidate so it prefers
candidate(**inputs) whenever candidate is callable, and only falls back to
candidate.forward(**inputs) for non-callable adapters; keep the existing
callable checks around candidate and forward to preserve compatibility.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/kernels/gtest/op_checks.py`:
- Around line 149-150: The candidate and gold evaluations in the op check path
are sharing the same input tree, so in-place writes from the candidate can
affect the reference result. Update the logic around _call_candidate and
case.gold_fn in op_checks.py to build separate cloned copies of case.inputs for
each call, using distinct cloned input trees for the candidate and gold paths.
- Around line 151-155: The arity mismatch in the operator check path currently
raises a ValueError and aborts run_operator_suite() before it can build an
OperatorCheckReport. Update the mismatch handling in op_checks.py around the
candidate/gold output comparison so it records a failed CaseCheck/OutputCheck
for the relevant candidate rather than throwing, allowing
scripts/check_operator.py to continue and emit its structured report. Use the
existing run_operator_suite(), CaseCheck, and OutputCheck flow to surface the
mismatch as a normal test failure.

In `@rl_engine/kernels/gtest/operator_inputs.py`:
- Around line 199-210: Reject non-positive vocab values at the start of
_token_ids before any mode-specific logic runs. Add an upfront validation in
_token_ids that raises a clear error when vocab <= 0, so constant mode does not
hit token_value % vocab and random mode does not fall through to torch.randint
with an invalid range. Keep the fix localized to _token_ids and use its existing
parameters to enforce the check.

In `@rl_engine/kernels/gtest/operator_specs.py`:
- Around line 78-106: make_candidate() currently lets CUDA-backed candidates
through without validating the resolved device, so add an early guard in this
function to reject cuda/cuda-sm90 when the selected device is not CUDA-capable.
Use the existing args.device flow from
make_operator_case()/scripts/check_operator.py and check the resolved device
before loading the candidate, raising a clear ValueError for invalid
backend/device combinations. Keep the existing candidate selection logic and
unique symbols like make_candidate, CandidateSpec, and _LogpSM90CandidateAdapter
to locate the fix.

In `@rl_engine/kernels/gtest/tolerance.py`:
- Around line 11-18: The tolerance contract loader in load_contract is reading
tolerance_contract.yaml with json.load, which only supports strict JSON. Update
load_contract to parse the contract as YAML (using the existing _CONTRACT_PATH
target) or, if you intend to keep JSON parsing, rename the contract and path to
a .json file so the loader and file format match.

---

Nitpick comments:
In `@rl_engine/kernels/gtest/op_checks.py`:
- Around line 184-187: The _call_candidate helper is invoking .forward() first
for Module-like objects, which skips the public callable path. Update
_call_candidate so it prefers candidate(**inputs) whenever candidate is
callable, and only falls back to candidate.forward(**inputs) for non-callable
adapters; keep the existing callable checks around candidate and forward to
preserve compatibility.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e1add4ce-5e2e-4b44-943e-1d6e4abc2a8c

📥 Commits

Reviewing files that changed from the base of the PR and between ea196da and 7defbdd.

📒 Files selected for processing (14)
  • docs/contributing/issue-108-session-log.md
  • docs/operators/fused-logp.md
  • rl_engine/kernels/gtest/__init__.py
  • rl_engine/kernels/gtest/op_checks.py
  • rl_engine/kernels/gtest/operator_inputs.py
  • rl_engine/kernels/gtest/operator_specs.py
  • rl_engine/kernels/gtest/tolerance.py
  • rl_engine/kernels/gtest/tolerance_contract.yaml
  • rl_engine/kernels/ops/pytorch/loss/logp.py
  • scripts/check_operator.py
  • tests/test_logp.py
  • tests/test_op_checks.py
  • tests/test_operator_inputs.py
  • tests/test_tolerance_contract.py

Comment thread rl_engine/kernels/gtest/op_checks.py
Comment thread rl_engine/kernels/gtest/op_checks.py
Comment thread rl_engine/kernels/gtest/operator_inputs.py
Comment thread rl_engine/kernels/gtest/operator_specs.py
Comment thread rl_engine/kernels/gtest/tolerance.py Outdated

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an excellently designed testing harness. Building a unified gtest runner with a YAML-backed tolerance contract is exactly what Workstream 1 needs to scale across CUDA, Triton, and ROCm. The JSON output integration will make CI regression tracking much easier.

Since this is an initial MVP version, the foundation looks very solid. Below are some reviews that outline some architectural improvements:

@@ -0,0 +1,1044 @@
# ISSUE-108 Session Log

本文档记录本 session 中围绕 RL-Kernel 算子测试框架、CUDA 验证和 upstream 同步的所有关键修改。后续本 session 中每次代码修改都必须继续追加到本文档,记录目标、设计判断、修改文件、验证方式和结果。

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use english to replace

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please translate the session log from Chinese to English to maintain the repository's open-source linguistic consistency.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix it today

)


def _run_case(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, _run_case only evaluates the forward pass (_call_candidate). As we have established in the RMSNorm, SwiGLU, and Embedding PRs, backward-pass consistency is critical to preventing RL training drift.

While it doesn't need to be implemented in this exact PR, you must add a TODO or open a tracking issue to support gradient checking in the gtest framework. Future iterations will need to set requires_grad=True on floating inputs, call .backward() on the outputs, and compare the resulting .grad tensors against the gold path using this same tolerance contract.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank very much for pointing that. I think it is critical for WS1, i will complete the backward part of gtest.
In fact, I am doing it, iam using the linear log-prob triton op . It implemented forward and backward. I’ll polish my code tomorrow and submit an updated version.
image

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @Flink-ddd I have commit a new version for gtest. Add partial new code to support backward in gTest. E2E tests passed on both H20 and H100 GPUs. Note: tolerance tables may require further tuning. The figure shows an example of a backward test.
image


if candidate_name in spec.candidate_paths:
candidate_op = _load_object(spec.candidate_paths[candidate_name])()
if args.op == "logp" and candidate_name == "cuda-sm90":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Injecting _LogpSM90CandidateAdapter using an if statement (if args.op == "logp" and candidate_name == "cuda-sm90":) inside make_candidate works for the MVP, but it will quickly become spaghetti code as you add more operators, ROCm backends, and Triton kernels that require shape-flattening or specific adapters.

Suggestion: Make the adapter mapping declarative. Add an optional candidate_adapters: dict[str, Callable] field to the OperatorSpec dataclass, so make_candidate can simply look up and wrap the candidate dynamically without needing to know operator-specific logic.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinks for your comments. Our intended direction is that candiate wrappers should expose the same interface as the Pytorch gold path. The current SM90 adaoter is a temproary wrapper for SM90 CUDA interface, and I think we don't need any if judgement, we will delete the if in the future.
In current SM90 CUDA interface, the input shape is [batch*seq, vocba]. However, in the gtest, we don't flatten the batch, the shape is [batch, seq, vocba]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, make sense.

Comment thread tests/test_logp.py
logits, token_ids = _make_inputs(2, 16, 257, dtype=dtype)
out = op.forward(logits, token_ids)
assert out.dtype == dtype

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed you included the unit tests for NativeLogpOp in this PR to prove the gtest framework works. Similar to the previous reference operators, test_logp.py is completely missing backward-pass tests (test_gradient_flows). Please ensure you add a slice-based Batch-Invariance (Axis-A) test for NativeLogpOp's backward pass before WS1 concludes.

}
},
"arch_overrides": {
"sm90": {}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Starting with sm90 as the only arch_overrides key is perfectly fine for this stage. The structure here is very clean, and it will be trivial to expand this to include rocm or cdna3 specific tolerances when you reach WS2 and WS3. No changes needed here.

@Flink-ddd

Copy link
Copy Markdown
Collaborator

Comment on lines +184 to +187
def _call_candidate(candidate: Callable[..., Any] | Any, inputs: Mapping[str, Any]) -> Any:
if hasattr(candidate, "forward") and callable(candidate.forward):
return candidate.forward(**inputs)
return candidate(**inputs)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The harness builds inputs as {logits, token_ids} (matching NativeLogpOp) and dispatches with candidate(**inputs) (op_checks.py:184-187). But FusedLogpSM90Op.call(self, logits, labels) uses labels, not token_ids → TypeError. Only the explicit cuda-sm90 candidate gets the positional _LogpSM90CandidateAdapter; the registry candidate does not, so when the registry resolves to the SM90 op the kwargs don't match.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out, It made me realize that the registry may no longer be necessary. I’ll try removing it, so users only need to choose the backend they want to test.

Comment on lines +17 to +18
with Path(path).open("r", encoding="utf-8") as handle:
return json.load(handle)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

json.load on yaml files, this should be resolve.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The contract file is JSON-formatted but currently named .yaml, while the loader uses json.load. To avoid adding a YAML parser dependency and keep the file format explicit, I will rename it to tolerance_contract.json and update the loader/docs references accordingly.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/contributing/issue-108-session-log.md`:
- Around line 83-88: The session log wording around the runner’s comparison
behavior is inconsistent with the later `check_grad=True` support, so update the
affected prose to clearly describe the final behavior or explicitly mark the
forward-only text as historical context. Use the relevant session-log section
that mentions `op_checks.py`, gradient checks, and the runner’s output
comparison to keep the narrative internally consistent across the document.

In `@rl_engine/kernels/gtest/op_checks.py`:
- Around line 201-220: The backward comparison in op_checks.py currently runs
before the forward outputs are validated, so a shape/arity mismatch in
candidate_outputs can cause _backward_grads() to fail before
_compare_case_outputs() reports the real issue. Update the flow around
_make_grad_outputs(), _match_grad_outputs(), and _compare_case_outputs() so
forward output compatibility is checked first, and only run the candidate/gold
backward passes after confirming the outputs match or by deriving gold upstream
gradients from gold_outputs.
- Around line 334-343: The backward path in _run_case_backward currently lets
failures from loss.backward() and the inputs[name].grad is None check escape as
exceptions instead of being reported as structured gradient failures. Update
_run_case_backward to catch backward-time errors and missing-grad cases, then
return a failure result tagged as gradient:<input> for the affected
grad_input_names, using the existing output/grad_input handling in op_checks.py
to keep non-differentiable candidates from aborting the suite.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2ce3d0a8-df5d-4bb6-847e-fce0f01558d0

📥 Commits

Reviewing files that changed from the base of the PR and between 7defbdd and 4bd2ed8.

📒 Files selected for processing (8)
  • docs/contributing/issue-108-session-log.md
  • rl_engine/kernels/gtest/op_checks.py
  • rl_engine/kernels/gtest/operator_inputs.py
  • rl_engine/kernels/gtest/operator_specs.py
  • scripts/check_operator.py
  • tests/test_logp.py
  • tests/test_op_checks.py
  • tests/test_operator_inputs.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • tests/test_operator_inputs.py
  • scripts/check_operator.py
  • rl_engine/kernels/gtest/operator_inputs.py
  • tests/test_logp.py

Comment on lines +83 to +88
- The runner compares forward outputs only in this minimal version.

Review follow-up:

- `op_checks.py` includes a TODO for optional gradient checks on differentiable operators.
- Gradient checks require additional metadata and input cloning rules, so they are intentionally tracked as follow-up work instead of being silently implied by this PR.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Clarify whether gradient checks are historical or current behavior.

This section still reads as if the runner is forward-only and gradient checks are future work, but the later PR update says check_grad=True support has already landed. Please rewrite this as historical context or update it to the final behavior so the log stays internally consistent.

♻️ Suggested wording
- The runner compares forward outputs only in this minimal version.
- Gradient checks require additional metadata and input cloning rules, so they are intentionally tracked as follow-up work instead of being silently implied by this PR.
+ The runner initially compared forward outputs only; backward checks were added later in this PR.

Also applies to: 414-429

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/contributing/issue-108-session-log.md` around lines 83 - 88, The session
log wording around the runner’s comparison behavior is inconsistent with the
later `check_grad=True` support, so update the affected prose to clearly
describe the final behavior or explicitly mark the forward-only text as
historical context. Use the relevant session-log section that mentions
`op_checks.py`, gradient checks, and the runner’s output comparison to keep the
narrative internally consistent across the document.

Comment on lines +201 to +220
grad_outputs = _make_grad_outputs(candidate_outputs, grad_mode=grad_mode, seed=grad_seed)
candidate_grads = _backward_grads(
candidate_outputs,
candidate_inputs,
case.grad_input_names,
grad_outputs=grad_outputs,
)
gold_grads = _backward_grads(
gold_outputs,
gold_inputs,
case.grad_input_names,
grad_outputs=_match_grad_outputs(grad_outputs, gold_outputs),
)
output_checks = _compare_case_outputs(
candidate,
case,
contract,
candidate_outputs,
gold_outputs,
).outputs

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
import torch
gold = torch.randn(2, 3, requires_grad=True)
bad_upstream = torch.randn(2, 4)
try:
    (gold.float() * bad_upstream.float()).sum().backward()
except RuntimeError as exc:
    print(type(exc).__name__, str(exc).splitlines()[0])
PY

Repository: RL-Align/RL-Kernel

Length of output: 259


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Map the target file and locate the relevant functions.
ast-grep outline rl_engine/kernels/gtest/op_checks.py --view expanded

# Show the relevant section with line numbers.
sed -n '150,260p' rl_engine/kernels/gtest/op_checks.py

# Locate helper definitions used here.
rg -n "_compare_case_outputs|_match_grad_outputs|_make_grad_outputs|_backward_grads" rl_engine/kernels/gtest/op_checks.py

Repository: RL-Align/RL-Kernel

Length of output: 6357


🏁 Script executed:

#!/bin/bash
set -euo pipefail

sed -n '253,390p' rl_engine/kernels/gtest/op_checks.py

Repository: RL-Align/RL-Kernel

Length of output: 4824


Compare forward outputs before backward. _make_grad_outputs() is built from candidate_outputs, and _match_grad_outputs() only moves devices, so a same-arity/wrong-shape candidate can make the gold backward pass raise before _compare_case_outputs() reports the forward mismatch. Short-circuit on output shape mismatches first, or build the gold upstream gradients from gold_outputs.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/gtest/op_checks.py` around lines 201 - 220, The backward
comparison in op_checks.py currently runs before the forward outputs are
validated, so a shape/arity mismatch in candidate_outputs can cause
_backward_grads() to fail before _compare_case_outputs() reports the real issue.
Update the flow around _make_grad_outputs(), _match_grad_outputs(), and
_compare_case_outputs() so forward output compatibility is checked first, and
only run the candidate/gold backward passes after confirming the outputs match
or by deriving gold upstream gradients from gold_outputs.

Comment on lines +334 to +343
loss = sum(
(output.float() * grad_output.to(device=output.device).float()).sum()
for output, grad_output in zip(outputs, grad_outputs, strict=True)
)
loss.backward()
grads: list[torch.Tensor] = []
for name in grad_input_names:
grad = inputs[name].grad
if grad is None:
raise ValueError(f"gradient for input {name!r} is None")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
import torch
x = torch.randn(4, requires_grad=True)
detached_output = x.detach()
try:
    detached_output.sum().backward()
except RuntimeError as exc:
    print(type(exc).__name__, str(exc).splitlines()[0])
PY

Repository: RL-Align/RL-Kernel

Length of output: 240


🏁 Script executed:

#!/bin/bash
set -euo pipefail
python3 - <<'PY'
from pathlib import Path
p = Path('rl_engine/kernels/gtest/op_checks.py')
lines = p.read_text().splitlines()
for start in (280, 320, 340, 360):
    end = min(len(lines), start+30)
    print(f"\n--- {p}:{start+1}-{end} ---")
    for i in range(start, end):
        print(f"{i+1:4d}: {lines[i]}")
PY

Repository: RL-Align/RL-Kernel

Length of output: 5344


🏁 Script executed:

#!/bin/bash
set -euo pipefail
python3 - <<'PY'
from pathlib import Path
p = Path('rl_engine/kernels/gtest/op_checks.py')
lines = p.read_text().splitlines()
for start in (280, 320, 340, 360):
    end = min(len(lines), start+30)
    print(f"\n--- {p}:{start+1}-{end} ---")
    for i in range(start, end):
        print(f"{i+1:4d}: {lines[i]}")
PY
PY

Repository: RL-Align/RL-Kernel

Length of output: 5385


🏁 Script executed:

#!/bin/bash
set -euo pipefail
python3 - <<'PY'
from pathlib import Path
p = Path('rl_engine/kernels/gtest/op_checks.py')
text = p.read_text().splitlines()
for i, line in enumerate(text, 1):
    if '_run_case_backward' in line or 'structured' in line or 'gradient:' in line or 'ValueError' in line:
        print(f'{i:4d}: {line}')
PY

Repository: RL-Align/RL-Kernel

Length of output: 797


🏁 Script executed:

#!/bin/bash
set -euo pipefail
python3 - <<'PY'
from pathlib import Path
p = Path('rl_engine/kernels/gtest/op_checks.py')
text = p.read_text().splitlines()
for i, line in enumerate(text, 1):
    if '_run_case_backward' in line or 'gradient:' in line or 'ValueError' in line:
        print(f'{i:4d}: {line}')
PY

Repository: RL-Align/RL-Kernel

Length of output: 713


🏁 Script executed:

#!/bin/bash
set -euo pipefail
python3 - <<'PY'
from pathlib import Path
p = Path('rl_engine/kernels/gtest/op_checks.py')
lines = p.read_text().splitlines()
for start in (182, 210, 225, 240):
    end = min(len(lines), start + 35)
    print(f"\n--- {p}:{start}-{end} ---")
    for i in range(start - 1, end):
        print(f"{i+1:4d}: {lines[i]}")
PY

Repository: RL-Align/RL-Kernel

Length of output: 5949


🏁 Script executed:

#!/bin/bash
set -euo pipefail
python3 - <<'PY'
from pathlib import Path
p = Path('rl_engine/kernels/gtest/op_checks.py')
lines = p.read_text().splitlines()
for start in (90, 120, 140):
    end = min(len(lines), start + 35)
    print(f"\n--- {p}:{start}-{end} ---")
    for i in range(start - 1, end):
        print(f"{i+1:4d}: {lines[i]}")
PY

Repository: RL-Align/RL-Kernel

Length of output: 4246


Treat backward failures as structured gradient failures
In rl_engine/kernels/gtest/op_checks.py:334-343, loss.backward() and the grad is None path can escape _run_case_backward, so a non-differentiable candidate aborts the suite instead of reporting gradient:<input> for the failing case.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/gtest/op_checks.py` around lines 334 - 343, The backward
path in _run_case_backward currently lets failures from loss.backward() and the
inputs[name].grad is None check escape as exceptions instead of being reported
as structured gradient failures. Update _run_case_backward to catch
backward-time errors and missing-grad cases, then return a failure result tagged
as gradient:<input> for the affected grad_input_names, using the existing
output/grad_input handling in op_checks.py to keep non-differentiable candidates
from aborting the suite.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants