mha: make impl_ptr_map caceh in fmha_v3_bwd and fmha_fwd_v3 thread-safe.#2201
mha: make impl_ptr_map caceh in fmha_v3_bwd and fmha_fwd_v3 thread-safe.#2201xinyazhang wants to merge 5 commits intoROCm:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR aims to make the internal asm-kernel cache (impl_ptr_map) used by fmha_fwd_v3 and fmha_v3_bwd safe under multi-threaded use by adding locking around cache updates.
Changes:
- Add
<mutex>includes and introduce a static mutex to guardimpl_ptr_mapupdates. - Wrap
impl_ptr_map.emplace(...)+ potentialAiterAsmKernelconstruction in a lock in both forward and backward paths. - Introduce a
LOCK_IMPL_PTR_MAPmacro to standardize the lock acquisition.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| csrc/cpp_itfs/mha_fwd.cu | Adds a mutex and locks around a thread_local kernel cache in fmha_fwd_v3. |
| csrc/cpp_itfs/mha_bwd.cu | Adds a mutex and locks around a shared static kernel cache in fmha_v3_bwd. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
csrc/cpp_itfs/mha_fwd.cu
Outdated
| static std::mutex impl_ptr_mutex; | ||
| static std::unordered_map<std::string, std::unique_ptr<AiterAsmKernel>> | ||
| impl_ptr_map; | ||
| #define LOCK_IMPL_PTR_MAP std::lock_guard<std::mutex> lock(impl_ptr_mutex) | ||
|
|
||
| const auto& cfg = it->second; | ||
| const char* name = cfg.knl_name.c_str(); | ||
| std::string co_name = get_kernel_co_name(cfg.co_name, arch_id); | ||
|
|
||
| auto result = impl_ptr_map.emplace(name, nullptr); | ||
| if(result.second) | ||
| { | ||
| LOCK_IMPL_PTR_MAP; | ||
| auto result = impl_ptr_map.emplace(name, nullptr); | ||
| if(result.second) | ||
| { | ||
| result.first->second = std::make_unique<AiterAsmKernel>(name, co_name.c_str()); | ||
| } | ||
| impl_ptr = result.first->second.get(); | ||
| } | ||
| impl_ptr = result.first->second.get(); | ||
| #undef LOCK_IMPL_PTR_MAP |
There was a problem hiding this comment.
Avoid introducing a preprocessor macro for locking (LOCK_IMPL_PTR_MAP). Macros inside a function make the code harder to read/debug and can accidentally conflict with identifiers in future edits. Prefer a normal RAII lock (e.g., a std::lock_guard / std::scoped_lock declared in the intended scope) or extract a small helper that returns the cached kernel pointer while handling locking internally.
…erged b/w fwd and bwd
|
This breaks XLA. Where tahat thread_local as bad as it is allows us to lead kernel per gpu. |
Okay then I keep the fwd unchanged and only guards bwd. Btw I think |
|
I mean we use both. But you are correct launch should ensure kernel is loaded on current GPU. You can have example how to do so here .https://github.com/ROCm/aiter/pull/1900/changes#diff-b6342ca94276a64b904784f48916379b6150a7e82233b9e99ce3f2402f5b037cR238 |
I got a little bit confused. If thread_local unordered_map is bad, then why not change the fwd thread_local to the same as this mutex protected one? Or in other words, how does the mutex protected one failed XLA? @draganmladjenovic |
|
The thread_local by accident allowed us to have kernel loaded on each GPU where its used. Because we use one thread per gpu in XLA. It is bad because in general sense as a library mha would load kernel that many times how may distinct threads tried to launch the kernel. My PR fixed this alon with other issues aiter has. But I have little hope it will ever get merged. I'll try t rework it next week into something smaller. |
Motivation
fmha_v3_bwdis not thread-safe due to missing locks in internal cache object.Technical Details
The cache object
impl_ptr_mapinfmha_v3_bwdandfmha_fwd_v3is updated without any locking.fmha_fwd_v3usesthread_localto ensures its correctness but defeats the purpose of caching in multi-threading environment.The PR makes both cache objects shared by all threads and protected by locks properly.
Test Plan
There is no testing planned at the moment due to difficult to validate if a function is thread-safe with unit tests.
Test Result
N/A
Submission Checklist