Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 7 additions & 22 deletions csrc/cpp_itfs/mha_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
AiterAsmKernel* impl_ptr_pre = nullptr;
AiterAsmKernel* impl_ptr_dqdkdv = nullptr;
AiterAsmKernel* impl_ptr_post = nullptr;
static std::unordered_map<std::string, std::unique_ptr<AiterAsmKernel>> impl_ptr_map;
static SynchronizedCache<std::string_view, AiterAsmKernel> impl_ptr_map;

auto it_pre = pre_cfgs->find(pre_kernel);
if(it_pre != pre_cfgs->end())
Expand All @@ -379,13 +379,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
const char* co_name = cfg.co_name.c_str();
ts_odo = cfg.ts;

auto result = impl_ptr_map.emplace(name, nullptr);
if(result.second)
{
result.first->second = std::make_unique<AiterAsmKernel>(name, co_name);
}

impl_ptr_pre = result.first->second.get();
impl_ptr_pre =
&impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); });
}
else
{
Expand All @@ -400,13 +395,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
const char* co_name = cfg.co_name.c_str();
ts_kv = cfg.ts;

auto result = impl_ptr_map.emplace(name, nullptr);
if(result.second)
{
result.first->second = std::make_unique<AiterAsmKernel>(name, co_name);
}

impl_ptr_dqdkdv = result.first->second.get();
impl_ptr_dqdkdv =
&impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); });
}
else
{
Expand All @@ -423,13 +413,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
const char* co_name = cfg.co_name.c_str();
ts_dq = cfg.ts;

auto result = impl_ptr_map.emplace(name, nullptr);
if(result.second)
{
result.first->second = std::make_unique<AiterAsmKernel>(name, co_name);
}

impl_ptr_post = result.first->second.get();
impl_ptr_post =
&impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); });
}
else
{
Expand Down
11 changes: 3 additions & 8 deletions csrc/cpp_itfs/mha_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,19 +240,14 @@ float fmha_fwd_v3(mha_fwd_args a, const ck_tile::stream_config& s)
};

AiterAsmKernel* impl_ptr = nullptr;
static thread_local std::unordered_map<std::string, std::unique_ptr<AiterAsmKernel>>
impl_ptr_map;
static SynchronizedCache<std::string_view, AiterAsmKernel> impl_ptr_map;

const auto& cfg = it->second;
const char* name = cfg.knl_name.c_str();
std::string co_name = get_kernel_co_name(cfg.co_name, arch_id);

auto result = impl_ptr_map.emplace(name, nullptr);
if(result.second)
{
result.first->second = std::make_unique<AiterAsmKernel>(name, co_name.c_str());
}
impl_ptr = result.first->second.get();
impl_ptr =
&impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name.c_str()); });

fmha_fwd_v3_args args;
size_t arg_size = sizeof(args);
Expand Down
11 changes: 5 additions & 6 deletions csrc/cpp_itfs/moe/asm_moe.cpp.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,18 @@ struct __attribute__((packed)) KernelArgs
};


unsigned char hsaco[{{bin_size}}] = { {{bin_data}} };
const unsigned char hsaco[{{bin_size}}] = { {{bin_data}} };

class FMoeKernel
{
private:
std::unique_ptr<AiterAsmKernelFast> asm_kernel=nullptr;
AiterAsmKernelFast asm_kernel;
uint32_t sub_GU = 512;
bool is_int4 = false;

public:
FMoeKernel()
FMoeKernel() : asm_kernel("{{kernel_name}}", hsaco)
{
asm_kernel=std::make_unique<AiterAsmKernelFast>("{{kernel_name}}", hsaco);
this->sub_GU = {{selected_tile}};
};

Expand Down Expand Up @@ -181,11 +180,11 @@ public:

if constexpr (switchGxy)
{
asm_kernel->launch_kernel({&args, &arg_size, gdy, gdx, gdz, bdx, 1, 1, stream});
asm_kernel.launch_kernel({&args, &arg_size, gdy, gdx, gdz, bdx, 1, 1, stream});
}
else
{
asm_kernel->launch_kernel({&args, &arg_size, gdx, gdy, gdz, bdx, 1, 1, stream});
asm_kernel.launch_kernel({&args, &arg_size, gdx, gdy, gdz, bdx, 1, 1, stream});
}
};
};
Expand Down
198 changes: 133 additions & 65 deletions csrc/include/aiter_hip_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#include <cstdint>
#include <hip/hip_runtime.h>
#include <iostream>
#include <fstream>
#include <mutex>
#include <memory>
#ifdef AITER_EMBEDDED_HSA_HEADER
#include AITER_EMBEDDED_HSA_HEADER
#endif
Expand Down Expand Up @@ -72,55 +75,82 @@ struct AiterAsmKernelArgs

static const std::string get_gpu_arch();

inline void load_asm_kernel(const char* name,
const char* hsaco,
hipModule_t& module,
hipFunction_t& kernel_func)
namespace detail {
struct FatBinaryWrapper
{
const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR");
std::string arch_name = get_gpu_arch();
if(AITER_ASM_DIR != nullptr)
{
std::string hsa_path = std::string(AITER_ASM_DIR) + "/" + arch_name + "/" + hsaco;
AITER_LOG_INFO("hipModuleLoad: " << hsa_path << " GetFunction: " << name);
HIP_CALL(hipModuleLoad(&module, hsa_path.c_str()));
}
else
{
#if defined(AITER_EMBEDDED_HSA_HEADER) && defined(AITER_EMBEDDED_HSA_MAP)
std::string fname = "hsa/" + arch_name + "/" + hsaco;
auto hasco_obj = AITER_EMBEDDED_HSA_MAP.find(fname);
CHECK_COND(hasco_obj != AITER_EMBEDDED_HSA_MAP.end());
CHECK_COND(hasco_obj->second.data() != nullptr);
AITER_LOG_INFO("hipModuleLoad: " << fname << " GetFunction: " << name);
HIP_CALL(hipModuleLoadData(&module, hasco_obj->second.data()));
#endif
}
HIP_CALL(hipModuleGetFunction(&kernel_func, module, name));
AITER_LOG_INFO("hipModuleGetFunction: " << name << " Success");
}
uint32_t magic = 0x48495046; // "HIPF";
uint32_t version = 1;
const void* binary = nullptr;
intptr_t __pad = 0;
};

extern "C" void* __hipRegisterFatBinary(const FatBinaryWrapper* data) noexcept;
extern "C" void __hipUnregisterFatBinary(void* module) noexcept;
extern "C" void __hipRegisterFunction(void* module,
const void* hostFunction,
const char* deviceFunction,
const char* deviceName,
int threadLimit,
void* tid,
void* bid,
void* blockDim,
void* gridDim,
void* wSize) noexcept;
} // namespace detail

class AiterAsmKernel

namespace {

class AiterAsmKernelFast
{
private:
hipModule_t module;
hipFunction_t kernel_func;
void* module = nullptr;

protected:
AiterAsmKernelFast() = default;
void init(const char* kernel_name, const void* hsaco)
{
detail::FatBinaryWrapper fat_bin{};
fat_bin.binary = hsaco;
module = detail::__hipRegisterFatBinary(&fat_bin);
CHECK_COND(module != nullptr);
detail::__hipRegisterFunction(module,
static_cast<void*>(this),
kernel_name,
kernel_name,
-1,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr);
}

public:
AiterAsmKernel(const char* name, const char* hsaco)
AiterAsmKernelFast(const char* kernel_name, const void* hsaco)
{
load_asm_kernel(name, hsaco, module, kernel_func);
init(kernel_name, hsaco);
};

~AiterAsmKernel() { HIP_CALL(hipModuleUnload(module)); }
~AiterAsmKernelFast() { detail::__hipUnregisterFatBinary(module); }

AiterAsmKernelFast(AiterAsmKernelFast&) = delete;
AiterAsmKernelFast(AiterAsmKernelFast&&) = delete;
AiterAsmKernelFast& operator=(AiterAsmKernelFast&) = delete;
AiterAsmKernelFast& operator=(AiterAsmKernelFast&&) = delete;

void launch_kernel(const AiterAsmKernelArgs& kargs)
{
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER,
kargs.args_ptr,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
kargs.arg_size_ptr,
HIP_LAUNCH_PARAM_END};
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER,
kargs.args_ptr,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
kargs.arg_size_ptr,
HIP_LAUNCH_PARAM_END};
hipFunction_t kernel_func = nullptr;
// TODO Ask runtime folks to provide an API for hipLaunchKernel with extra arg
// Don't error check here.
// Failure to load the func would cause hipModuleLaunchKernel to fail anyways.
(void)hipGetFuncBySymbol(&kernel_func, reinterpret_cast<void*>(this));

HIP_CALL(hipModuleLaunchKernel(kernel_func,
kargs.gdx,
Expand All @@ -136,44 +166,58 @@ class AiterAsmKernel
};
};

class AiterAsmKernelFast

class AiterAsmKernel: private AiterAsmKernelFast
{
private:
hipModule_t module;
hipFunction_t kernel_func;
std::unique_ptr<char[]> hsaco_data;

public:
AiterAsmKernelFast(const char* name, void* hsaco)
const void* load_hsaco_file(const char* hsaco_path)
{
HIP_CALL(hipModuleLoadData(&module, hsaco));
HIP_CALL(hipModuleGetFunction(&kernel_func, module, name));
AITER_LOG_INFO("hipModuleGetFunction: " << name << " Success");
};
const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR");
std::string arch_name = get_gpu_arch();
if(AITER_ASM_DIR != nullptr)
{
std::string full_path = std::string(AITER_ASM_DIR) + "/" + arch_name + "/" + hsaco_path;

~AiterAsmKernelFast() { HIP_CALL(hipModuleUnload(module)); }
std::ifstream file(full_path, std::ios::binary | std::ios::ate);

void launch_kernel(const AiterAsmKernelArgs& kargs)
{
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER,
kargs.args_ptr,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
kargs.arg_size_ptr,
HIP_LAUNCH_PARAM_END};
CHECK_COND(file.is_open());

HIP_CALL(hipModuleLaunchKernel(kernel_func,
kargs.gdx,
kargs.gdy,
kargs.gdz,
kargs.bdx,
kargs.bdy,
kargs.bdz,
0,
kargs.stream,
nullptr,
(void**)&config));
size_t file_size = file.tellg();
hsaco_data.reset(new char[file_size]);

file.seekg(0, std::ios::beg);
CHECK_COND(file.read(hsaco_data.get(), file_size));
return hsaco_data.get();
}
else
{
#if defined(AITER_EMBEDDED_HSA_HEADER) && defined(AITER_EMBEDDED_HSA_MAP)
std::string fname = "hsa/" + arch_name + "/" + hsaco;
auto hasco_obj = AITER_EMBEDDED_HSA_MAP.find(fname);
CHECK_COND(hasco_obj != AITER_EMBEDDED_HSA_MAP.end());
CHECK_COND(hasco_obj->second.data() != nullptr);
return hasco_obj->second.data();
#else
CHECK_COND(AITER_ASM_DIR != nullptr);
return nullptr;
#endif
}
}

public:
AiterAsmKernel(const char* kernel_name, const char* hsaco_path)
{
init(kernel_name, load_hsaco_file(hsaco_path));
};

using AiterAsmKernelFast::launch_kernel;
};


} // namespace

static const std::string get_gpu_arch()
{
int device_count;
Expand Down Expand Up @@ -212,3 +256,27 @@ static uint32_t get_num_cu_func()
static const uint32_t num_cu = get_num_cu_local();
return num_cu;
}

template <class Key, class T, class Hash = std::hash<Key>, class KeyEqual = std::equal_to<Key>>
struct SynchronizedCache
{
template <typename K, typename F>
inline T& get_or_create(K&& k, F&& factory)
{
std::lock_guard<std::mutex> map_mu_guard(map_mu);

struct Wrapper
{
F& f;
// Makes usre we only invoke lambda on insert
operator T() && { return f(); }
};

auto [it, _] = map.try_emplace(std::forward<K>(k), Wrapper{factory});
return it->second;
}

private:
std::mutex map_mu;
std::unordered_map<Key, T, Hash, KeyEqual> map;
};
Loading
Loading