From 54ca046372a2478cf87425fcea84dc5a82f2d27b Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Mon, 9 Mar 2026 08:09:15 -0500 Subject: [PATCH 1/2] Make AiterAsmKernel load hsaco on each GPU it is used on --- csrc/cpp_itfs/moe/asm_moe.cpp.jinja | 11 +- csrc/include/aiter_hip_common.h | 174 +++++++++++++++++----------- csrc/py_itfs_cu/asm_fmoe.cu | 33 +++--- 3 files changed, 134 insertions(+), 84 deletions(-) diff --git a/csrc/cpp_itfs/moe/asm_moe.cpp.jinja b/csrc/cpp_itfs/moe/asm_moe.cpp.jinja index 05e0b5d33d..2a3944a43d 100644 --- a/csrc/cpp_itfs/moe/asm_moe.cpp.jinja +++ b/csrc/cpp_itfs/moe/asm_moe.cpp.jinja @@ -63,19 +63,18 @@ struct __attribute__((packed)) KernelArgs }; -unsigned char hsaco[{{bin_size}}] = { {{bin_data}} }; +const unsigned char hsaco[{{bin_size}}] = { {{bin_data}} }; class FMoeKernel { private: - std::unique_ptr asm_kernel=nullptr; + AiterAsmKernelFast asm_kernel; uint32_t sub_GU = 512; bool is_int4 = false; public: - FMoeKernel() + FMoeKernel() : asm_kernel("{{kernel_name}}", hsaco) { - asm_kernel=std::make_unique("{{kernel_name}}", hsaco); this->sub_GU = {{selected_tile}}; }; @@ -181,11 +180,11 @@ public: if constexpr (switchGxy) { - asm_kernel->launch_kernel({&args, &arg_size, gdy, gdx, gdz, bdx, 1, 1, stream}); + asm_kernel.launch_kernel({&args, &arg_size, gdy, gdx, gdz, bdx, 1, 1, stream}); } else { - asm_kernel->launch_kernel({&args, &arg_size, gdx, gdy, gdz, bdx, 1, 1, stream}); + asm_kernel.launch_kernel({&args, &arg_size, gdx, gdy, gdz, bdx, 1, 1, stream}); } }; }; diff --git a/csrc/include/aiter_hip_common.h b/csrc/include/aiter_hip_common.h index bbe5eb89d1..6334132610 100644 --- a/csrc/include/aiter_hip_common.h +++ b/csrc/include/aiter_hip_common.h @@ -6,6 +6,9 @@ #include #include #include +#include +#include +#include #ifdef AITER_EMBEDDED_HSA_HEADER #include AITER_EMBEDDED_HSA_HEADER #endif @@ -72,55 +75,82 @@ struct AiterAsmKernelArgs static const std::string get_gpu_arch(); -inline void load_asm_kernel(const char* name, - const char* hsaco, - hipModule_t& module, - hipFunction_t& kernel_func) +namespace detail { +struct FatBinaryWrapper { - const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); - std::string arch_name = get_gpu_arch(); - if(AITER_ASM_DIR != nullptr) - { - std::string hsa_path = std::string(AITER_ASM_DIR) + "/" + arch_name + "/" + hsaco; - AITER_LOG_INFO("hipModuleLoad: " << hsa_path << " GetFunction: " << name); - HIP_CALL(hipModuleLoad(&module, hsa_path.c_str())); - } - else - { -#if defined(AITER_EMBEDDED_HSA_HEADER) && defined(AITER_EMBEDDED_HSA_MAP) - std::string fname = "hsa/" + arch_name + "/" + hsaco; - auto hasco_obj = AITER_EMBEDDED_HSA_MAP.find(fname); - CHECK_COND(hasco_obj != AITER_EMBEDDED_HSA_MAP.end()); - CHECK_COND(hasco_obj->second.data() != nullptr); - AITER_LOG_INFO("hipModuleLoad: " << fname << " GetFunction: " << name); - HIP_CALL(hipModuleLoadData(&module, hasco_obj->second.data())); -#endif - } - HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); - AITER_LOG_INFO("hipModuleGetFunction: " << name << " Success"); -} + uint32_t magic = 0x48495046; // "HIPF"; + uint32_t version = 1; + const void* binary = nullptr; + intptr_t __pad = 0; +}; + +extern "C" void* __hipRegisterFatBinary(const FatBinaryWrapper* data) noexcept; +extern "C" void __hipUnregisterFatBinary(void* module) noexcept; +extern "C" void __hipRegisterFunction(void* module, + const void* hostFunction, + const char* deviceFunction, + const char* deviceName, + int threadLimit, + void* tid, + void* bid, + void* blockDim, + void* gridDim, + void* wSize) noexcept; +} // namespace detail + -class AiterAsmKernel +namespace { + +class AiterAsmKernelFast { private: - hipModule_t module; - hipFunction_t kernel_func; + void* module = nullptr; + + protected: + AiterAsmKernelFast() = default; + void init(const char* kernel_name, const void* hsaco) + { + detail::FatBinaryWrapper fat_bin{}; + fat_bin.binary = hsaco; + module = detail::__hipRegisterFatBinary(&fat_bin); + CHECK_COND(module != nullptr); + detail::__hipRegisterFunction(module, + static_cast(this), + kernel_name, + kernel_name, + -1, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr); + } public: - AiterAsmKernel(const char* name, const char* hsaco) + AiterAsmKernelFast(const char* kernel_name, const void* hsaco) { - load_asm_kernel(name, hsaco, module, kernel_func); + init(kernel_name, hsaco); }; - ~AiterAsmKernel() { HIP_CALL(hipModuleUnload(module)); } + ~AiterAsmKernelFast() { detail::__hipUnregisterFatBinary(module); } + + AiterAsmKernelFast(AiterAsmKernelFast&) = delete; + AiterAsmKernelFast(AiterAsmKernelFast&&) = delete; + AiterAsmKernelFast& operator=(AiterAsmKernelFast&) = delete; + AiterAsmKernelFast& operator=(AiterAsmKernelFast&&) = delete; void launch_kernel(const AiterAsmKernelArgs& kargs) { - void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, - kargs.args_ptr, - HIP_LAUNCH_PARAM_BUFFER_SIZE, - kargs.arg_size_ptr, - HIP_LAUNCH_PARAM_END}; + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, + kargs.args_ptr, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + kargs.arg_size_ptr, + HIP_LAUNCH_PARAM_END}; + hipFunction_t kernel_func = nullptr; + // TODO Ask runtime folks to provide an API for hipLaunchKernel with extra arg + // Don't error check here. + // Failure to load the func would cause hipModuleLaunchKernel to fail anyways. + (void)hipGetFuncBySymbol(&kernel_func, reinterpret_cast(this)); HIP_CALL(hipModuleLaunchKernel(kernel_func, kargs.gdx, @@ -136,44 +166,58 @@ class AiterAsmKernel }; }; -class AiterAsmKernelFast + +class AiterAsmKernel: private AiterAsmKernelFast { private: - hipModule_t module; - hipFunction_t kernel_func; + std::unique_ptr hsaco_data; - public: - AiterAsmKernelFast(const char* name, void* hsaco) + const void* load_hsaco_file(const char* hsaco_path) { - HIP_CALL(hipModuleLoadData(&module, hsaco)); - HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); - AITER_LOG_INFO("hipModuleGetFunction: " << name << " Success"); - }; + const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); + std::string arch_name = get_gpu_arch(); + if(AITER_ASM_DIR != nullptr) + { + std::string full_path = std::string(AITER_ASM_DIR) + "/" + arch_name + "/" + hsaco_path; - ~AiterAsmKernelFast() { HIP_CALL(hipModuleUnload(module)); } + std::ifstream file(full_path, std::ios::binary | std::ios::ate); - void launch_kernel(const AiterAsmKernelArgs& kargs) - { - void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, - kargs.args_ptr, - HIP_LAUNCH_PARAM_BUFFER_SIZE, - kargs.arg_size_ptr, - HIP_LAUNCH_PARAM_END}; + CHECK_COND(file.is_open()); - HIP_CALL(hipModuleLaunchKernel(kernel_func, - kargs.gdx, - kargs.gdy, - kargs.gdz, - kargs.bdx, - kargs.bdy, - kargs.bdz, - 0, - kargs.stream, - nullptr, - (void**)&config)); + size_t file_size = file.tellg(); + hsaco_data.reset(new char[file_size]); + + file.seekg(0, std::ios::beg); + CHECK_COND(file.read(hsaco_data.get(), file_size)); + return hsaco_data.get(); + } + else + { +#if defined(AITER_EMBEDDED_HSA_HEADER) && defined(AITER_EMBEDDED_HSA_MAP) + std::string fname = "hsa/" + arch_name + "/" + hsaco; + auto hasco_obj = AITER_EMBEDDED_HSA_MAP.find(fname); + CHECK_COND(hasco_obj != AITER_EMBEDDED_HSA_MAP.end()); + CHECK_COND(hasco_obj->second.data() != nullptr); + return hasco_obj->second.data(); +#else + CHECK_COND(AITER_ASM_DIR != nullptr); + return nullptr; +#endif + } + } + + public: + AiterAsmKernel(const char* kernel_name, const char* hsaco_path) + { + init(kernel_name, load_hsaco_file(hsaco_path)); }; + + using AiterAsmKernelFast::launch_kernel; }; + +} // namespace + static const std::string get_gpu_arch() { int device_count; diff --git a/csrc/py_itfs_cu/asm_fmoe.cu b/csrc/py_itfs_cu/asm_fmoe.cu index 9b4c35ce75..fda1206e07 100755 --- a/csrc/py_itfs_cu/asm_fmoe.cu +++ b/csrc/py_itfs_cu/asm_fmoe.cu @@ -74,8 +74,7 @@ struct __attribute__((packed)) KernelArgs class FMoeKernel { private: - hipModule_t module; - hipFunction_t kernel_func; + AiterAsmKernel kernel; uint32_t sub_GU = 512; bool is_int4 = false; uint32_t num_persistent_tgs = 0; @@ -85,9 +84,8 @@ class FMoeKernel FMoeKernel(const char* name, const char* hsaco, uint32_t sub_GU = 512, - uint32_t num_persistent_tgs = 0) + uint32_t num_persistent_tgs = 0) : kernel(name, hsaco) { - load_asm_kernel(name, hsaco, module, kernel_func); this->sub_GU = sub_GU; this->num_persistent_tgs = num_persistent_tgs; this->name = name; @@ -186,11 +184,6 @@ class FMoeKernel args.ps_deno = ((inter_dim + sub_GU - 1) / sub_GU); args.total_tgs = this->num_persistent_tgs / args.ps_deno * args.ps_deno; - void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, - &args, - HIP_LAUNCH_PARAM_BUFFER_SIZE, - &arg_size, - HIP_LAUNCH_PARAM_END}; int bdx; int gdx; int gdy; @@ -235,13 +228,27 @@ class FMoeKernel const hipStream_t stream = at::hip::getCurrentHIPStream(); if constexpr(switchGxy) { - HIP_CALL(hipModuleLaunchKernel( - kernel_func, gdy, gdx, gdz, bdx, 1, 1, 0, stream, nullptr, (void**)&config)); + kernel.launch_kernel({&args, + &arg_size, + gdy, // gdx + gdx, // gdy + gdz, // gdz + bdx, // bdx + 1, // bdy + 1, // bdz + stream}); } else { - HIP_CALL(hipModuleLaunchKernel( - kernel_func, gdx, gdy, gdz, bdx, 1, 1, 0, stream, nullptr, (void**)&config)); + kernel.launch_kernel({&args, + &arg_size, + gdx, // gdx + gdy, // gdy + gdz, // gdz + bdx, // bdx + 1, // bdy + 1, // bdz + stream}); } }; }; From e4e9f6c3f1ebf8e646ba89f80788bd1a24a176b7 Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Mon, 9 Mar 2026 08:10:02 -0500 Subject: [PATCH 2/2] Replace unsafe uses of std::unordered_map with SynchronizedCache --- csrc/cpp_itfs/mha_bwd.cu | 29 ++++------------ csrc/cpp_itfs/mha_fwd.cu | 11 ++---- csrc/include/aiter_hip_common.h | 24 +++++++++++++ .../asm_a8w8_blockscale_bpreshuffle.cu | 28 ++++++--------- csrc/py_itfs_cu/asm_fmoe.cu | 23 +++++-------- csrc/py_itfs_cu/asm_gemm_a16w16.cu | 8 ++--- csrc/py_itfs_cu/asm_gemm_a4w4.cu | 34 +++++-------------- csrc/py_itfs_cu/asm_gemm_a8w8.cu | 34 ++++++------------- csrc/py_itfs_cu/asm_mla.cu | 34 +++++++------------ csrc/py_itfs_cu/asm_moe_2stage.cu | 10 ++---- csrc/py_itfs_cu/asm_pa.cu | 22 +++++------- csrc/py_itfs_cu/asm_topksoftmax.cu | 11 +++--- 12 files changed, 101 insertions(+), 167 deletions(-) diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index 528513bb9f..33955034d7 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -369,7 +369,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) AiterAsmKernel* impl_ptr_pre = nullptr; AiterAsmKernel* impl_ptr_dqdkdv = nullptr; AiterAsmKernel* impl_ptr_post = nullptr; - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; auto it_pre = pre_cfgs->find(pre_kernel); if(it_pre != pre_cfgs->end()) @@ -379,13 +379,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* co_name = cfg.co_name.c_str(); ts_odo = cfg.ts; - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - - impl_ptr_pre = result.first->second.get(); + impl_ptr_pre = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else { @@ -400,13 +395,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* co_name = cfg.co_name.c_str(); ts_kv = cfg.ts; - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - - impl_ptr_dqdkdv = result.first->second.get(); + impl_ptr_dqdkdv = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else { @@ -423,13 +413,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* co_name = cfg.co_name.c_str(); ts_dq = cfg.ts; - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - - impl_ptr_post = result.first->second.get(); + impl_ptr_post = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else { diff --git a/csrc/cpp_itfs/mha_fwd.cu b/csrc/cpp_itfs/mha_fwd.cu index 25b54f7544..ec3fa60a34 100644 --- a/csrc/cpp_itfs/mha_fwd.cu +++ b/csrc/cpp_itfs/mha_fwd.cu @@ -240,19 +240,14 @@ float fmha_fwd_v3(mha_fwd_args a, const ck_tile::stream_config& s) }; AiterAsmKernel* impl_ptr = nullptr; - static thread_local std::unordered_map> - impl_ptr_map; + static SynchronizedCache impl_ptr_map; const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); std::string co_name = get_kernel_co_name(cfg.co_name, arch_id); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name.c_str()); - } - impl_ptr = result.first->second.get(); + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name.c_str()); }); fmha_fwd_v3_args args; size_t arg_size = sizeof(args); diff --git a/csrc/include/aiter_hip_common.h b/csrc/include/aiter_hip_common.h index 6334132610..4c62517d2b 100644 --- a/csrc/include/aiter_hip_common.h +++ b/csrc/include/aiter_hip_common.h @@ -256,3 +256,27 @@ static uint32_t get_num_cu_func() static const uint32_t num_cu = get_num_cu_local(); return num_cu; } + +template , class KeyEqual = std::equal_to> +struct SynchronizedCache +{ + template + inline T& get_or_create(K&& k, F&& factory) + { + std::lock_guard map_mu_guard(map_mu); + + struct Wrapper + { + F& f; + // Makes usre we only invoke lambda on insert + operator T() && { return f(); } + }; + + auto [it, _] = map.try_emplace(std::forward(k), Wrapper{factory}); + return it->second; + } + + private: + std::mutex map_mu; + std::unordered_map map; +}; \ No newline at end of file diff --git a/csrc/py_itfs_cu/asm_a8w8_blockscale_bpreshuffle.cu b/csrc/py_itfs_cu/asm_a8w8_blockscale_bpreshuffle.cu index bcf76397a2..2395b28eea 100644 --- a/csrc/py_itfs_cu/asm_a8w8_blockscale_bpreshuffle.cu +++ b/csrc/py_itfs_cu/asm_a8w8_blockscale_bpreshuffle.cu @@ -145,8 +145,8 @@ struct KernelSelector { } }; - static std::unordered_map, SimpleHash> heuristic_cache; - static std::unordered_map> kernel_cache; + static SynchronizedCache, SimpleHash> heuristic_cache; + static SynchronizedCache kernel_cache; static std::tuple select_kernel(int M, int N, int K, const std::string& arch_id, std::optional splitK, std::optional bpreshuffle, @@ -156,28 +156,22 @@ struct KernelSelector { } DictKey key(M, N, K, splitK, bpreshuffle); - auto it = heuristic_cache.find(key); - if (it != heuristic_cache.end()) { - return it->second; // find it and return - } - auto result = get_heuristic_fp8_kernel(M, N, K, arch_id, splitK, bpreshuffle, config_map); - heuristic_cache[key] = result; - return result; + + return heuristic_cache.get_or_create(key, [&]() { + return get_heuristic_fp8_kernel(M, N, K, arch_id, splitK, bpreshuffle, config_map); + }); } - + static AiterAsmKernel* get_kernel(const std::string& kernel_name, const std::string& co_name) { - auto result = kernel_cache.emplace(kernel_name, nullptr); - if (result.second) { - result.first->second = std::make_unique(kernel_name.c_str(), co_name.c_str()); - } - return result.first->second.get(); + return &kernel_cache.get_or_create( + kernel_name, [&]() { return AiterAsmKernel(kernel_name.c_str(), co_name.c_str()); }); } }; -std::unordered_map, KernelSelector::SimpleHash> +SynchronizedCache, KernelSelector::SimpleHash> KernelSelector::heuristic_cache; -std::unordered_map> KernelSelector::kernel_cache; +SynchronizedCache KernelSelector::kernel_cache; static KernelArgs setup_kernel_args(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& out, const torch::Tensor& A_scale, const torch::Tensor& B_scale, diff --git a/csrc/py_itfs_cu/asm_fmoe.cu b/csrc/py_itfs_cu/asm_fmoe.cu index fda1206e07..337d497ad1 100755 --- a/csrc/py_itfs_cu/asm_fmoe.cu +++ b/csrc/py_itfs_cu/asm_fmoe.cu @@ -265,7 +265,7 @@ FMoeKernel* get_heuristic_kernel( std::string arch_id = get_gpu_arch(); std::string selectedKl = kernel_name.empty() ? "" : arch_id + kernel_name; int vskip = 1; - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; const char* vs_env_value = std::getenv("AITER_ENABLE_VSKIP"); if(vs_env_value != nullptr && std::string(vs_env_value) == "0") @@ -319,15 +319,13 @@ FMoeKernel* get_heuristic_kernel( const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); if(cfg.ps == 1) num_persistent_tgs = cfg.tg_num_perCU * num_cu; else num_persistent_tgs = 0; - if(result.second) - result.first->second = - std::make_unique(name, co_name, cfg.subGU_n, num_persistent_tgs); - impl_ptr = result.first->second.get(); + + impl_ptr = &impl_ptr_map.get_or_create( + name, [&]() { return FMoeKernel(name, co_name, cfg.subGU_n, num_persistent_tgs); }); } else TORCH_CHECK(false, __func__, " not find kernel " + selectedKl); @@ -421,7 +419,7 @@ void fmoe_int8_g1u0(torch::Tensor& out, // [token_cnt, dim] { FMoeKernel* impl_ptr = nullptr; int inter_dim = down.size(2); - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; struct FMoeKernelConfig { @@ -497,13 +495,8 @@ void fmoe_int8_g1u0(torch::Tensor& out, // [token_cnt, dim] const char* name = config.name.c_str(); const char* co_name = config.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = - std::make_unique(name, co_name, config.tile_size); - } - impl_ptr = result.first->second.get(); + impl_ptr = &impl_ptr_map.get_or_create( + name, [&]() { return FMoeKernel(name, co_name, config.tile_size); }); } } impl_ptr->launch_kernel(out, @@ -551,7 +544,7 @@ void fmoe_g1u1(torch::Tensor& out, // [token_cnt, dim] int inter_dim = down.size(2); inter_dim *= model_dim / gate.size(2); int sub_X_cnt = sorted_expert_ids.size(0); - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; if(gate.dtype() == at::ScalarType::UInt32 || gate.dtype() == at::ScalarType::Int) // int4 { int selectedTile = get_heuristic_tile( diff --git a/csrc/py_itfs_cu/asm_gemm_a16w16.cu b/csrc/py_itfs_cu/asm_gemm_a16w16.cu index 4c3cc0088e..edbb08a338 100644 --- a/csrc/py_itfs_cu/asm_gemm_a16w16.cu +++ b/csrc/py_itfs_cu/asm_gemm_a16w16.cu @@ -176,7 +176,7 @@ AiterAsmKernel* get_or_load_kernel(const std::string& selectedKernelName, unsigned int& SUBM, unsigned int& SUBN) { - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; auto it_kl = config_map->find(selectedKernelName); TORCH_CHECK(it_kl != config_map->end(), __func__, " not find kernel~ " + selectedKernelName); @@ -187,11 +187,7 @@ AiterAsmKernel* get_or_load_kernel(const std::string& selectedKernelName, SUBM = cfg.tileM; SUBN = cfg.tileN; - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - result.first->second = std::make_unique(name, co_name); - - return result.first->second.get(); + return &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } torch::Tensor gemm_a16w16_asm(torch::Tensor& A, diff --git a/csrc/py_itfs_cu/asm_gemm_a4w4.cu b/csrc/py_itfs_cu/asm_gemm_a4w4.cu index c76d9107d7..1fbcf878d3 100644 --- a/csrc/py_itfs_cu/asm_gemm_a4w4.cu +++ b/csrc/py_itfs_cu/asm_gemm_a4w4.cu @@ -213,7 +213,7 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 std::hash()(log2_key) ^ std::hash()(shuffle_key); } }; - static std::unordered_map, SimpleHash> + static SynchronizedCache, SimpleHash> heuristic_kernel_dict; if(config_map->empty()) @@ -221,7 +221,7 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 TORCH_CHECK(false, __func__, " no kernel support a4w4 for this gpu arch"); } - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; std::string arch_id = get_gpu_arch(); kernelName = kernelName.empty() ? "" : arch_id + kernelName; @@ -229,23 +229,11 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 int selectedksplit = log2_k_split.has_value() ? log2_k_split.value() : 0; if(kernelName.empty()) { - auto it = heuristic_kernel_dict.find(DictKey(Mdim, Ndim, Kdim, log2_k_split, bpreshuffle)); - if(it != heuristic_kernel_dict.end()) - { - auto res = it->second; - kernelName = std::get<0>(res); - selectedksplit = std::get<1>(res); - } - else - { - auto it = get_heuristic_kernel( - Mdim, Ndim, Kdim, arch_id, log2_k_split, bpreshuffle, config_map); - - kernelName = std::get<0>(it); - selectedksplit = std::get<1>(it); - heuristic_kernel_dict[{Mdim, Ndim, Kdim, log2_k_split, bpreshuffle}] = - std::make_tuple(kernelName, selectedksplit); - } + std::tie(kernelName, selectedksplit) = heuristic_kernel_dict.get_or_create( + DictKey(Mdim, Ndim, Kdim, log2_k_split, bpreshuffle), [&]() { + return get_heuristic_kernel( + Mdim, Ndim, Kdim, arch_id, log2_k_split, bpreshuffle, config_map); + }); } AiterAsmKernel* impl_ptr = nullptr; @@ -274,12 +262,8 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 gdz = (Kdim + k_per_tg - 1) / k_per_tg; } - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else TORCH_CHECK(false, __func__, " not find kernel " + kernelName); diff --git a/csrc/py_itfs_cu/asm_gemm_a8w8.cu b/csrc/py_itfs_cu/asm_gemm_a8w8.cu index 44f8e40288..927b50b4c3 100644 --- a/csrc/py_itfs_cu/asm_gemm_a8w8.cu +++ b/csrc/py_itfs_cu/asm_gemm_a8w8.cu @@ -180,35 +180,24 @@ torch::Tensor gemm_a8w8_asm(torch::Tensor& A, // A:[M, K] i8 std::hash()(splitk_key) ^ std::hash()(shuffle_key); } }; - static std::unordered_map, SimpleHash> + static SynchronizedCache, SimpleHash> heuristic_kernel_dict; if(config_map->empty()) { TORCH_CHECK(false, __func__, " no kernel support a8w8 for this gpu arch"); } - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; std::string arch_id = get_gpu_arch(); kernelName = kernelName.empty() ? "" : arch_id + kernelName; int selectedksplit = splitK.value_or(0) ?: 1; if(kernelName.empty()) { - auto it = heuristic_kernel_dict.find(DictKey(Mdim, Ndim, Kdim, splitK, bpreshuffle)); - if(it != heuristic_kernel_dict.end()) - { - auto res = it->second; - kernelName = std::get<0>(res); - selectedksplit = std::get<1>(res); - } - else - { - auto it = get_heuristic_kernel(Mdim, Ndim, Kdim, arch_id, splitK, bpreshuffle, config_map); - - kernelName = std::get<0>(it); - selectedksplit = std::get<1>(it); - heuristic_kernel_dict[{Mdim, Ndim, Kdim, splitK, bpreshuffle}] = - std::make_tuple(kernelName, selectedksplit); - } + std::tie(kernelName, selectedksplit) = heuristic_kernel_dict.get_or_create( + DictKey(Mdim, Ndim, Kdim, splitK, bpreshuffle), [&]() { + return get_heuristic_kernel( + Mdim, Ndim, Kdim, arch_id, splitK, bpreshuffle, config_map); + }); } AiterAsmKernel* impl_ptr = nullptr; @@ -257,12 +246,9 @@ torch::Tensor gemm_a8w8_asm(torch::Tensor& A, // A:[M, K] i8 out.zero_(); } gdx = gdx * selectedksplit; - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else TORCH_CHECK(false, __func__, " not find kernel " + kernelName); diff --git a/csrc/py_itfs_cu/asm_mla.cu b/csrc/py_itfs_cu/asm_mla.cu index cf9a938796..5c7690d42a 100644 --- a/csrc/py_itfs_cu/asm_mla.cu +++ b/csrc/py_itfs_cu/asm_mla.cu @@ -235,7 +235,7 @@ void mla_decode_stage1_asm_fwd( // Get kernel using config dispatch std::string arch_id = get_gpu_arch(); CFG* config_map = &cfg_mla_asm; - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; int ps = persistent ? 1 : 0; int prefill = 0; // decode stage @@ -321,13 +321,9 @@ void mla_decode_stage1_asm_fwd( const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); - + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else TORCH_CHECK(false, __func__, " not find kernel " + kernelName); @@ -489,7 +485,7 @@ void mla_prefill_ps_asm_fwd( TORCH_CHECK(false, __func__, ": fp8 mla persistent prefill is not supported on gfx942"); } CFG* config_map = &cfg_mla_asm; - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; int ps = 1; // ps_prefill always uses persistent scheduling int prefill = 1; // prefill stage @@ -509,12 +505,9 @@ void mla_prefill_ps_asm_fwd( const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else TORCH_CHECK(false, __func__, " not find kernel " + kernelName); @@ -608,7 +601,7 @@ void mla_prefill_asm_fwd( // Get kernel using config dispatch std::string arch_id = get_gpu_arch(); CFG* config_map = &cfg_mla_asm; - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; int ps = 0; // prefill without persistent scheduling int prefill = 1; // prefill stage @@ -626,12 +619,9 @@ void mla_prefill_asm_fwd( const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else TORCH_CHECK(false, __func__, " not find kernel " + kernelName); diff --git a/csrc/py_itfs_cu/asm_moe_2stage.cu b/csrc/py_itfs_cu/asm_moe_2stage.cu index 7111df37b7..41d221f2cb 100644 --- a/csrc/py_itfs_cu/asm_moe_2stage.cu +++ b/csrc/py_itfs_cu/asm_moe_2stage.cu @@ -169,7 +169,7 @@ void moe_stage1_g1u1( const hipStream_t stream = at::hip::getCurrentHIPStream(); CFG *config_map = get_cfg(input, out, w1, quant_type, sorted_weights.has_value()); - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; int model_dim = input.size(1); int hidden_dim = inter_dim; int sub_X_cnt = sorted_expert_ids.size(0); @@ -190,12 +190,8 @@ void moe_stage1_g1u1( TORCH_CHECK(inter_dim % cfg.tile_n == 0, "ASM kernel " + std::string(name) + " is not supported for inter_dim = " + std::to_string(inter_dim)); - auto result = impl_ptr_map.emplace(name, nullptr); - if (result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else TORCH_CHECK(false, __func__, " not find kernel " + kernelName); diff --git a/csrc/py_itfs_cu/asm_pa.cu b/csrc/py_itfs_cu/asm_pa.cu index 6dbe063123..7e3a4ad6dc 100644 --- a/csrc/py_itfs_cu/asm_pa.cu +++ b/csrc/py_itfs_cu/asm_pa.cu @@ -271,7 +271,7 @@ torch::Tensor pa_fwd(torch::Tensor& Q, // [num_seqs, num_heads, head_size] }; int qTile = 0; CFG* config_map = &cfg_pa_asm; // only one config csv in hsa//pa, now - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; std::string kernelName = kernelName_.has_value() ? arch_id + kernelName_.value() : ""; int ps = 0; if (kernelName.empty()) @@ -289,12 +289,9 @@ torch::Tensor pa_fwd(torch::Tensor& Q, // [num_seqs, num_heads, head_size] const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else TORCH_CHECK(false, __func__, " not find kernel " + kernelName); @@ -452,7 +449,7 @@ torch::Tensor pa_ps_fwd(torch::Tensor& Q, // [num_seqs, num_heads, head_size] ") exceeds maximum available qTile. Please reduce gqa_ratio or max_qlen."); CFG* config_map = &cfg_pa_asm; // only one config csv in hsa//pa, now - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; std::string arch_id = get_gpu_arch(); std::string kernelName = kernelName_.value_or( get_heuristic_kernel(q_type, kv_type, gqa, mtp, msk, hp, block_size, arch_id, ps, qTile, config_map)); @@ -470,12 +467,9 @@ torch::Tensor pa_ps_fwd(torch::Tensor& Q, // [num_seqs, num_heads, head_size] const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); if(cfg.ps) { gdx = get_num_cu_func(); diff --git a/csrc/py_itfs_cu/asm_topksoftmax.cu b/csrc/py_itfs_cu/asm_topksoftmax.cu index 509e0ca509..bc4228f627 100644 --- a/csrc/py_itfs_cu/asm_topksoftmax.cu +++ b/csrc/py_itfs_cu/asm_topksoftmax.cu @@ -95,7 +95,7 @@ void topk_softmax_asm(torch::Tensor& topk_weights, // [num_tokens, topk] args.out_stride = out_stride * 4; CFG* config_map = &cfg_topksoftmax; - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; AiterAsmKernel* impl_ptr = nullptr; auto [kernelName, subm] = get_heuristic_kernel_topksoftmax(arch_id, dtype, MAX_SUBM, num_experts, topk, config_map); @@ -105,12 +105,9 @@ void topk_softmax_asm(torch::Tensor& topk_weights, // [num_tokens, topk] const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else TORCH_CHECK(false, __func__, " not find kernel " + kernelName);