Skip to content

feat(ws1): NativeRMSNormOp pure-PyTorch ground-truth reference + numerical contract tests#160

Merged
Flink-ddd merged 9 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-rms_norm-pytorch-op
Jun 30, 2026
Merged

feat(ws1): NativeRMSNormOp pure-PyTorch ground-truth reference + numerical contract tests#160
Flink-ddd merged 9 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-rms_norm-pytorch-op

Conversation

@maxiaosong1124

@maxiaosong1124 maxiaosong1124 commented Jun 20, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds the pure-PyTorch ground-truth reference op for RMSNorm (pre-norm / QK-Norm)
as the first WS1 batch-invariant operator built on top of the numerical contract
defined in #108. Ships the op, its registry wiring, and a 16-case test suite that
pins down both alignment axes (Axis-A bitwise batch invariance, Axis-B per-dtype
tolerance).

Refs #108

Terminology

This PR uses the WS1 alignment vocabulary from #108:

  • Axis-A — batch invariance (reproducibility). A row's output must not depend on
    how many rows share the batch (batch size, slicing, padding). Asserted bitwise
    (torch.equal). This is what keeps train-time (large batch) and sample-time
    (small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
  • Axis-B — accuracy. The low-precision (bf16 / fp16) forward must stay within a
    documented per-dtype tolerance of the fp32 ground-truth. Asserted with torch.allclose
    • per-dtype thresholds.

Motivation / Context

#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. RMSNorm is required on two normalized dims of the
target model (Qwen3-8B dense):

  • hidden = 4096 — input / post-attention norm
  • head_dim = 128 — QK-Norm (per-head RMSNorm on Q and K)

This PR provides the deterministic fp32 reference path those downstream kernels
(Triton / CUDA / ROCm RMSNorm) will be validated against.

Changes

  • rl_engine/kernels/ops/pytorch/norm/rms_norm.pyNativeRMSNormOp
    • forward() — accumulate in fp32, cast result back to x.dtype (Axis-B candidate path)
    • forward_fp32() — fp32 accumulation, forced fp32 output (ground-truth / backward golden source)
    • Formula: out = x * rsqrt(mean(x^2, dim=-1) + eps) * weight
    • eps lives inside the sqrt; plain weight scaling (not the 1 + weight variant)
    • Shape guard: weight must be 1-D of size x.shape[-1]
  • rl_engine/kernels/registry.py — register PYTORCH_NATIVE_RMS_NORM
    and add rms_norm dispatch to the cuda / rocm / cpu priority maps
  • tests/test_rms_norm.py — 16 tests (details below)

How this satisfies the #108 contract

#108 requirement How it's met here
Deterministic reference path, fixed reduction order forward_fp32() accumulates in fp32 along dim=-1; tests use fixed-seed torch.Generator so outputs are reproducible
Per-dtype tolerance policy (bitwise vs tight-tolerance) Axis-A asserted bitwise (torch.equal); Axis-B asserted within documented per-dtype thresholds — bf16 atol=2e-2, rtol=1.6e-2, fp16 atol=1e-3, rtol=1e-3
Batch-config sweep / validation helper Batch-invariance checks compute on the full batch, then assert sliced/padded rows are bitwise identical to their full-batch counterparts
Both normalized dims covered Every correctness/invariance test is parametrized over hidden=4096 and head_dim=128

Test Environment

OS Ubuntu 22.04.5 LTS (kernel 5.15.0-122-generic)
Python 3.12.3
PyTorch 2.8.0+cu128
CUDA / cuDNN 12.8 / 9.10.02 (driver 580.65.06)
pytest 9.0.3
GPU NVIDIA H20

Testing

Run from the repo root with python -m pytest (the -m form puts the repo on

python -m pytest tests/test_rms_norm.py

→ 16 passed, covering:

- Correctness vs an independent hand-written fp32 formula (bitwise, both dims)
- Axis-A batch invariance: row output is independent of batch size — slice and
padding variants, asserted bitwise
- dtype paths: forward follows input dtype; forward_fp32 forces fp32
- Axis-B low-precision (bf16 / fp16) within tolerance of the fp32 reference
- eps inside sqrt (zero input → finite zero output)
- plain weight scaling (rules out the 1 + weight variant)
- shape guard fires on wrong-size / non-1-D weight
- purity (inputs not mutated in place)
- gradient flow (fp32 autograd = backward golden source)
- registry dispatch resolves rms_norm → NativeRMSNormOp

Rebased onto latest upstream/main; registry dispatch for the neighboring
ratio_kl / grpo_loss ops verified intact after conflict resolution.

Checklist

- [x] Pure-PyTorch reference, no custom extension required
- [x] Both Qwen3-8B normalized dims (4096, 128) covered
- [x] Axis-A bitwise batch invariance enforced
- [x] Axis-B per-dtype tolerance documented and tested
- [x] Registered in OpBackend + cuda/rocm/cpu priority maps
- [x] All 16 tests pass locally

---

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **New Features**
  * Added a native pure-PyTorch RMSNorm operator backend and enabled `"rms_norm"` dispatch on CUDA, ROCm, and CPU.
  * `forward` now preserves input dtype (fp16/bf16/fp32), while `forward_fp32` forces fp32 outputs.
  * Enforces weight shape validation and correct `eps` behavior; output respects the requested dtype and scales linearly with weight.

* **Tests**
  * Added extensive pytest coverage for correctness vs fp32 references, dtype routing, error handling, immutability, gradient sanity, batch-slice invariance (forward and backward), and registry dispatch.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

- NativeRMSNormOp with forward / forward_fp32 (fp32 ground-truth path)
- covers both normalized dims: hidden=4096 and head_dim=128 (Qwen3 QK-Norm)
- register PYTORCH_NATIVE_RMS_NORM in OpBackend + cpu/cuda/rocm priority map
- tests/test_rms_norm.py: axis-A bitwise batch invariance + dtype tolerance,
  shape guard, purity, gradient flow, registry dispatch (16 tests)

Refs RL-Align#108
@coderabbitai

coderabbitai Bot commented Jun 20, 2026

Copy link
Copy Markdown

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c03119f0-b31b-42aa-8928-3bfbc7c546dc

📥 Commits

Reviewing files that changed from the base of the PR and between 63ce44c and e825c31.

📒 Files selected for processing (1)
  • rl_engine/kernels/registry.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • rl_engine/kernels/registry.py

📝 Walkthrough

Walkthrough

Adds NativeRMSNormOp, a pure-PyTorch RMSNorm reference implementation, to rl_engine/kernels/ops/pytorch/norm/rms_norm.py. Registers it as OpBackend.PYTORCH_NATIVE_RMS_NORM in KernelRegistry with dispatch entries for cuda, rocm, and cpu. A new pytest module validates correctness, dtype routing, batch invariance, shape guards, purity, gradients, and registry dispatch.

Changes

NativeRMSNormOp: Implementation, Registry Wiring, and Tests

Layer / File(s) Summary
NativeRMSNormOp class and core _rms_norm math
rl_engine/kernels/ops/pytorch/norm/rms_norm.py
Defines NativeRMSNormOp with __call__/forward (fp32 accumulation, casts output to x.dtype), forward_fp32 (forces float32 output), and static _rms_norm that validates weight shape, computes rsqrt(mean(x²) + eps) * weight, and casts to output_dtype.
OpBackend enum and KernelRegistry dispatch wiring
rl_engine/kernels/registry.py
Adds PYTORCH_NATIVE_RMS_NORM to OpBackend with the NativeRMSNormOp import path, and extends KernelRegistry._priority_map with rms_norm entries for cuda, rocm, and cpu.
Test suite: correctness, dtype, guards, purity, gradients, registry
tests/test_rms_norm.py
Validates NativeRMSNormOp against a manual fp32 reference for two normalized dimensions, batch/padding invariance (bitwise equality), dtype routing, bf16/fp16 tolerances, eps/zero-input finiteness, linear weight scaling, ValueError on bad weight shapes, input non-mutation, gradient finiteness, and kernel_registry dispatch.

Possibly Related Issues

Suggested Reviewers

  • inaniloquentee
  • KJLdefeated
  • Flink-ddd

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🐇 I twitched my nose and found the norm,
In PyTorch wool, a cozy form.
Mean squares hum soft, then rsqrt sings,
Through CPU, CUDA, ROCm springs.
Hop hop hooray, the tests all pass! 🌿

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.81% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: a pure-PyTorch NativeRMSNormOp plus numerical contract tests.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/kernels/ops/pytorch/norm/rms_norm.py`:
- Around line 5-6: The file rms_norm.py has linting and formatting violations
from black, isort, and trailing-whitespace checks that are blocking CI. Run the
project's pre-commit formatting hooks (typically via a command like 'pre-commit
run --all-files' or 'black' and 'isort' individually) to automatically reformat
the file and fix signature spacing issues around lines 25-31, expression
formatting issues around lines 38-44 and 57-67, and any trailing whitespace
violations. Commit the reformatted file after the hooks complete.

In `@tests/test_rms_norm.py`:
- Around line 10-11: The test_rms_norm.py file has formatting inconsistencies
detected by Black. Run the Black formatter on this file to automatically fix
alignment and spacing issues in inline comments (like those on the _HIDDEN and
_HEAD_DIM constant definitions) and long assertions (around the test assertion
blocks). Apply Black's output and commit the formatted result to resolve the CI
formatting check.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 34ccd0aa-23c0-4f09-8d2a-e7d333cea8a2

📥 Commits

Reviewing files that changed from the base of the PR and between 9dfcdbc and 5396d27.

📒 Files selected for processing (3)
  • rl_engine/kernels/ops/pytorch/norm/rms_norm.py
  • rl_engine/kernels/registry.py
  • tests/test_rms_norm.py

Comment thread rl_engine/kernels/ops/pytorch/norm/rms_norm.py
Comment thread tests/test_rms_norm.py Outdated
Resolve CodeRabbit formatting findings on RL-Align#160: black (line-length=100),
isort (profile=black), trailing-whitespace and EOF fixes. No logic change;
16 tests still pass.
@Flink-ddd Flink-ddd added platform: cuda Specific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations) priority: high Severe congestion issues require the highest priority for resolution. sprint-0615 labels Jun 21, 2026
@Flink-ddd Flink-ddd requested a review from EthanZero2Hero June 21, 2026 13:10

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a solid, production-ready PR that strictly adheres to the #108 numerical contract. Here are a few professional suggestions for refinement before merging:

Comment thread tests/test_rms_norm.py
Comment thread rl_engine/kernels/ops/pytorch/norm/rms_norm.py

@KJLdefeated KJLdefeated left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation looks correct — standard RMSNorm, fp32 accumulation, eps inside the sqrt, plain weight scaling, clean shape guard. Registry wiring matches the existing attn pattern. Happy to approve once request changes are addressed.

Comment thread tests/test_rms_norm.py
…eck + backward batch-invariance

- KJLdefeated: add test_forward_fp32_matches_torch_reference comparing against
  PyTorch's F.rms_norm via assert_close (tolerance, not torch.equal, since its
  reduction order may differ) so a shared formula bug can't pass green. Keep the
  hand-written _manual_rms_norm test as a secondary bitwise sanity check.
- Flink-ddd: add test_backward_batch_invariance_slice proving input gradients are
  bitwise identical regardless of batch size (Axis-A for gradients, needed for RL-Align#153).
@maxiaosong1124

Copy link
Copy Markdown
Collaborator Author

@KJLdefeated You're right — _manual_rms_norm uses the same expression and float order as the op, so torch.equal only proves they're copies of each other, not that the formula is correct; a shared bug (eps placement, wrong reduction dim) would pass green on both sides. Addressed in 64f6f56:

  • Added test_forward_fp32_matches_torch_reference as the primary correctness check, comparing forward_fp32 against PyTorch's own F.rms_norm. Used torch.testing.assert_close(rtol=1e-6, atol=1e-6) rather than torch.equal exactly because F.rms_norm is free to reduce in a different float order — that non-identical reduction is what makes it a real independent check.
  • Kept the hand-written _manual_rms_norm test as a secondary bitwise sanity check that pins the exact reference semantics.

Thanks for the careful catch!

@maxiaosong1124

Copy link
Copy Markdown
Collaborator Author

This is a solid, production-ready PR that strictly adheres to the #108 numerical contract. Here are a few professional suggestions for refinement before merging:

@Flink-ddd Good point — forward batch-invariance isn't enough; WS1 (and #153) needs the backward to be batch-invariant too, and test_gradient_flows only checked isfinite. Addressed in 64f6f56:

Added test_backward_batch_invariance_slice following your example — it runs the full-batch forward+backward, then a batch-of-1 recompute fed the matching slice of the upstream gradient, and asserts torch.equal(x_slice.grad, grad_x_full_sliced) so the input gradient is bitwise identical regardless of batch size.

On your other note (var = x_f.pow(2).mean(...) materializing a full fp32 tensor): left as-is for this PR, since you flagged it as no-change-required — readability over VRAM for the reference op. I've noted it as a watch-item for the downstream Triton/CUDA kernels on long-context workloads.

All 19 tests in tests/test_rms_norm.py pass and black --check is clean. Thanks for the review!

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/test_rms_norm.py`:
- Around line 144-159: The backward invariance test only covers the `_HIDDEN`
normalization width, so it should be expanded to validate both Qwen3 RMSNorm
widths. Update `test_backward_batch_invariance_slice` to run the same
forward/backward slice comparison for both `_HIDDEN` and `_HEAD_DIM` (for
example via parametrization), keeping the existing `NativeRMSNormOp`, `_rand`,
and gradient slice assertions unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5bc24ea0-96eb-4323-a0b7-59f424add2dc

📥 Commits

Reviewing files that changed from the base of the PR and between 6c50a87 and fb6cfc5.

📒 Files selected for processing (2)
  • rl_engine/kernels/registry.py
  • tests/test_rms_norm.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • rl_engine/kernels/registry.py

Comment thread tests/test_rms_norm.py

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, Thank you for update.

@Flink-ddd

Copy link
Copy Markdown
Collaborator

the GPU CI error is not releate this PR, it's out of GPU resource. but please resolve CI-Pipeline error. also cc @KJLdefeated PTAL again.

@KJLdefeated KJLdefeated left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Approved.

@maxiaosong1124

Copy link
Copy Markdown
Collaborator Author

the GPU CI error is not releate this PR, it's out of GPU resource. but please resolve CI-Pipeline error. also cc @KJLdefeated PTAL again.

@Flink-ddd Resolved !!

@Flink-ddd Flink-ddd merged commit 27e3695 into RL-Align:main Jun 30, 2026
4 of 5 checks passed
@Flink-ddd

Copy link
Copy Markdown
Collaborator

Thank you, merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-gpu-ci platform: cuda Specific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations) priority: high Severe congestion issues require the highest priority for resolution. sprint-0615

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants