Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a new HIP fused Q/K RMSNorm kernel and wires it through the ROCm extension + Python API to improve decode performance for Kimi-2–like shapes (small token counts), along with a new benchmark script.
Changes:
- Add
fused_qk_rmsnormHIP kernel (2D grid, Q and K computed in parallel). - Expose the kernel via pybind and a Python wrapper with a conditional fallback path.
- Update JIT build config for the existing fused QK norm/RoPE/cache module and add a new op benchmark.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_fused_qk_norm.py | Adds a benchmark/accuracy script for fused vs split RMSNorm on Q/K. |
| csrc/kernels/fused_qk_norm.cu | Implements the new fused QK RMSNorm HIP kernel + C++ entrypoint. |
| csrc/include/rocm_ops.hpp | Exposes fused_qk_rmsnorm via pybind in the existing fused QK norm/RoPE/cache module macro. |
| csrc/include/fused_qk_norm_rope_cache_quant.h | Declares the new fused_qk_rmsnorm API for pybind visibility. |
| aiter/ops/fused_qk_norm_rope_cache_quant.py | Adds the Python wrapper and compile_ops binding for fused_qk_rmsnorm. |
| aiter/jit/optCompilerConfig.json | Extends the module’s sources/includes/flags to compile and link the new kernel. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
ab05d72 to
93da489
Compare
|
nice, could you please help finish this refactor ROCm/ATOM#342 |
Yes, I will add a PR in ATOM side to add this fused_op based on [ROCm/ATOM#342]. |
csrc/kernels/fused_qk_norm.cu
Outdated
| int q_in_stride = q.stride(0); | ||
| int k_in_stride = k.stride(0); | ||
|
|
||
| auto q_out = torch::empty({m, q_n}, q.options()); |
There was a problem hiding this comment.
maybe expose output to user end, so we can have stride tensor for some opt cases
There was a problem hiding this comment.
Added output option.
91be957 to
19ac005
Compare
Motivation
Add a fused_qknorm hip kernel to improve kimi-2 model performance/
Technical Details
Fused two rmsnorm to a fused norm for a smaller number of tokens, which can improve decode performance/
Test Plan
Test Result
Submission Checklist