feat(ws1): NativeRMSNormOp pure-PyTorch ground-truth reference + numerical contract tests#160
Conversation
- NativeRMSNormOp with forward / forward_fp32 (fp32 ground-truth path) - covers both normalized dims: hidden=4096 and head_dim=128 (Qwen3 QK-Norm) - register PYTORCH_NATIVE_RMS_NORM in OpBackend + cpu/cuda/rocm priority map - tests/test_rms_norm.py: axis-A bitwise batch invariance + dtype tolerance, shape guard, purity, gradient flow, registry dispatch (16 tests) Refs RL-Align#108
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds ChangesNativeRMSNormOp: Implementation, Registry Wiring, and Tests
Possibly Related Issues
Suggested Reviewers
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/kernels/ops/pytorch/norm/rms_norm.py`:
- Around line 5-6: The file rms_norm.py has linting and formatting violations
from black, isort, and trailing-whitespace checks that are blocking CI. Run the
project's pre-commit formatting hooks (typically via a command like 'pre-commit
run --all-files' or 'black' and 'isort' individually) to automatically reformat
the file and fix signature spacing issues around lines 25-31, expression
formatting issues around lines 38-44 and 57-67, and any trailing whitespace
violations. Commit the reformatted file after the hooks complete.
In `@tests/test_rms_norm.py`:
- Around line 10-11: The test_rms_norm.py file has formatting inconsistencies
detected by Black. Run the Black formatter on this file to automatically fix
alignment and spacing issues in inline comments (like those on the _HIDDEN and
_HEAD_DIM constant definitions) and long assertions (around the test assertion
blocks). Apply Black's output and commit the formatted result to resolve the CI
formatting check.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 34ccd0aa-23c0-4f09-8d2a-e7d333cea8a2
📒 Files selected for processing (3)
rl_engine/kernels/ops/pytorch/norm/rms_norm.pyrl_engine/kernels/registry.pytests/test_rms_norm.py
Resolve CodeRabbit formatting findings on RL-Align#160: black (line-length=100), isort (profile=black), trailing-whitespace and EOF fixes. No logic change; 16 tests still pass.
KJLdefeated
left a comment
There was a problem hiding this comment.
Implementation looks correct — standard RMSNorm, fp32 accumulation, eps inside the sqrt, plain weight scaling, clean shape guard. Registry wiring matches the existing attn pattern. Happy to approve once request changes are addressed.
…eck + backward batch-invariance - KJLdefeated: add test_forward_fp32_matches_torch_reference comparing against PyTorch's F.rms_norm via assert_close (tolerance, not torch.equal, since its reduction order may differ) so a shared formula bug can't pass green. Keep the hand-written _manual_rms_norm test as a secondary bitwise sanity check. - Flink-ddd: add test_backward_batch_invariance_slice proving input gradients are bitwise identical regardless of batch size (Axis-A for gradients, needed for RL-Align#153).
|
@KJLdefeated You're right — _manual_rms_norm uses the same expression and float order as the op, so torch.equal only proves they're copies of each other, not that the formula is correct; a shared bug (eps placement, wrong reduction dim) would pass green on both sides. Addressed in 64f6f56:
Thanks for the careful catch! |
@Flink-ddd Good point — forward batch-invariance isn't enough; WS1 (and #153) needs the backward to be batch-invariant too, and test_gradient_flows only checked isfinite. Addressed in 64f6f56: Added test_backward_batch_invariance_slice following your example — it runs the full-batch forward+backward, then a batch-of-1 recompute fed the matching slice of the upstream gradient, and asserts torch.equal(x_slice.grad, grad_x_full_sliced) so the input gradient is bitwise identical regardless of batch size. On your other note (var = x_f.pow(2).mean(...) materializing a full fp32 tensor): left as-is for this PR, since you flagged it as no-change-required — readability over VRAM for the reference op. I've noted it as a watch-item for the downstream Triton/CUDA kernels on long-context workloads. All 19 tests in tests/test_rms_norm.py pass and black --check is clean. Thanks for the review! |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/test_rms_norm.py`:
- Around line 144-159: The backward invariance test only covers the `_HIDDEN`
normalization width, so it should be expanded to validate both Qwen3 RMSNorm
widths. Update `test_backward_batch_invariance_slice` to run the same
forward/backward slice comparison for both `_HIDDEN` and `_HEAD_DIM` (for
example via parametrization), keeping the existing `NativeRMSNormOp`, `_rand`,
and gradient slice assertions unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5bc24ea0-96eb-4323-a0b7-59f424add2dc
📒 Files selected for processing (2)
rl_engine/kernels/registry.pytests/test_rms_norm.py
🚧 Files skipped from review as they are similar to previous changes (1)
- rl_engine/kernels/registry.py
Flink-ddd
left a comment
There was a problem hiding this comment.
LGTM now, Thank you for update.
|
the GPU CI error is not releate this PR, it's out of GPU resource. but please resolve CI-Pipeline error. also cc @KJLdefeated PTAL again. |
@Flink-ddd Resolved !! |
|
Thank you, merged. |
Summary
Adds the pure-PyTorch ground-truth reference op for RMSNorm (pre-norm / QK-Norm)
as the first WS1 batch-invariant operator built on top of the numerical contract
defined in #108. Ships the op, its registry wiring, and a 16-case test suite that
pins down both alignment axes (Axis-A bitwise batch invariance, Axis-B per-dtype
tolerance).
Refs #108
Terminology
This PR uses the WS1 alignment vocabulary from #108:
how many rows share the batch (batch size, slicing, padding). Asserted bitwise
(
torch.equal). This is what keeps train-time (large batch) and sample-time(small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
documented per-dtype tolerance of the fp32 ground-truth. Asserted with
torch.allcloseMotivation / Context
#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. RMSNorm is required on two normalized dims of the
target model (Qwen3-8B dense):
hidden = 4096— input / post-attention normhead_dim = 128— QK-Norm (per-head RMSNorm on Q and K)This PR provides the deterministic fp32 reference path those downstream kernels
(Triton / CUDA / ROCm RMSNorm) will be validated against.
Changes
rl_engine/kernels/ops/pytorch/norm/rms_norm.py—NativeRMSNormOpforward()— accumulate in fp32, cast result back tox.dtype(Axis-B candidate path)forward_fp32()— fp32 accumulation, forced fp32 output (ground-truth / backward golden source)out = x * rsqrt(mean(x^2, dim=-1) + eps) * weightepslives inside the sqrt; plain weight scaling (not the1 + weightvariant)weightmust be 1-D of sizex.shape[-1]rl_engine/kernels/registry.py— registerPYTORCH_NATIVE_RMS_NORMand add
rms_normdispatch to the cuda / rocm / cpu priority mapstests/test_rms_norm.py— 16 tests (details below)How this satisfies the #108 contract
forward_fp32()accumulates in fp32 alongdim=-1; tests use fixed-seedtorch.Generatorso outputs are reproducibletorch.equal); Axis-B asserted within documented per-dtype thresholds — bf16atol=2e-2, rtol=1.6e-2, fp16atol=1e-3, rtol=1e-3hidden=4096andhead_dim=128Test Environment
Testing
Run from the repo root with
python -m pytest(the-mform puts the repo on