Skip to content

add fused_qknorm hip kernel#2442

Merged
valarLip merged 9 commits intoROCm:mainfrom
XiaobingSuper:xiaobing/fuse_qk_norm
Mar 25, 2026
Merged

add fused_qknorm hip kernel#2442
valarLip merged 9 commits intoROCm:mainfrom
XiaobingSuper:xiaobing/fuse_qk_norm

Conversation

@XiaobingSuper
Copy link
Contributor

Motivation

Add a fused_qknorm hip kernel to improve kimi-2 model performance/

Technical Details

Fused two rmsnorm to a fused norm for a smaller number of tokens, which can improve decode performance/

Test Plan

python test_fused_qk_norm.py -d bf16 -m 2 4 8 16 32 64 128 256 512 1024 2048 4096 8192 16384 32768 65536 -n1 1536 -n2 512

Test Result

image

Submission Checklist

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a new HIP fused Q/K RMSNorm kernel and wires it through the ROCm extension + Python API to improve decode performance for Kimi-2–like shapes (small token counts), along with a new benchmark script.

Changes:

  • Add fused_qk_rmsnorm HIP kernel (2D grid, Q and K computed in parallel).
  • Expose the kernel via pybind and a Python wrapper with a conditional fallback path.
  • Update JIT build config for the existing fused QK norm/RoPE/cache module and add a new op benchmark.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
op_tests/test_fused_qk_norm.py Adds a benchmark/accuracy script for fused vs split RMSNorm on Q/K.
csrc/kernels/fused_qk_norm.cu Implements the new fused QK RMSNorm HIP kernel + C++ entrypoint.
csrc/include/rocm_ops.hpp Exposes fused_qk_rmsnorm via pybind in the existing fused QK norm/RoPE/cache module macro.
csrc/include/fused_qk_norm_rope_cache_quant.h Declares the new fused_qk_rmsnorm API for pybind visibility.
aiter/ops/fused_qk_norm_rope_cache_quant.py Adds the Python wrapper and compile_ops binding for fused_qk_rmsnorm.
aiter/jit/optCompilerConfig.json Extends the module’s sources/includes/flags to compile and link the new kernel.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@XiaobingSuper XiaobingSuper force-pushed the xiaobing/fuse_qk_norm branch from ab05d72 to 93da489 Compare March 24, 2026 07:14
@gbyu-amd gbyu-amd requested review from ganyi1996ppo and xytpai March 24, 2026 07:23
@wuhuikx wuhuikx requested a review from valarLip March 24, 2026 07:55
@valarLip
Copy link
Collaborator

nice, could you please help finish this refactor ROCm/ATOM#342

@XiaobingSuper
Copy link
Contributor Author

nice, could you please help finish this refactor ROCm/ATOM#342

Yes, I will add a PR in ATOM side to add this fused_op based on [ROCm/ATOM#342].

wuhuikx
wuhuikx previously approved these changes Mar 25, 2026
int q_in_stride = q.stride(0);
int k_in_stride = k.stride(0);

auto q_out = torch::empty({m, q_n}, q.options());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe expose output to user end, so we can have stride tensor for some opt cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added output option.

@valarLip valarLip merged commit 149a8f6 into ROCm:main Mar 25, 2026
23 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants