Small refactor for compile_ops control flow for better performance and readability#2161
Open
Small refactor for compile_ops control flow for better performance and readability#2161
Conversation
…nd less overhead Signed-off-by: Sampo Immonen <sampo.immonen@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Small refactor for compile_ops control flow for better performance and readability. Please correct me if there are cases where this change does not apply!
Currently all mha forward operations with a gen_func going through the compile_ops decorator will go through the following steps in the code snippet below:
Motivation
The problem with the current implementation is twofold:
The exception is raised for every single flash_attn_func call and it adds a considerable overhead for each flash_attn_func call that uses gen_func. => In my simple test the launch overhead for a vanilla flash_attn_func call overhead goes from about 200us to 100us with this fix.
Relying on exceptions for control flow makes it hard to reason about the code. For example in this case the get_module result should be cache but it keeps getting called everytime with the argument module_mha_fwd
Technical Details
Test Plan
Testing was done with the below simple test python script to verify that the exception path is taken each time. Additionally overhead was verified by taking a pytorch trace from the below snippet. All aiter tests pass after the change
Test Result
Tested on a machine with MI300X GPU.
docker image: rocm/pytorch:rocm7.1_ubuntu22.04_py3.10_pytorch_release_2.8.0
On this simple test the launch overhead goes from around 200us => 100us for each flash_attn_func call.
on a similar test backward call goes from 224us => 135us
Submission Checklist