-
Notifications
You must be signed in to change notification settings - Fork 267
Implement device grouped gemm fixed nk multi abd for rdna4 #3619
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Implement device grouped gemm fixed nk multi abd for rdna4 #3619
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements support for grouped GEMM with multiple ABD tensors and fixed NK on RDNA4 architecture, specifically for WMMA implementations. The feature was previously only available for XDL implementations.
Changes:
- Added WMMA device operator implementations for grouped GEMM multi ABD with fixed NK
- Unit tests for both new WMMA and existing XDL implementations
- Reference implementation class for verification
- Example code demonstrating WMMA usage patterns
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp | Unit test framework for validating grouped GEMM multi ABD fixed NK implementations |
| test/grouped_gemm/CMakeLists.txt | Build configuration for new unit test |
| profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp | Generic profiler interface for calling unit tests and benchmarking |
| profiler/include/profiler/profile_gemm_multi_abd_impl.hpp | Refactored to use new reference implementation |
| library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp | Commented out failing XDL instances |
| library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp | Commented out failing XDL instances |
| library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp | WMMA instances for MK-NK-MN layout with bias/gelu operations |
| library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp | WMMA instances for MK-KN-MN layout with bias/gelu operations |
| library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp | WMMA instances for KM-KN-MN layout with bias/gelu operations |
| library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt | Build configuration for new WMMA instances |
| library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp | Factory functions for WMMA instances |
| library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp | Reference implementation for grouped GEMM multi ABD verification |
| include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp | Added EDataType_ alias for type access |
| include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | Added hardware support checks and main K block loop validation |
| include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | New WMMA device operator implementation |
| example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp | Example using WMMA with FP16 and bias addition |
| example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp | Example using WMMA with BF16/I8 and bias+GELU |
| example/59_grouped_gemm_multi_ABD/CMakeLists.txt | Build configuration for WMMA examples |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| int main(int argc, char** argv) | ||
| { | ||
| testing::InitGoogleTest(&argc, argv); | ||
| if(argc == 1) {} |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empty conditional block serves no purpose. Remove this branch or add a comment explaining why argc == 1 is explicitly handled (e.g., "use default parameters").
| if(argc == 1) {} | |
| if(argc == 1) | |
| { | |
| // use default parameters when no extra arguments are provided | |
| } |
| auto ref_invoker = ref_gemm.MakeInvoker(); | ||
|
|
||
| auto ref_argument = | ||
| ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reference GEMM arguments use variables a_m_k, b_k_n, and c_m_n that are no longer defined in this scope after the refactoring. These variables were computed within the removed reference computation code and need to be generated by the new ReferenceGemmMultiABD class.
| // this function is also defined in CK but because of the way we use it in | ||
| // profile_gemm_multi_impl, it requires the arguments to not be const |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add documentation explaining why this duplicate function definition exists and how it differs from the CK version. The comment on line 20-21 is insufficient - it should explain the const-correctness difference and the implications for usage in profile_gemm_multi_impl.
| // this function is also defined in CK but because of the way we use it in | |
| // profile_gemm_multi_impl, it requires the arguments to not be const | |
| // NOTE: | |
| // This helper intentionally duplicates `concat_tuple_of_refs` from the core CK utilities, | |
| // but with a different const-correctness contract on its arguments: | |
| // | |
| // - The CK version is defined to operate on (typically) const-qualified tuples of | |
| // references; its parameters are more permissive and can accept `const Tuple<...>&`. | |
| // - This host-side overload is deliberately restricted to *non-const* tuples of | |
| // references: `ck::Tuple<X&...>&` and `ck::Tuple<Y&...>&`. | |
| // | |
| // In `profile_gemm_multi_impl`, we need to concatenate tuples that contain non-const | |
| // references to tensors/buffers so that: | |
| // * The resulting concatenated tuple preserves non-const reference semantics, allowing | |
| // the profiled kernels and host-side utilities to modify the referenced objects, and | |
| // * Overload resolution / SFINAE continues to select APIs that require non-const | |
| // references (these would reject a const-qualified tuple produced by the CK version). | |
| // | |
| // If this function were replaced by the CK version, the arguments in | |
| // `profile_gemm_multi_impl` could become (or be treated as) const, which would either: | |
| // - Prevent intended mutation of the underlying tensors, or | |
| // - Cause subtle compilation or behavior differences due to const propagation. | |
| // | |
| // For that reason, this duplicate, non-const overload must remain local to the host-side | |
| // GEMM multi reference implementation and should not be "simplified" by switching to the | |
| // CK variant without carefully revisiting `profile_gemm_multi_impl` and its call sites. |
| { | ||
| if(arg.grouped_gemm_kernel_args_dev == nullptr) | ||
| { | ||
| throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'nullpr' to 'nullptr'.
| throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); | |
| throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr"); |
| { | ||
| printf("arg1: verification (0=no, 1=yes)\n"); | ||
| printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); | ||
| printf("arg3: time kernel (0=n0, 1=yes)\n"); |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'n0' to 'no'.
| printf("arg3: time kernel (0=n0, 1=yes)\n"); | |
| printf("arg3: time kernel (0=no, 1=yes)\n"); |
| { | ||
| printf("arg1: verification (0=no, 1=yes)\n"); | ||
| printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); | ||
| printf("arg3: time kernel (0=n0, 1=yes)\n"); |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'n0' to 'no'.
| printf("arg3: time kernel (0=n0, 1=yes)\n"); | |
| printf("arg3: time kernel (0=no, 1=yes)\n"); |
Proposed changes
Add support for grouped gemm multi ABD fixed NK. MR contains:
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK).59_grouped_gemm_multi_ABDNote: Some Xdl instances were commented out because of unit test failures. As mentioned apparently for xdl this feature was missing tests so our assumption is either there is an implemenetation bug or these instances were not set up correctly. Has the potential for a follow-up issue.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion