-
Notifications
You must be signed in to change notification settings - Fork 42
feat(ws1): NativeRMSNormOp pure-PyTorch ground-truth reference + numerical contract tests #160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Flink-ddd
merged 9 commits into
RL-Align:main
from
maxiaosong1124:feat/ws1-rms_norm-pytorch-op
Jun 30, 2026
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
5396d27
feat(ws1): add NativeRMSNormOp pure-PyTorch reference + tests
maxiaosong1124 6c50a87
style(ws1): apply black/isort formatting to RMSNorm op + tests
maxiaosong1124 f61667e
Merge branch 'RL-Align:main' into feat/ws1-rms_norm-pytorch-op
maxiaosong1124 5078762
Merge branch 'RL-Align:main' into feat/ws1-rms_norm-pytorch-op
maxiaosong1124 64e4594
Merge branch 'RL-Align:main' into feat/ws1-rms_norm-pytorch-op
maxiaosong1124 64f6f56
test(ws1): address PR #160 review — independent F.rms_norm check + ba…
maxiaosong1124 fb6cfc5
Merge branch 'main' into feat/ws1-rms_norm-pytorch-op
maxiaosong1124 63ce44c
Merge branch 'main' into feat/ws1-rms_norm-pytorch-op
maxiaosong1124 e825c31
fix(lint): add space after # in registry.py comment (E265)
maxiaosong1124 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2026 RL-Kernel Contributors | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class NativeRMSNormOp: | ||
| """ | ||
| Pure Pytorch native RMSNorm reference | ||
| out = x * rsqrt(mean(x^2, dim=-1) + eps) * weight | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| pass | ||
|
|
||
| def __call__( | ||
| self, | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| *, | ||
| eps: float = 1e-6, | ||
| ) -> torch.Tensor: | ||
| return self.forward(x, weight, eps=eps) | ||
|
|
||
| def forward( | ||
| self, | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| *, | ||
| eps: float = 1e-6, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Canonical entry: accumulate in fp32, cast the result back to x.dtype. | ||
| This is the dtype-behavior path used as the Axis-B accuracy candidate. | ||
| """ | ||
| return self._rms_norm(x, weight, eps=eps, output_dtype=x.dtype) | ||
|
|
||
| def forward_fp32( | ||
| self, | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| *, | ||
| eps: float = 1e-6, | ||
| ) -> torch.Tensor: | ||
| """Ground-truth: accumulate in fp32 and force fp32 output.""" | ||
| return self._rms_norm(x, weight, eps=eps, output_dtype=torch.float32) | ||
|
|
||
| # ------------------------------------------------------------------ # | ||
| # Helpers | ||
| # ------------------------------------------------------------------ # | ||
| @staticmethod | ||
| def _rms_norm( | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| *, | ||
| eps: float, | ||
| output_dtype: torch.dtype, | ||
| ) -> torch.Tensor: | ||
| if weight.dim() != 1 or weight.shape[0] != x.shape[-1]: | ||
| raise ValueError( | ||
| f"weight must be 1-D of size x.shape[-1]={x.shape[-1]}, " | ||
| f"got tuple(weight.shape)={tuple(weight.shape)}" | ||
| ) | ||
| x_f = x.float() | ||
| var = x_f.pow(2).mean(dim=-1, keepdim=True) | ||
|
Flink-ddd marked this conversation as resolved.
|
||
| normed = x_f * torch.rsqrt(var + eps) | ||
| out = normed * weight.float() | ||
| return out.to(output_dtype) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2026 RL-Kernel Contributors | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from rl_engine.kernels.ops.pytorch.norm.rms_norm import NativeRMSNormOp | ||
|
|
||
| # Qwen3-8B normalized dims this op must cover. | ||
| _HIDDEN = 4096 # input / post-attention norm | ||
| _HEAD_DIM = 128 # QK-Norm (per-head RMSNorm on Q and K) | ||
| _EPS = 1e-6 | ||
|
|
||
|
|
||
| # Shared helpers | ||
| def _rand(shape, *, seed, dtype=torch.float32): | ||
| gen = torch.Generator().manual_seed(seed) | ||
| return torch.randn(*shape, generator=gen, dtype=dtype) | ||
|
|
||
|
|
||
| def _manual_rms_norm(x, weight, *, eps=_EPS): | ||
| """Independent hand-written fp32 reference (NOT the op under test).""" | ||
| x_f = x.float() | ||
| var = x_f.pow(2).mean(dim=-1, keepdim=True) | ||
| return x_f * torch.rsqrt(var + eps) * weight.float() | ||
|
|
||
|
|
||
| # 1. Primary correctness check vs PyTorch's own F.rms_norm. This is a *truly* | ||
| # independent implementation -- it may reduce in a different float order than | ||
| # our op, so a shared formula bug (eps placement, wrong reduction dim) cannot | ||
| # hide here. Tolerance-based (assert_close), NOT torch.equal, precisely because | ||
| # the reduction order is allowed to differ. | ||
| @pytest.mark.parametrize("N", [_HIDDEN, _HEAD_DIM]) | ||
| def test_forward_fp32_matches_torch_reference(N): | ||
| op = NativeRMSNormOp() | ||
| x, w = _rand((2, 16, N), seed=0), _rand((N,), seed=1) | ||
| ref = F.rms_norm(x.float(), (N,), weight=w.float(), eps=_EPS) | ||
| torch.testing.assert_close(op.forward_fp32(x, w), ref, rtol=1e-6, atol=1e-6) | ||
|
|
||
|
|
||
| # 1b. Secondary sanity check vs a hand-written fp32 formula in the same float | ||
| # order -> bitwise equal. Pins the exact reference semantics; the F.rms_norm | ||
| # test above is the independent guard against a formula bug. | ||
| @pytest.mark.parametrize("N", [_HIDDEN, _HEAD_DIM]) | ||
| def test_forward_fp32_matches_manual_reference(N): | ||
| op = NativeRMSNormOp() | ||
| x, w = _rand((2, 16, N), seed=0), _rand((N,), seed=1) | ||
| assert torch.equal(op.forward_fp32(x, w), _manual_rms_norm(x, w)) | ||
|
Flink-ddd marked this conversation as resolved.
|
||
|
|
||
|
|
||
| # 2. Axis A -- batch invariance, bitwise (the WS1 "aligned" property) | ||
| @pytest.mark.parametrize("N", [_HIDDEN, _HEAD_DIM]) | ||
| def test_batch_invariance_slice(N): | ||
| """A row's output must not depend on how many rows share the batch.""" | ||
| op = NativeRMSNormOp() | ||
| w, x = _rand((N,), seed=1), _rand((8, 32, N), seed=2) | ||
| full = op.forward_fp32(x, w) # compute on full batch... | ||
| assert torch.equal(op.forward_fp32(x[:1], w), full[:1]) # ...then slice | ||
| assert torch.equal(op.forward_fp32(x[3:5], w), full[3:5]) | ||
|
|
||
|
|
||
| def test_batch_invariance_with_padding(): | ||
| """Padding extra rows must not perturb the real rows (bitwise).""" | ||
| op = NativeRMSNormOp() | ||
| w = _rand((_HIDDEN,), seed=1) | ||
| x = _rand((4, _HIDDEN), seed=3) | ||
| padded = torch.cat([x, _rand((6, _HIDDEN), seed=99)], dim=0) | ||
| assert torch.equal(op.forward_fp32(padded, w)[:4], op.forward_fp32(x, w)) | ||
|
|
||
|
|
||
| # 3. dtype behavior -- forward follows input, forward_fp32 forces fp32 | ||
| @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) | ||
| def test_dtype_paths(dtype): | ||
| op = NativeRMSNormOp() | ||
| x = _rand((2, 16, _HIDDEN), seed=4).to(dtype) | ||
| w = _rand((_HIDDEN,), seed=5).to(dtype) | ||
| assert op.forward(x, w).dtype == dtype | ||
| assert op.forward_fp32(x, w).dtype == torch.float32 | ||
|
|
||
|
|
||
| # 4. Axis B -- low-precision forward stays within tolerance of fp32 reference | ||
| @pytest.mark.parametrize( | ||
| "dtype, atol, rtol", | ||
| [(torch.bfloat16, 2e-2, 1.6e-2), (torch.float16, 1e-3, 1e-3)], | ||
| ) | ||
| def test_low_precision_within_tolerance(dtype, atol, rtol): | ||
| op = NativeRMSNormOp() | ||
| x, w = _rand((4, 64, _HIDDEN), seed=6), _rand((_HIDDEN,), seed=7) | ||
| ref = op.forward_fp32(x, w) | ||
| got = op.forward(x.to(dtype), w.to(dtype)).float() | ||
| assert torch.allclose(got, ref, atol=atol, rtol=rtol) | ||
|
|
||
|
|
||
| # 5. eps lives INSIDE the sqrt: zero input -> finite (zero) output | ||
| def test_eps_inside_sqrt(): | ||
| op = NativeRMSNormOp() | ||
| out = op.forward_fp32(torch.zeros(1, _HIDDEN), torch.ones(_HIDDEN)) | ||
| assert torch.isfinite(out).all() and torch.equal(out, torch.zeros(1, _HIDDEN)) | ||
|
|
||
|
|
||
| # 6. Plain weight scaling, NOT the (1 + weight) variant | ||
| def test_weight_scaling_no_plus_one(): | ||
| op = NativeRMSNormOp() | ||
| x = _rand((2, _HEAD_DIM), seed=8) | ||
| base = op.forward_fp32(x, torch.ones(_HEAD_DIM)) | ||
| doubled = op.forward_fp32(x, torch.full((_HEAD_DIM,), 2.0)) | ||
| assert torch.allclose(doubled, 2.0 * base, atol=1e-5) | ||
|
|
||
|
|
||
| # 7. Shape guard fires | ||
| def test_bad_weight_shape_raises(): | ||
| op = NativeRMSNormOp() | ||
| x = _rand((2, _HIDDEN), seed=9) | ||
| with pytest.raises(ValueError): | ||
| op.forward_fp32(x, _rand((_HEAD_DIM,), seed=10)) # 128 != 4096 | ||
| with pytest.raises(ValueError): | ||
| op.forward_fp32(x, _rand((1, _HIDDEN), seed=10)) # not 1-D | ||
|
|
||
|
|
||
| # 8. Purity -- inputs not mutated in-place | ||
| def test_inputs_not_mutated(): | ||
| op = NativeRMSNormOp() | ||
| x, w = _rand((2, _HIDDEN), seed=11), _rand((_HIDDEN,), seed=12) | ||
| xc, wc = x.clone(), w.clone() | ||
| op.forward(x, w) | ||
| op.forward_fp32(x, w) | ||
| assert torch.equal(x, xc) and torch.equal(w, wc) | ||
|
|
||
|
|
||
| # 9. Gradient flows (fp32 autograd = backward golden source) | ||
| def test_gradient_flows(): | ||
| op = NativeRMSNormOp() | ||
| x = _rand((2, _HIDDEN), seed=13).requires_grad_(True) | ||
| w = _rand((_HIDDEN,), seed=14).requires_grad_(True) | ||
| op.forward_fp32(x, w).sum().backward() | ||
| assert torch.isfinite(x.grad).all() and torch.isfinite(w.grad).all() | ||
|
|
||
|
Flink-ddd marked this conversation as resolved.
|
||
|
|
||
| # 9b. Axis A for gradients -- backward must be batch-invariant too (needed for | ||
| # #153). Slicing the batch must yield bitwise-identical input gradients to the | ||
| # full-batch backward. Compute on the full batch, then compare against a | ||
| # batch-of-1 recompute fed the matching slice of the upstream gradient. | ||
| def test_backward_batch_invariance_slice(): | ||
| op = NativeRMSNormOp() | ||
|
|
||
| w_full = _rand((_HIDDEN,), seed=1).requires_grad_(True) | ||
| x_full = _rand((8, 32, _HIDDEN), seed=2).requires_grad_(True) | ||
| out_full = op.forward_fp32(x_full, w_full) | ||
| dy_full = _rand(out_full.shape, seed=3) | ||
| out_full.backward(dy_full) | ||
| grad_x_full_sliced = x_full.grad[:1].clone() | ||
|
|
||
| w_slice = _rand((_HIDDEN,), seed=1).requires_grad_(True) | ||
| x_slice = _rand((8, 32, _HIDDEN), seed=2)[:1].detach().requires_grad_(True) | ||
| out_slice = op.forward_fp32(x_slice, w_slice) | ||
| out_slice.backward(dy_full[:1]) # matching slice of the upstream gradient | ||
|
|
||
| assert torch.equal(x_slice.grad, grad_x_full_sliced) | ||
|
Flink-ddd marked this conversation as resolved.
|
||
|
|
||
|
|
||
| # 10. Registry dispatch resolves to the native op | ||
| def test_registry_dispatches_rms_norm(): | ||
| from rl_engine.kernels.registry import kernel_registry | ||
|
|
||
| op = kernel_registry.get_op("rms_norm") | ||
| assert isinstance(op, NativeRMSNormOp) | ||
| assert hasattr(op, "forward") and hasattr(op, "forward_fp32") | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.