From 461d730a2114db34f3a7ffef12683c028778fcc7 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 6 Mar 2026 16:36:50 +0800 Subject: [PATCH 1/5] mha: make impl_ptr_map caceh in fmha_v3_bwd and fmha_fwd_v3 thread-safe. --- csrc/cpp_itfs/mha_bwd.cu | 6 ++++++ csrc/cpp_itfs/mha_fwd.cu | 12 +++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index 528513bb9f..8668f43f75 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -3,6 +3,7 @@ #include "asm_fmha_v3_bwd_configs.hpp" #include #include +#include namespace aiter { std::tuple get_padded_hdim(int hdim_q, int hdim_v, std::string arch_id) @@ -369,7 +370,9 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) AiterAsmKernel* impl_ptr_pre = nullptr; AiterAsmKernel* impl_ptr_dqdkdv = nullptr; AiterAsmKernel* impl_ptr_post = nullptr; + static std::mutex impl_ptr_mutex; static std::unordered_map> impl_ptr_map; +#define LOCK_IMPL_PTR_MAP std::lock_guard lock(impl_ptr_mutex) auto it_pre = pre_cfgs->find(pre_kernel); if(it_pre != pre_cfgs->end()) @@ -379,6 +382,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* co_name = cfg.co_name.c_str(); ts_odo = cfg.ts; + LOCK_IMPL_PTR_MAP; auto result = impl_ptr_map.emplace(name, nullptr); if(result.second) { @@ -400,6 +404,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* co_name = cfg.co_name.c_str(); ts_kv = cfg.ts; + LOCK_IMPL_PTR_MAP; auto result = impl_ptr_map.emplace(name, nullptr); if(result.second) { @@ -423,6 +428,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* co_name = cfg.co_name.c_str(); ts_dq = cfg.ts; + LOCK_IMPL_PTR_MAP; auto result = impl_ptr_map.emplace(name, nullptr); if(result.second) { diff --git a/csrc/cpp_itfs/mha_fwd.cu b/csrc/cpp_itfs/mha_fwd.cu index 25b54f7544..dd45733884 100644 --- a/csrc/cpp_itfs/mha_fwd.cu +++ b/csrc/cpp_itfs/mha_fwd.cu @@ -5,6 +5,7 @@ #endif #include #include +#include namespace aiter { #if FAV3_ON @@ -240,19 +241,24 @@ float fmha_fwd_v3(mha_fwd_args a, const ck_tile::stream_config& s) }; AiterAsmKernel* impl_ptr = nullptr; + static std::mutex impl_ptr_mutex; static thread_local std::unordered_map> impl_ptr_map; +#define LOCK_IMPL_PTR_MAP std::lock_guard lock(impl_ptr_mutex) const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); std::string co_name = get_kernel_co_name(cfg.co_name, arch_id); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) { + LOCK_IMPL_PTR_MAP; + auto result = impl_ptr_map.emplace(name, nullptr); + if(result.second) + { result.first->second = std::make_unique(name, co_name.c_str()); + } + impl_ptr = result.first->second.get(); } - impl_ptr = result.first->second.get(); fmha_fwd_v3_args args; size_t arg_size = sizeof(args); From a1dea07533b86006bbd23cecf36f06d312854afe Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 6 Mar 2026 17:14:35 +0800 Subject: [PATCH 2/5] fmha_fwd_v3: share impl_ptr_map cache object to avoid duplicated loading --- csrc/cpp_itfs/mha_fwd.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpp_itfs/mha_fwd.cu b/csrc/cpp_itfs/mha_fwd.cu index dd45733884..8bfa3cb2dc 100644 --- a/csrc/cpp_itfs/mha_fwd.cu +++ b/csrc/cpp_itfs/mha_fwd.cu @@ -242,7 +242,7 @@ float fmha_fwd_v3(mha_fwd_args a, const ck_tile::stream_config& s) AiterAsmKernel* impl_ptr = nullptr; static std::mutex impl_ptr_mutex; - static thread_local std::unordered_map> + static std::unordered_map> impl_ptr_map; #define LOCK_IMPL_PTR_MAP std::lock_guard lock(impl_ptr_mutex) From 87ec181dea2520722926305a992b7de3b019a16b Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 6 Mar 2026 17:16:15 +0800 Subject: [PATCH 3/5] mha: undef unused macros --- csrc/cpp_itfs/mha_bwd.cu | 1 + csrc/cpp_itfs/mha_fwd.cu | 1 + 2 files changed, 2 insertions(+) diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index 8668f43f75..fdcd8e099d 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -442,6 +442,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) return -1; } } +#undef LOCK_IMPL_PTR_MAP if(a.v3_api_check) return 1; diff --git a/csrc/cpp_itfs/mha_fwd.cu b/csrc/cpp_itfs/mha_fwd.cu index 8bfa3cb2dc..bd49533177 100644 --- a/csrc/cpp_itfs/mha_fwd.cu +++ b/csrc/cpp_itfs/mha_fwd.cu @@ -259,6 +259,7 @@ float fmha_fwd_v3(mha_fwd_args a, const ck_tile::stream_config& s) } impl_ptr = result.first->second.get(); } +#undef LOCK_IMPL_PTR_MAP fmha_fwd_v3_args args; size_t arg_size = sizeof(args); From c99e71888b371458dc1626ca7e2bce1007efb4ea Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 6 Mar 2026 18:01:32 +0800 Subject: [PATCH 4/5] copilot doesn't like single macro usage so the coding practice is diverged b/w fwd and bwd --- csrc/cpp_itfs/mha_fwd.cu | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/cpp_itfs/mha_fwd.cu b/csrc/cpp_itfs/mha_fwd.cu index bd49533177..0eb697401a 100644 --- a/csrc/cpp_itfs/mha_fwd.cu +++ b/csrc/cpp_itfs/mha_fwd.cu @@ -244,14 +244,13 @@ float fmha_fwd_v3(mha_fwd_args a, const ck_tile::stream_config& s) static std::mutex impl_ptr_mutex; static std::unordered_map> impl_ptr_map; -#define LOCK_IMPL_PTR_MAP std::lock_guard lock(impl_ptr_mutex) const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); std::string co_name = get_kernel_co_name(cfg.co_name, arch_id); { - LOCK_IMPL_PTR_MAP; + std::lock_guard lock(impl_ptr_mutex); auto result = impl_ptr_map.emplace(name, nullptr); if(result.second) { @@ -259,7 +258,6 @@ float fmha_fwd_v3(mha_fwd_args a, const ck_tile::stream_config& s) } impl_ptr = result.first->second.get(); } -#undef LOCK_IMPL_PTR_MAP fmha_fwd_v3_args args; size_t arg_size = sizeof(args); From a52d98bad74478202fb19e4c8cb065be0f1ec8c6 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 6 Mar 2026 18:41:36 +0800 Subject: [PATCH 5/5] restores csrc/cpp_itfs/mha_fwd.cu because it breaks XLA --- csrc/cpp_itfs/mha_fwd.cu | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/csrc/cpp_itfs/mha_fwd.cu b/csrc/cpp_itfs/mha_fwd.cu index 0eb697401a..25b54f7544 100644 --- a/csrc/cpp_itfs/mha_fwd.cu +++ b/csrc/cpp_itfs/mha_fwd.cu @@ -5,7 +5,6 @@ #endif #include #include -#include namespace aiter { #if FAV3_ON @@ -241,23 +240,19 @@ float fmha_fwd_v3(mha_fwd_args a, const ck_tile::stream_config& s) }; AiterAsmKernel* impl_ptr = nullptr; - static std::mutex impl_ptr_mutex; - static std::unordered_map> + static thread_local std::unordered_map> impl_ptr_map; const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); std::string co_name = get_kernel_co_name(cfg.co_name, arch_id); + auto result = impl_ptr_map.emplace(name, nullptr); + if(result.second) { - std::lock_guard lock(impl_ptr_mutex); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { result.first->second = std::make_unique(name, co_name.c_str()); - } - impl_ptr = result.first->second.get(); } + impl_ptr = result.first->second.get(); fmha_fwd_v3_args args; size_t arg_size = sizeof(args);