diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index 528513bb9f..fdcd8e099d 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -3,6 +3,7 @@ #include "asm_fmha_v3_bwd_configs.hpp" #include #include +#include namespace aiter { std::tuple get_padded_hdim(int hdim_q, int hdim_v, std::string arch_id) @@ -369,7 +370,9 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) AiterAsmKernel* impl_ptr_pre = nullptr; AiterAsmKernel* impl_ptr_dqdkdv = nullptr; AiterAsmKernel* impl_ptr_post = nullptr; + static std::mutex impl_ptr_mutex; static std::unordered_map> impl_ptr_map; +#define LOCK_IMPL_PTR_MAP std::lock_guard lock(impl_ptr_mutex) auto it_pre = pre_cfgs->find(pre_kernel); if(it_pre != pre_cfgs->end()) @@ -379,6 +382,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* co_name = cfg.co_name.c_str(); ts_odo = cfg.ts; + LOCK_IMPL_PTR_MAP; auto result = impl_ptr_map.emplace(name, nullptr); if(result.second) { @@ -400,6 +404,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* co_name = cfg.co_name.c_str(); ts_kv = cfg.ts; + LOCK_IMPL_PTR_MAP; auto result = impl_ptr_map.emplace(name, nullptr); if(result.second) { @@ -423,6 +428,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* co_name = cfg.co_name.c_str(); ts_dq = cfg.ts; + LOCK_IMPL_PTR_MAP; auto result = impl_ptr_map.emplace(name, nullptr); if(result.second) { @@ -436,6 +442,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) return -1; } } +#undef LOCK_IMPL_PTR_MAP if(a.v3_api_check) return 1;