Skip to content

mha: make impl_ptr_map caceh in fmha_v3_bwd and fmha_fwd_v3 thread-safe.#2201

Open
xinyazhang wants to merge 5 commits intoROCm:mainfrom
xinyazhang:xinyazhang/make_mha_cache_thread_safe
Open

mha: make impl_ptr_map caceh in fmha_v3_bwd and fmha_fwd_v3 thread-safe.#2201
xinyazhang wants to merge 5 commits intoROCm:mainfrom
xinyazhang:xinyazhang/make_mha_cache_thread_safe

Conversation

@xinyazhang
Copy link

@xinyazhang xinyazhang commented Mar 6, 2026

Motivation

fmha_v3_bwd is not thread-safe due to missing locks in internal cache object.

Technical Details

The cache object impl_ptr_map in fmha_v3_bwd and fmha_fwd_v3 is updated without any locking.

fmha_fwd_v3 uses thread_local to ensures its correctness but defeats the purpose of caching in multi-threading environment.

The PR makes both cache objects shared by all threads and protected by locks properly.

Test Plan

There is no testing planned at the moment due to difficult to validate if a function is thread-safe with unit tests.

Test Result

N/A

Submission Checklist

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to make the internal asm-kernel cache (impl_ptr_map) used by fmha_fwd_v3 and fmha_v3_bwd safe under multi-threaded use by adding locking around cache updates.

Changes:

  • Add <mutex> includes and introduce a static mutex to guard impl_ptr_map updates.
  • Wrap impl_ptr_map.emplace(...) + potential AiterAsmKernel construction in a lock in both forward and backward paths.
  • Introduce a LOCK_IMPL_PTR_MAP macro to standardize the lock acquisition.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
csrc/cpp_itfs/mha_fwd.cu Adds a mutex and locks around a thread_local kernel cache in fmha_fwd_v3.
csrc/cpp_itfs/mha_bwd.cu Adds a mutex and locks around a shared static kernel cache in fmha_v3_bwd.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +244 to +262
static std::mutex impl_ptr_mutex;
static std::unordered_map<std::string, std::unique_ptr<AiterAsmKernel>>
impl_ptr_map;
#define LOCK_IMPL_PTR_MAP std::lock_guard<std::mutex> lock(impl_ptr_mutex)

const auto& cfg = it->second;
const char* name = cfg.knl_name.c_str();
std::string co_name = get_kernel_co_name(cfg.co_name, arch_id);

auto result = impl_ptr_map.emplace(name, nullptr);
if(result.second)
{
LOCK_IMPL_PTR_MAP;
auto result = impl_ptr_map.emplace(name, nullptr);
if(result.second)
{
result.first->second = std::make_unique<AiterAsmKernel>(name, co_name.c_str());
}
impl_ptr = result.first->second.get();
}
impl_ptr = result.first->second.get();
#undef LOCK_IMPL_PTR_MAP
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid introducing a preprocessor macro for locking (LOCK_IMPL_PTR_MAP). Macros inside a function make the code harder to read/debug and can accidentally conflict with identifiers in future edits. Prefer a normal RAII lock (e.g., a std::lock_guard / std::scoped_lock declared in the intended scope) or extract a small helper that returns the cached kernel pointer while handling locking internally.

Copilot uses AI. Check for mistakes.
@draganmladjenovic
Copy link
Contributor

This breaks XLA. Where tahat thread_local as bad as it is allows us to lead kernel per gpu.

@xinyazhang
Copy link
Author

xinyazhang commented Mar 6, 2026

This breaks XLA. Where tahat thread_local as bad as it is allows us to lead kernel per gpu.

Okay then I keep the fwd unchanged and only guards bwd.

Btw I think AiterAsmKernel::launch_kernel should ensure kernel is loaded to current stream's device? If this is a concern I can file another PR to improve AiterAsmKernel::launch_kernel

@draganmladjenovic
Copy link
Contributor

draganmladjenovic commented Mar 6, 2026

I mean we use both. But you are correct launch should ensure kernel is loaded on current GPU. You can have example how to do so here .https://github.com/ROCm/aiter/pull/1900/changes#diff-b6342ca94276a64b904784f48916379b6150a7e82233b9e99ce3f2402f5b037cR238

@wangye805
Copy link

This breaks XLA. Where tahat thread_local as bad as it is allows us to lead kernel per gpu.

Okay then I keep the fwd unchanged and only guards bwd.

Btw I think AiterAsmKernel::launch_kernel should ensure kernel is loaded to current stream's device? If this is a concern I can file another PR to improve AiterAsmKernel::launch_kernel

I got a little bit confused. If thread_local unordered_map is bad, then why not change the fwd thread_local to the same as this mutex protected one? Or in other words, how does the mutex protected one failed XLA? @draganmladjenovic

@draganmladjenovic
Copy link
Contributor

The thread_local by accident allowed us to have kernel loaded on each GPU where its used. Because we use one thread per gpu in XLA. It is bad because in general sense as a library mha would load kernel that many times how may distinct threads tried to launch the kernel. My PR fixed this alon with other issues aiter has. But I have little hope it will ever get merged. I'll try t rework it next week into something smaller.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants