[WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd)#180
[WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd)#180Flink-ddd wants to merge 4 commits into
Conversation
📝 WalkthroughWalkthroughThis PR adds a deterministic GEMM CUDA path with SM90 TMA support, Python/C++ bindings, Triton and native reference backends, registry wiring, build flags, docs, benchmark artifacts, and CUDA tests covering batch invariance and backward correctness. ChangesDeterministic GEMM rollout
Sequence Diagram(s)sequenceDiagram
participant DetGemmOp
participant _DetGemmFn
participant det_gemm_fwd
participant gemm_dispatch
participant det_gemm_da
participant det_gemm_db
DetGemmOp->>_DetGemmFn: apply(a, b)
_DetGemmFn->>det_gemm_fwd: forward(a, b)
det_gemm_fwd->>gemm_dispatch: choose SM90 or naive
gemm_dispatch-->>det_gemm_fwd: C
_DetGemmFn-->>DetGemmOp: C
_DetGemmFn->>det_gemm_da: grad_out, b
_DetGemmFn->>det_gemm_db: a, grad_out
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: vensen <vensenmu@gmail.com>
e797bc9 to
90d120c
Compare
There was a problem hiding this comment.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
rl_engine/kernels/registry.py (1)
99-112: 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
det_gemmfalls back to the logp backend on CPU.CUDA and ROCm get explicit
det_gemmentries, but CPU does not. Becauseget_op()defaults missing op types toOpBackend.PYTORCH_NATIVE,get_op("det_gemm")on CPU will try to instantiateNativeLogpOpinstead of failing fast.Suggested fix
"cpu": { "logp": [OpBackend.PYTORCH_NATIVE], "attn": [OpBackend.PYTORCH_ATTN], "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.PYTORCH_RATIO_KL], + "det_gemm": [], },🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@rl_engine/kernels/registry.py` around lines 99 - 112, The `KernelRegistry.get_op()` fallback is incorrectly treating missing `det_gemm` as `OpBackend.PYTORCH_NATIVE`, which routes CPU requests to `NativeLogpOp` instead of failing or using a valid `det_gemm` backend. Update the registry/lookup logic so `det_gemm` is explicitly handled for CPU in `KernelRegistry` (or excluded from the generic native fallback), and ensure `get_op("det_gemm")` selects only a real `det_gemm` implementation or raises a clear unsupported-op error.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@benchmarks/benchmark_det_gemm.py`:
- Around line 52-63: The benchmark currently starts timing in run() without
checking that the CUDA backend binding is actually available, so a None handle
can still fail later inside deterministic_gemm as NoneType.det_gemm_fwd. Add an
explicit preflight check before the SHAPES loop to verify the CUDA
wrapper/binding used by native_gemm and deterministic_gemm is initialized, and
if not, exit or skip with a clear message instead of letting _time() trigger the
opaque failure. If Triton is part of the path, gate deterministic_gemm_triton
the same way so the benchmark only runs when the required backend symbols are
present.
In `@csrc/cuda/gemm/det_gemm_kernel.cu`:
- Around line 230-238: The gemm dispatch path currently assumes the active CUDA
device matches the input tensors and does not reject mixed-GPU inputs. Update
check_in() and gemm_dispatch() to verify both tensors are on CUDA, are bf16, and
reside on the same device, then add an explicit device guard in gemm_dispatch()
so the allocation of c and the kernel launch via
at::cuda::getCurrentCUDAStream() are pinned to the inputs’ device. Use the
existing gemm_dispatch, check_in, and stream setup to place the guard and
same-device validation close to the launch site.
In `@docs/operators/det-gemm.md`:
- Around line 28-29: The CUDA backend description in the DetGemm documentation
is outdated and still describes a naive FP32 milestone with deferred tensor-core
work, while the merged implementation already uses the SM90 TMA plus mma.sync
path. Update the wording in the DetGemm operator docs table and any related CUDA
backend references so they accurately reflect the implemented DetGemmOp behavior
and current performance profile, keeping the TritonDetGemmOp description
consistent with the same section.
In `@rl_engine/kernels/ops/cuda/matmul/det_gemm.py`:
- Around line 16-20: The functional deterministic GEMM path is bypassing the
compiled-extension availability guard and can hit _C.det_gemm_fwd when _C is
None. Update _DetGemmFn.forward and deterministic_gemm to perform the same
availability check used by DetGemmOp.__call__ before touching _C, and raise the
intended RuntimeError instead of allowing an AttributeError to surface.
In `@rl_engine/kernels/ops/triton/matmul/det_gemm.py`:
- Around line 101-105: The backward path in det_gemm.backward still computes dB
from Aᵀ @ grad_out, so the reduction depends on the live batch dimension and can
vary with chunking/padding. Update the db calculation to use the WS1 invariant
reduction contract instead of directly reducing over the current batch layout,
and ensure the Triton backward path produces the same dW regardless of batch
size or layout changes. Keep the fix localized around backward and the
_triton_gemm calls in det_gemm.py.
- Around line 68-90: The _triton_gemm helper currently derives K from a.shape
and launches _det_gemm_kernel without verifying that b.shape[0] matches, so add
an explicit inner-dimension check before the kernel call and raise the standard
matmul shape error on mismatch. Keep the validation in _triton_gemm near the
existing M, K, N shape extraction so invalid inputs are rejected before any
Triton launch.
In `@setup.py`:
- Around line 64-71: The SM90 build path still emits a conflicting plain SM90
gencode when KERNEL_ALIGN_DET_GEMM_SM90 is enabled, so update the nvcc flag
assembly in setup.py to avoid adding the compute_90/sm_90 pair on SM90 machines.
Make the logic in the startup flag block and the later SM90 append path
consistent so that enabling KERNEL_ALIGN_DET_GEMM_SM90 alone does not require
also setting KERNEL_ALIGN_FORCE_SM90. Use the existing nvcc_flags construction
and the SM90-specific branch to ensure only the intended 90a gencode is present.
In `@tests/test_det_gemm.py`:
- Around line 25-33: The CUDA backend is being added to _BACKENDS even when
deterministic_gemm is not actually available, which causes crashes before the
assertions run. Update the backend selection near pytestmark/_BACKENDS to check
the wrapper’s real availability signal from deterministic_gemm before appending
the CUDA case, or convert it into a clear module-level skip when the extension
binding is None. Keep the Triton entry gated separately with _HAS_TRITON.
- Around line 78-85: The correctness assertions in test_forward_correctness and
the related tests still use temporary loose thresholds instead of the shared
`#108` contract. Update the checks in the gemm test cases to use the `#108`
correctness harness/thresholds rather than hardcoded placeholder max_abs limits,
and make sure the affected test functions (including test_forward_correctness
and the nearby cases in the same range) validate against the contract
consistently.
---
Outside diff comments:
In `@rl_engine/kernels/registry.py`:
- Around line 99-112: The `KernelRegistry.get_op()` fallback is incorrectly
treating missing `det_gemm` as `OpBackend.PYTORCH_NATIVE`, which routes CPU
requests to `NativeLogpOp` instead of failing or using a valid `det_gemm`
backend. Update the registry/lookup logic so `det_gemm` is explicitly handled
for CPU in `KernelRegistry` (or excluded from the generic native fallback), and
ensure `get_op("det_gemm")` selects only a real `det_gemm` implementation or
raises a clear unsupported-op error.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 0de95808-0f0e-49ba-ba17-f216e838e757
📒 Files selected for processing (17)
benchmarks/benchmark_det_gemm.pybenchmarks/results/det_gemm_h100_tma.mdcsrc/cuda/gemm/det_gemm_kernel.cucsrc/cuda/gemm/det_gemm_tma.cuhcsrc/ops.cppdocs/.nav.ymldocs/operators/det-gemm.mdrl_engine/kernels/ops/cuda/__init__.pyrl_engine/kernels/ops/cuda/matmul/__init__.pyrl_engine/kernels/ops/cuda/matmul/det_gemm.pyrl_engine/kernels/ops/pytorch/matmul/__init__.pyrl_engine/kernels/ops/pytorch/matmul/det_gemm.pyrl_engine/kernels/ops/triton/matmul/__init__.pyrl_engine/kernels/ops/triton/matmul/det_gemm.pyrl_engine/kernels/registry.pysetup.pytests/test_det_gemm.py
Fixes #146
Purpose
In RL train–inference alignment, a token's logprob must not drift between rollout (vLLM-side, varying batch shapes) and training (Megatron-side). A major source of this drift is the GEMM: cuBLAS selects kernels by problem shape and may use split-K, so the K-reduction order — and therefore the low bits of the output changes with batch size, chunked-prefill splitting, and padding layout. The same row produces different results depending on the batch it rides in.
This PR implements a batch-invariant deterministic GEMM (forward + backward), one op in the WS1 forward chain (#146). A row's output is bitwise invariant to batch size, chunked-prefill splitting, and padding layout. Invariance is achieved by pinning the tile shape, fixing the K-accumulation order, and forbidding split-K / shape-based kernel selection. BF16 inputs, FP32 accumulation, no TF32.
This is PR2 of the planned series (design note → kernel → tests → LM-head wiring → benchmark). Scope here: the kernel(s), op wiring, and invariance tests.
Invariance contract. The same (M, N, K) row produces the same output regardless of the surrounding batch, chunked-prefill split, or padding. Two design points protect this:
Backends. CUDA (SM90 TMA + mma.sync tensor cores; naive FP32 scalar fallback below SM90), Triton (portable / ROCm fallback + cross-backend invariance eference), and a PyTorch reference (torch.matmul, intentionally non-deterministic — reference/benchmark only, deliberately excluded from dispatch, since a non-deterministic fallback would silently defeat the op's purpose).
Out of scope (per #146): tensor-parallel GEMM (WS2), FP8, and occupancy/throughput tuning of the tensor-core path — correctness and invariance first; a slower-than-cuBLAS deterministic baseline is the accepted first milestone. Throughput tuning will be a separate perf PR (see Follow-ups).
Test Plan
on both the CUDA tensor-core kernel and the Triton path.
Test Result
Build: compiled and linked successfully with the SM90 TMA + mma.sync path enabled (CUDA 12.4, H100). det_gemm uses its own csrc/cuda/gemm/det_gemm_tma.cuh (shared::cluster); the shared csrc/utils/tma_utils.cuh emits shared::cta, which CUDA 12.4 ptxas rejects for cp.async.bulk.tensor — scoped the fix to this PR rather than touching the shared logp helpers (tracked separately in #). The
KERNEL_ALIGN_DET_GEMM_SM90build flag is independent of the fused_logp SM90 sources, which are left untouched.Tests: pytest tests/test_det_gemm.py → 22/22 passed on H100 SXM (CUDA 12.4) — forward/backward batch-invariance, chunked-prefill,
padding-invariance (all bitwise), correctness vs FP32, and all 5 target
projection shapes, on both the CUDA tensor-core kernel and Triton.
Benchmark (NVIDIA H100 80GB HBM3, SM90): overhead = det CUDA vs cuBLAS (TF32 disabled). Both deterministic paths trade speed for bitwise invariance by design (no split-K, fixed accumulation, FP32, no TF32).
(The det CUDA path uses SM90 TMA +
mma.synctensor cores, 128×128 tile, single-CTA-per-tile, no split-K. Occupancy/throughput tuning is deferred per #146 and will be an ncu-driven perf PR; this milestone is correctness- and invariance-first.)Files
Follow-ups
Summary by CodeRabbit
New Features
Bug Fixes