-
Notifications
You must be signed in to change notification settings - Fork 267
[CK_TILE][FMHA] Enable wholek_prefetch #3026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
Signed-off-by: JL-underdog <Jun.Lin@amd.com>
|
@LJ-underdog are you still working on this? |
yes, this pr is prepare for mate |
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enables a new whole-K prefetch pipeline variant for the FMHA (Flash Multi-Head Attention) operations. The change introduces a new pipeline type to optimize performance through whole-K prefetching, specifically for the hdim=128, hdim_v=128 configuration.
Changes:
- Added new pipeline enum
QRKSVS_WHOLEK_PREFETCHwith corresponding string mappings - Updated C++ pipeline implementation name field to match the new pipeline identifier
- Extended Python codegen to generate kernels for the new pipeline with specific constraints
- Added symbol mappings for the new pipeline in codegen infrastructure
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | Updated static name field to identify the pipeline implementation |
| block_fmha_pipeline_enum.hpp | Added new enum value and string mapping template specialization |
| fmha_fwd.py | Added pipeline tag to validation checks, tile size configuration, and generation logic with constraints |
| cpp_symbol_map.py | Added C++ class and enum mappings for the new pipeline |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp
Outdated
Show resolved
Hide resolved
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…_whole_k_prefetch.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
| for tile, pipeline in itertools.product( | ||
| tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) | ||
| ): | ||
| if pipeline.tag == "qr_wholek_prefetch" and ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move the check to the CompatibilityRuleFactoryGfx9
Refactor seqtune method for better readability.
| FmhaFwdPipeline( | ||
| "qr_wholek_prefetch", | ||
| "row", | ||
| "f", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't add padding for both of the seqlen_q & seqlen_k dimensions?
|
Hi @LJ-underdog , please make sure this thing compiles and passes tests! |
Proposed changes
Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered