Skip to content

Small refactor for compile_ops control flow for better performance and readability#2161

Open
SampoAMD wants to merge 1 commit intoROCm:mainfrom
SampoAMD:refactor_get_module_control_flow
Open

Small refactor for compile_ops control flow for better performance and readability#2161
SampoAMD wants to merge 1 commit intoROCm:mainfrom
SampoAMD:refactor_get_module_control_flow

Conversation

@SampoAMD
Copy link
Copy Markdown

@SampoAMD SampoAMD commented Mar 3, 2026

Small refactor for compile_ops control flow for better performance and readability. Please correct me if there are cases where this change does not apply!

Currently all mha forward operations with a gen_func going through the compile_ops decorator will go through the following steps in the code snippet below:

if module is None:
    try:
        module = get_module(md_name)
    except Exception:
        md = custom_build_args.get("md_name", md_name)
        module = get_module(md)
  1. Call get_module with module_mha_fwd which will lead to an exception being raised by:
__mds[md_name] = importlib.import_module(f"{__package__}.{md_name}")
  1. After that it will catch the exception and call get_module with the correct generated name.

Motivation

The problem with the current implementation is twofold:

  1. The exception is raised for every single flash_attn_func call and it adds a considerable overhead for each flash_attn_func call that uses gen_func. => In my simple test the launch overhead for a vanilla flash_attn_func call overhead goes from about 200us to 100us with this fix.

  2. Relying on exceptions for control flow makes it hard to reason about the code. For example in this case the get_module result should be cache but it keeps getting called everytime with the argument module_mha_fwd

Technical Details

Test Plan

Testing was done with the below simple test python script to verify that the exception path is taken each time. Additionally overhead was verified by taking a pytorch trace from the below snippet. All aiter tests pass after the change

import torch
import aiter    

if __name__ == "__main__":

    torch.manual_seed(1234)
    device = torch.device("cuda")
    dtype = torch.bfloat16
    batch_size, seqlen, num_heads, head_dim = 1, 128, 8, 64
    q = torch.randn(batch_size, seqlen, num_heads, head_dim, device=device, dtype=dtype)
    k = torch.randn_like(q)
    v = torch.randn_like(q)

    for _ in range(5):
        out = aiter.flash_attn_func(q, k, v)

Test Result

Tested on a machine with MI300X GPU.
docker image: rocm/pytorch:rocm7.1_ubuntu22.04_py3.10_pytorch_release_2.8.0

On this simple test the launch overhead goes from around 200us => 100us for each flash_attn_func call.
on a similar test backward call goes from 224us => 135us

Submission Checklist

…nd less overhead

Signed-off-by: Sampo Immonen <sampo.immonen@amd.com>
@SampoAMD SampoAMD linked an issue Mar 3, 2026 that may be closed by this pull request
@SampoAMD SampoAMD marked this pull request as ready for review March 3, 2026 13:29
@SampoAMD SampoAMD requested a review from a team March 3, 2026 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Large launch overhead for aiter.flash_attn_func

1 participant