diff --git a/cmake/deps.cmake b/cmake/deps.cmake index 0af44562e..7ba5b0267 100644 --- a/cmake/deps.cmake +++ b/cmake/deps.cmake @@ -63,7 +63,7 @@ if(PPLNN_DEP_HPCC_PKG) SUBBUILD_DIR ${HPCC_DEPS_DIR}/hpcc-subbuild) else() if(NOT PPLNN_DEP_HPCC_GIT) - set(PPLNN_DEP_HPCC_GIT "https://github.com/openppl-public/hpcc.git") + set(PPLNN_DEP_HPCC_GIT "https://github.com/OpenPPL/hpcc.git") endif() FetchContent_Declare(hpcc GIT_REPOSITORY ${PPLNN_DEP_HPCC_GIT} @@ -113,7 +113,7 @@ if(PPLNN_DEP_PPLCOMMON_PKG) ${PPLNN_DEP_PPLCOMMON_PKG}) else() if(NOT PPLNN_DEP_PPLCOMMON_GIT) - set(PPLNN_DEP_PPLCOMMON_GIT "https://github.com/openppl-public/ppl.common.git") + set(PPLNN_DEP_PPLCOMMON_GIT "https://github.com/OpenPPL/ppl.common.git") endif() hpcc_declare_git_dep(pplcommon ${PPLNN_DEP_PPLCOMMON_GIT} @@ -257,7 +257,7 @@ if(PPLNN_USE_X86_64 OR PPLNN_USE_AARCH64 OR PPLNN_USE_ARMV7 OR PPLNN_USE_RISCV64 ${PPLNN_DEP_PPLCPUKERNEL_PKG}) else() if(NOT PPLNN_DEP_PPLCPUKERNEL_GIT) - set(PPLNN_DEP_PPLCPUKERNEL_GIT "https://github.com/openppl-public/ppl.kernel.cpu.git") + set(PPLNN_DEP_PPLCPUKERNEL_GIT "https://github.com/OpenPPL/ppl.kernel.cpu.git") endif() hpcc_declare_git_dep(ppl.kernel.cpu ${PPLNN_DEP_PPLCPUKERNEL_GIT} @@ -277,7 +277,7 @@ if(PPLNN_USE_CUDA) ${PPLNN_DEP_PPLCUDAKERNEL_PKG}) else() if(NOT PPLNN_DEP_PPLCUDAKERNEL_GIT) - set(PPLNN_DEP_PPLCUDAKERNEL_GIT "https://github.com/openppl-public/ppl.kernel.cuda.git") + set(PPLNN_DEP_PPLCUDAKERNEL_GIT "https://github.com/OpenPPL/ppl.kernel.cuda.git") endif() hpcc_declare_git_dep(ppl.kernel.cuda ${PPLNN_DEP_PPLCUDAKERNEL_GIT} @@ -296,7 +296,7 @@ if(PPLNN_DEP_PPL_LLM_KERNEL_CUDA_PKG) ${PPLNN_DEP_PPL_LLM_KERNEL_CUDA_PKG}) else() if(NOT PPLNN_DEP_PPL_LLM_KERNEL_CUDA_GIT) - set(PPLNN_DEP_PPL_LLM_KERNEL_CUDA_GIT "https://github.com/openppl-public/ppl.llm.kernel.cuda.git") + set(PPLNN_DEP_PPL_LLM_KERNEL_CUDA_GIT "https://github.com/OpenPPL/ppl.llm.kernel.cuda.git") endif() hpcc_declare_git_dep(ppl.llm.kernel.cuda ${PPLNN_DEP_PPL_LLM_KERNEL_CUDA_GIT} diff --git a/include/ppl/nn/engines/llm_cuda/options.h b/include/ppl/nn/engines/llm_cuda/options.h index a85cca593..9bc0a09b5 100644 --- a/include/ppl/nn/engines/llm_cuda/options.h +++ b/include/ppl/nn/engines/llm_cuda/options.h @@ -63,6 +63,95 @@ enum { ENGINE_CONF_DEBUG_DATA_DIR = 3, ENGINE_CONF_MAX, + + ENGINE_CONF_INTERNAL_BEGIN = 1000, + + /** + @brief uint32_t, set shared memory decoding attention algorithm heuristic(1)/off(0), default is heuristic + This algorithm use sharemem to store softmax logits. And is the fastest algorithm + on decode phase attention, but context length is limited by the size of shared memory + + We must set one of `ENGINE_CONF_DECODING_SHM_MHA`, `ENGINE_CONF_DECODING_INF_MHA` and `ENGINE_CONF_DECODING_INF_GQA` on. + + @note example: + @code{.cpp} + cuda_engine->Configure(ENGINE_CONF_DECODING_SHM_MHA, uint32_t); + @endcode + */ + ENGINE_CONF_DECODING_SHM_MHA = 1000, + + /** + @brief uint32_t, set infinity decoding attention algorithm heuristic(1)/off(0), default is heuristic + This algorithm rescale softmax logits on register. A bit slower than shared memory decoding attention, + but context length has no limit. + + We must set one of `ENGINE_CONF_DECODING_SHM_MHA`, `ENGINE_CONF_DECODING_INF_MHA` and `ENGINE_CONF_DECODING_INF_GQA` on. + + @note example: + @code{.cpp} + cuda_engine->Configure(ENGINE_CONF_DECODING_INF_MHA, uint32_t); + @endcode + */ + ENGINE_CONF_DECODING_INF_MHA = 1001, + + /** + @brief uint32_t, set infinity grouped query decoding attention algorithm heuristic(1)/off(0), default is heuristic + This algorithm rescale softmax logits on shared memory, and optimized by tensor core for grouped query attention. + It could be very fast when decoding batch size is large(usually more than 64). And context length has no limit. + + We must set one of `ENGINE_CONF_DECODING_SHM_MHA`, `ENGINE_CONF_DECODING_INF_MHA` and `ENGINE_CONF_DECODING_INF_GQA` on. + + @note example: + @code{.cpp} + cuda_engine->Configure(ENGINE_CONF_DECODING_INF_GQA, uint32_t); + @endcode + */ + ENGINE_CONF_DECODING_INF_GQA = 1002, + + /** + @brief uint32_t, set split-k decoding attention algorithm always-on(2)/heuristic(1)/off(0), default is heuristic + Apply split-k decoding on all decoding algorithm, accelerating long context decoding. + Recommanded for context length >= 1024, but may slow down when batch size is too large. + And suggest to alway turn it on for context length >= 16k + + + @note example: + @code{.cpp} + cuda_engine->Configure(ENGINE_CONF_DECODING_SPLIT_K, uint32_t); + @endcode + */ + ENGINE_CONF_DECODING_ATTN_SPLIT_K = 1003, + + /** + @brief uint32_t, specify decoding attention kernel threads per block to 512/256/heuristic(0), default is heuristic + Apply split-k decoding on all decoding algorithm, accelerating long context decoding. + Recommanded for context length >= 1024, but may slow down when batch size is too large. + And suggest to alway turn it on for context length >= 16k + + + @note example: + @code{.cpp} + cuda_engine->Configure(ENGINE_CONF_DECODING_SPLIT_K, uint32_t); + @endcode + */ + ENGINE_CONF_DECODING_ATTN_TPB = 1004, + + /** + @brief uint32_t, set prefill flash attention key and value tensor to use int8 kv cache, default is fp16 unquantized + tensor. It is used to support chunked prefill and prompt cache. + + + @note example: + @code{.cpp} + cuda_engine->Configure(ENGINE_CONF_CACHE_PREFILL, uint32_t); + @endcode + */ + ENGINE_CONF_CACHE_PREFILL = 1005, + + // TODO: ENGINE_CONF_CUSTOM_LOGGER = 1006, + + ENGINE_CONF_INTERNAL_MAX, + }; /** @brief memory management policies */ @@ -87,6 +176,9 @@ enum { /** online quantization, fp16 tensor and int4 weight */ QUANT_METHOD_ONLINE_I4F16, + + /** online quantization, fp8 tensor and fp8 weight */ + QUANT_METHOD_ONLINE_F8F8, }; /** @brief cublas layout hint, currently for selecting matrix layout for int8 gemm */ @@ -125,6 +217,8 @@ enum { DEV_CONF_MAX, }; + + }}}} // namespace ppl::nn::llm::cuda #endif diff --git a/src/ppl/nn/engines/llm_cuda/engine.cc b/src/ppl/nn/engines/llm_cuda/engine.cc index 8751b5f41..32b9a7359 100644 --- a/src/ppl/nn/engines/llm_cuda/engine.cc +++ b/src/ppl/nn/engines/llm_cuda/engine.cc @@ -95,7 +95,7 @@ RetCode LlmCudaEngine::ConfSetTensorParellelNcclComm(LlmCudaEngine* engine, va_l engine->tensor_parallel_nccl_param_.comm = nccl_comm; NCCL_CHECK(ncclCommCount(nccl_comm, &engine->tensor_parallel_nccl_param_.size), "ncclCommCount"); NCCL_CHECK(ncclCommUserRank(nccl_comm, &engine->tensor_parallel_nccl_param_.rank), "ncclCommUserRank"); - LOG(INFO) << "Engine Conf tp nccl comm world size: " + LOG(INFO) << "Engine Conf tp nccl comm world size: " << engine->tensor_parallel_nccl_param_.size; return RC_SUCCESS; #else @@ -122,6 +122,53 @@ RetCode LlmCudaEngine::ConfDebugDataDir(LlmCudaEngine* engine, va_list args) { return RC_SUCCESS; } +RetCode LlmCudaEngine::ConfCachePrefill(LlmCudaEngine* engine, va_list args) { + engine->config_.enable_cache_prefill = va_arg(args, uint32_t) ? true : false; + // TODO: Change to Custom logger + LOG(DEBUG) << "Engine Conf cache prefill: " << engine->config_.enable_cache_prefill; + return RC_SUCCESS; +} + +RetCode LlmCudaEngine::ConfDecodingShmMha(LlmCudaEngine* engine, va_list args) { + engine->config_.enable_decoding_sharemem_mhca = va_arg(args, uint32_t) ? true : false; + LOG(INFO) << "Engine Conf decoding shared memory mhca: " << engine->config_.enable_decoding_sharemem_mhca; + return RC_SUCCESS; +} + +RetCode LlmCudaEngine::ConfDecodingInfMha(LlmCudaEngine* engine, va_list args) { + engine->config_.enable_decoding_infinity_mhca = va_arg(args, uint32_t) ? true : false; + LOG(INFO) << "Engine Conf decoding infinity mhca: " << engine->config_.enable_decoding_infinity_mhca; + return RC_SUCCESS; +} + +RetCode LlmCudaEngine::ConfDecodingInfGqa(LlmCudaEngine* engine, va_list args) { + engine->config_.enable_decoding_infinity_gqca = va_arg(args, uint32_t) ? true : false; + LOG(INFO) << "Engine Conf decoding infinity gqca: " << engine->config_.enable_decoding_infinity_gqca; + return RC_SUCCESS; +} + +RetCode LlmCudaEngine::ConfDecodingAttnSplitK(LlmCudaEngine* engine, va_list args) { + uint32_t split_k = va_arg(args, uint32_t); + if (split_k != 0 && split_k != 1 && split_k != 2) { + LOG(ERROR) << "ENGINE_CONF_DECODING_ATTN_SPLIT_K only accept 0/1/2 but get " << split_k; + return ppl::common::RC_INVALID_VALUE; + } + engine->config_.specify_decoding_attn_split_k = split_k; + LOG(INFO) << "Engine Conf decoding attention split k: " << engine->config_.specify_decoding_attn_split_k; + return RC_SUCCESS; +} + +RetCode LlmCudaEngine::ConfDecodingAttnTpb(LlmCudaEngine* engine, va_list args) { + uint32_t tpb = va_arg(args, uint32_t); + if (tpb != 0 && tpb != 256 && tpb != 512) { + LOG(ERROR) << "ENGINE_CONF_DECODING_ATTN_TPB only accept 0/256/512 but get " << tpb; + return ppl::common::RC_INVALID_VALUE; + } + engine->config_.specify_decoding_attn_tpb = tpb; + LOG(INFO) << "Engine Conf decoding attention tpb: " << engine->config_.specify_decoding_attn_tpb; + return RC_SUCCESS; +} + #ifdef PPLNN_ENABLE_PMX_MODEL RetCode LlmCudaEngine::LoadConstants(const ConstantVisitor& visitor, map* eid2info) { return utils::LoadConstants(visitor, device_.get(), eid2info); @@ -156,13 +203,13 @@ ppl::common::RetCode LlmCudaEngine::SerializeData(const pmx::SerializationContex ppl::common::RetCode LlmCudaEngine::DeserializeData(const void* base, uint64_t size) { auto fb_engine_param = GetEngineParam(base); auto fb_param = fb_engine_param->value_as_EngineOptionsParam(); - + uint32_t cublas_layout_hint = fb_param->cublas_layout_hint(); if (cublas_layout_hint != options_.cublas_layout_hint) { LOG(WARNING) << "deserialize cublas_layout_hint[" << cublas_layout_hint << "] diff from user input[" << options_.cublas_layout_hint << "]"; } options_.cublas_layout_hint = cublas_layout_hint; - + if (fb_param->version() != GetVersion()) { LOG(WARNING) << "engine version[" << GetVersion() << "] diff from pmx version[" << fb_param->version() << "]"; } @@ -176,16 +223,28 @@ LlmCudaEngine::ConfHandlerFunc LlmCudaEngine::conf_handlers_[] = { ConfGraphFusion, ConfTenosrDebug, ConfDebugDataDir, + + ConfDecodingShmMha, + ConfDecodingInfMha, + ConfDecodingInfGqa, + ConfDecodingAttnSplitK, + ConfDecodingAttnTpb, + + ConfCachePrefill, }; RetCode LlmCudaEngine::Configure(uint32_t option, ...) { - if (option >= ENGINE_CONF_MAX) { - LOG(ERROR) << "invalid option[" << option << "] >= [" << (uint32_t)ENGINE_CONF_MAX << "]"; + auto conf_length = sizeof(conf_handlers_) / sizeof(ConfHandlerFunc); + auto uniform_option = option >= ENGINE_CONF_INTERNAL_BEGIN ? + option + ENGINE_CONF_MAX - ENGINE_CONF_INTERNAL_BEGIN : + option; + if (uniform_option >= conf_length) { + LOG(ERROR) << "invalid option[" << option << "]"; return RC_INVALID_VALUE; } va_list args; va_start(args, option); - auto status = conf_handlers_[option](this, args); + auto status = conf_handlers_[uniform_option](this, args); va_end(args); return status; diff --git a/src/ppl/nn/engines/llm_cuda/engine.h b/src/ppl/nn/engines/llm_cuda/engine.h index f4ebd7574..ca132cc5e 100644 --- a/src/ppl/nn/engines/llm_cuda/engine.h +++ b/src/ppl/nn/engines/llm_cuda/engine.h @@ -54,8 +54,17 @@ class LlmCudaEngine final : public EngineImpl { static ppl::common::RetCode ConfTenosrDebug(LlmCudaEngine*, va_list); static ppl::common::RetCode ConfDebugDataDir(LlmCudaEngine*, va_list); + static ppl::common::RetCode ConfDecodingShmMha(LlmCudaEngine*, va_list); + static ppl::common::RetCode ConfDecodingInfMha(LlmCudaEngine*, va_list); + static ppl::common::RetCode ConfDecodingInfGqa(LlmCudaEngine*, va_list); + static ppl::common::RetCode ConfDecodingAttnSplitK(LlmCudaEngine*, va_list); + static ppl::common::RetCode ConfDecodingAttnTpb(LlmCudaEngine*, va_list); + + static ppl::common::RetCode ConfCachePrefill(LlmCudaEngine*, va_list); + typedef ppl::common::RetCode (*ConfHandlerFunc)(LlmCudaEngine*, va_list); - static ConfHandlerFunc conf_handlers_[ENGINE_CONF_MAX]; + static ConfHandlerFunc conf_handlers_[ + ENGINE_CONF_MAX + (ENGINE_CONF_INTERNAL_MAX - ENGINE_CONF_INTERNAL_BEGIN)]; private: EngineOptions options_; diff --git a/src/ppl/nn/engines/llm_cuda/engine_config.h b/src/ppl/nn/engines/llm_cuda/engine_config.h index 26e1c411a..152873614 100644 --- a/src/ppl/nn/engines/llm_cuda/engine_config.h +++ b/src/ppl/nn/engines/llm_cuda/engine_config.h @@ -26,6 +26,13 @@ struct EngineConfig final { bool enable_graph_fusion = true; bool enable_tensor_debug = false; std::string debug_data_dir = "."; + + bool enable_cache_prefill = false; + bool enable_decoding_sharemem_mhca = true; + bool enable_decoding_infinity_mhca = true; + bool enable_decoding_infinity_gqca = true; + int32_t specify_decoding_attn_split_k = 1; + int32_t specify_decoding_attn_tpb = 0; }; }}}} // namespace ppl::nn::llm::cuda diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/dynamic_batching/multi_head_cache_attention_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/dynamic_batching/multi_head_cache_attention_kernel.cc index 4b14bc77d..eb2cdfb51 100644 --- a/src/ppl/nn/engines/llm_cuda/kernels/opmx/dynamic_batching/multi_head_cache_attention_kernel.cc +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/dynamic_batching/multi_head_cache_attention_kernel.cc @@ -16,11 +16,8 @@ // under the License. #include "multi_head_cache_attention_kernel.h" - #include "ppl/common/destructor.h" -#include "ppl/kernel/llm/cuda/pmx/multi_head_cache_attention.h" - namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { @@ -95,11 +92,6 @@ ppl::common::RetCode DynamicBatchingMultiHeadCacheAttentionKernel::DoExecute(Ker return ppl::common::RC_UNSUPPORTED; } - if (param_->cache_mode != 0) { - LOG(ERROR) << "currently only support cache_mode == 0"; - return ppl::common::RC_UNSUPPORTED; - } - if (param_->is_alibi) { LOG(ERROR) << "currently only support is_alibi == false"; return ppl::common::RC_UNSUPPORTED; @@ -166,8 +158,16 @@ ppl::common::RetCode DynamicBatchingMultiHeadCacheAttentionKernel::DoExecute(Ker return ppl::common::RC_UNSUPPORTED; } - auto p_ret = ppl::kernel::llm::cuda::pmx::dynamic_batching_multi_head_cache_attention_prepare( - GetStream(), + int64_t cachestart_stride_b = 0; + if (param_->cache_mode == 1) { + if (cachestarts->GetShape()->GetDimCount() != 2) { + LOG(ERROR) << "cachestarts must be a 2d tensor in cache_mode 1"; + return ppl::common::RC_INVALID_VALUE; + } + cachestart_stride_b = cachestarts->GetShape()->GetDim(1); + } + + auto ret = attn_kernel_.heuristic_prepare( GetCudaDevice()->GetDeviceProp(), query->GetShape(), query->GetBufferPtr(), @@ -181,6 +181,7 @@ ppl::common::RetCode DynamicBatchingMultiHeadCacheAttentionKernel::DoExecute(Ker kvstarts->GetBufferPtr(), cachestarts->GetBufferPtr(), start_pos->GetBufferPtr(), + nullptr, param_->is_causal, batch, decodeing_batches_val, @@ -192,38 +193,41 @@ ppl::common::RetCode DynamicBatchingMultiHeadCacheAttentionKernel::DoExecute(Ker param_->num_kv_heads, param_->head_dim, param_->cache_mode, + param_->page_size, cache_stride_s, cache_stride_l, cache_stride_h, cache_stride_kv, + cachestart_stride_b, + GetEngineConfig().enable_cache_prefill, + GetEngineConfig().enable_decoding_sharemem_mhca, + GetEngineConfig().enable_decoding_infinity_mhca, + GetEngineConfig().enable_decoding_infinity_gqca, + GetEngineConfig().specify_decoding_attn_split_k, + GetEngineConfig().specify_decoding_attn_tpb, cache->GetBufferPtr(), scale->GetBufferPtr(), attn_output->GetShape(), attn_output->GetBufferPtr() ); - if (p_ret.first != ppl::common::RC_SUCCESS) { - return p_ret.first; + if (ret != ppl::common::RC_SUCCESS) { + return ret; } - auto &cfg = p_ret.second; - BufferDesc tmpbuffer_desc; - auto status = GetCudaDevice()->AllocTmpBuffer(cfg.temp_buffer_size, &tmpbuffer_desc); + auto status = GetCudaDevice()->AllocTmpBuffer(attn_kernel_.cfg.workspace_size, &tmpbuffer_desc); if (status != ppl::common::RC_SUCCESS) { - LOG(ERROR) << "alloc tmp buffer size[" << cfg.temp_buffer_size << "] for kernel[" << GetName() + LOG(ERROR) << "alloc tmp buffer size[" << attn_kernel_.cfg.workspace_size << "] for kernel[" << GetName() << "] failed: " << ppl::common::GetRetCodeStr(status); return status; } ppl::common::Destructor multi_block_tmpbuffer_guard([this, &tmpbuffer_desc]() -> void { GetCudaDevice()->FreeTmpBuffer(&tmpbuffer_desc); }); - cfg.temp_buffer = tmpbuffer_desc.addr; + attn_kernel_.cfg.workspace = tmpbuffer_desc.addr; - return ppl::kernel::llm::cuda::pmx::dynamic_batching_multi_head_cache_attention( - GetStream(), - cfg - ); + return attn_kernel_.forward(GetStream()); } }}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/dynamic_batching/multi_head_cache_attention_kernel.h b/src/ppl/nn/engines/llm_cuda/kernels/opmx/dynamic_batching/multi_head_cache_attention_kernel.h index 23965bd71..6360e0179 100644 --- a/src/ppl/nn/engines/llm_cuda/kernels/opmx/dynamic_batching/multi_head_cache_attention_kernel.h +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/dynamic_batching/multi_head_cache_attention_kernel.h @@ -21,6 +21,8 @@ #include "ppl/nn/engines/llm_cuda/kernel.h" #include "ppl/nn/params/opmx/multi_head_cache_attention_param.h" +#include "ppl/kernel/llm/cuda/pmx/multi_head_cache_attention.h" + namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { class DynamicBatchingMultiHeadCacheAttentionKernel : public LlmCudaKernel { @@ -36,6 +38,7 @@ class DynamicBatchingMultiHeadCacheAttentionKernel : public LlmCudaKernel { private: const ppl::nn::opmx::MultiHeadCacheAttentionParam* param_ = nullptr; + ppl::kernel::llm::cuda::pmx::dynamic_batching_multi_head_cache_attention attn_kernel_; }; }}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/column_parallel_linear_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/column_parallel_linear_kernel.cc new file mode 100644 index 000000000..11b3e3398 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/column_parallel_linear_kernel.cc @@ -0,0 +1,125 @@ +#ifdef PPLNN_ENABLE_FP8 + +#include "column_parallel_linear_kernel.h" + +#include "ppl/common/cuda/nccl_utils.h" +#include "ppl/common/destructor.h" + +#include "ppl/kernel/llm/cuda/pmx/f8f8/column_parallel_linear.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + + +ppl::common::RetCode F8F8ColumnParallelLinearKernel::DoExecute(KernelExecContext* ctx) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Entry LlmCudaKernel: [%s]\n", GetName().c_str()); + + PPLNN_LLM_CUDA_REQUIRED_INPUT(input, 0); + PPLNN_LLM_CUDA_REQUIRED_INPUT(weight, 1); + PPLNN_LLM_CUDA_OPTIONAL_INPUT(bias, 2); + PPLNN_LLM_CUDA_REQUIRED_OUTPUT(output, 0); + + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [weight]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(weight); + if (bias) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [bias]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(bias); + } + + PPLNN_LLM_CUDA_DEBUG_TRACE("in_features: %d\n", param_->in_features); + PPLNN_LLM_CUDA_DEBUG_TRACE("out_features: %d\n", param_->out_features); + PPLNN_LLM_CUDA_DEBUG_TRACE("bias_term: %d\n", param_->bias_term); + PPLNN_LLM_CUDA_DEBUG_TRACE("gather_output: %d\n", param_->gather_output); + + PPLNN_LLM_CUDA_RESHAPE_OUTPUTS(); + + PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(output); + PPLNN_LLM_CUDA_DEBUG_TRACE("Output [output]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output); + + auto input_shape = input->GetShape(); + auto weight_shape = weight->GetShape(); + auto output_shape = output->GetShape(); + + TensorShape *bias_shape = nullptr; + void *bias_data = nullptr; + if (param_->bias_term) { + if (!bias) { + LOG(ERROR) << "bias_term == true but bias not found."; + return ppl::common::RC_NOT_FOUND; + } + bias_shape = bias->GetShape(); + bias_data = bias->GetBufferPtr(); + + if (bias_shape->GetDataType() != ppl::common::DATATYPE_FLOAT16) { + LOG(ERROR) << "currently only support fp16 bias"; + return ppl::common::RC_UNSUPPORTED; + } + } + + if (input_shape->GetDataType() != ppl::common::DATATYPE_FLOAT8E4M3) { + LOG(ERROR) << "only support float8 input"; + return ppl::common::RC_UNSUPPORTED; + } + if (weight_shape->GetDataType() != ppl::common::DATATYPE_FLOAT8E4M3) { + LOG(ERROR) << "only support float8 weight"; + return ppl::common::RC_UNSUPPORTED; + } + + auto cublas_handle = GetCublasHandle(); + auto nccl_param = GetTensorParallelNcclParam(); + + uint64_t gather_buffer_size = 0; + void *gather_buffer = nullptr; + if (param_->gather_output && nccl_param->size > 1) { + gather_buffer_size = output_shape->CalcBytesIncludingPadding(); + } + + BufferDesc tmp_buffer_desc; + auto status = GetCudaDevice()->AllocTmpBuffer(gather_buffer_size, &tmp_buffer_desc); + if (status != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "alloc tmp buffer size[" << gather_buffer_size << "] for kernel[" << GetName() + << "] failed: " << ppl::common::GetRetCodeStr(status); + return status; + } + ppl::common::Destructor __tmp_buffer_guard([this, &tmp_buffer_desc]() -> void { + GetCudaDevice()->FreeTmpBuffer(&tmp_buffer_desc); + }); + gather_buffer = tmp_buffer_desc.addr; + + status = ppl::kernel::llm::cuda::pmx::f8f8::column_parallel_linear( + GetStream(), + cublas_handle, + nullptr, + input_shape, + input->GetBufferPtr(), + weight_shape, + weight->GetBufferPtr(), + bias_shape, + bias_data, + param_->in_features, + param_->out_features, + nccl_param, + param_->gather_output, + gather_buffer, + GetCudaDevice()->GetCublasWorkspaceSize(), + GetCudaDevice()->GetCublasWorkspace(), + output_shape, + output->GetBufferPtr() + ); + + if (status != ppl::common::RC_SUCCESS) { + return status; + } + + if (input_shape->GetPadding1(0) > 0) { + output_shape->SetPadding1(0, 0); + } + + return ppl::common::RC_SUCCESS; +} + +}}}}} // namespace ppl::nn::llm::cuda::pmx + +#endif \ No newline at end of file diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/column_parallel_linear_kernel.h b/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/column_parallel_linear_kernel.h new file mode 100644 index 000000000..4213ffca9 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/column_parallel_linear_kernel.h @@ -0,0 +1,29 @@ +#ifdef PPLNN_ENABLE_FP8 + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_PMX_F8F8_COLUMN_PARALLEL_LINEAR_KERNEL_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_PMX_F8F8_COLUMN_PARALLEL_LINEAR_KERNEL_H_ + +#include "ppl/nn/engines/llm_cuda/kernel.h" +#include "ppl/nn/params/opmx/column_parallel_linear_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class F8F8ColumnParallelLinearKernel : public LlmCudaKernel { +public: + F8F8ColumnParallelLinearKernel(const ir::Node* node) : LlmCudaKernel(node) {} + + void SetParam(const ppl::nn::opmx::ColumnParallelLinearParam* p) { + param_ = p; + } + +private: + ppl::common::RetCode DoExecute(KernelExecContext*) override; + +private: + const ppl::nn::opmx::ColumnParallelLinearParam* param_ = nullptr; +}; + +}}}}} // namespace ppl::nn::llm::cuda::pmx + +#endif +#endif diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/online_cast_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/online_cast_kernel.cc new file mode 100644 index 000000000..970cd4a77 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/online_cast_kernel.cc @@ -0,0 +1,45 @@ +#ifdef PPLNN_ENABLE_FP8 + +#include "online_cast_kernel.h" + +#include "ppl/kernel/llm/cuda/pmx/f8f8/cast.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +ppl::common::RetCode F8F8OnlineCastKernel::DoExecute(KernelExecContext* ctx) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Entry LlmCudaKernel: [%s]\n", GetName().c_str()); + + PPLNN_LLM_CUDA_REQUIRED_INPUT(input, 0); + PPLNN_LLM_CUDA_REQUIRED_OUTPUT(output, 0); + + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input); + + PPLNN_LLM_CUDA_RESHAPE_OUTPUTS(); + + PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(output); + PPLNN_LLM_CUDA_DEBUG_TRACE("Output [output]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output); + + + if (input->GetShape()->GetDataType() != ppl::common::DATATYPE_FLOAT16) { + LOG(ERROR) << "currently only support fp16 input"; + return ppl::common::RC_UNSUPPORTED; + } + + const auto dim_count = input->GetShape()->GetDimCount(); + const int64_t quant_dim = input->GetShape()->GetDim(dim_count - 1); + const int64_t batch = input->GetShape()->CalcElementsToDimensionIncludingPadding(dim_count - 1); + + return ppl::kernel::llm::cuda::pmx::f8f8::cast_fp16( + GetStream(), + input->GetBufferPtr(), + batch, + quant_dim, + output->GetBufferPtr() + ); +} + +}}}}} // namespace ppl::nn::llm::cuda::pmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/online_cast_kernel.h b/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/online_cast_kernel.h new file mode 100644 index 000000000..0834e0273 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/online_cast_kernel.h @@ -0,0 +1,22 @@ +#ifdef PPLNN_ENABLE_FP8 + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_PMX_F8F8_ONLINE_CAST_KERNEL_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_PMX_F8F8_ONLINE_CAST_KERNEL_H_ + +#include "ppl/nn/engines/llm_cuda/kernel.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class F8F8OnlineCastKernel : public LlmCudaKernel { +public: + F8F8OnlineCastKernel(const ir::Node* node) : LlmCudaKernel(node) {} + +private: + ppl::common::RetCode DoExecute(KernelExecContext*) override; + +}; + +}}}}} // namespace ppl::nn::llm::cuda::pmx + +#endif +#endif \ No newline at end of file diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/row_parallel_linear_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/row_parallel_linear_kernel.cc new file mode 100644 index 000000000..028cbcbf5 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/row_parallel_linear_kernel.cc @@ -0,0 +1,106 @@ +#ifdef PPLNN_ENABLE_FP8 + +#include "row_parallel_linear_kernel.h" + +#include "ppl/common/cuda/nccl_utils.h" +#include "ppl/common/destructor.h" + +#include "ppl/kernel/llm/cuda/pmx/f8f8/row_parallel_linear.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +ppl::common::RetCode F8F8RowParallelLinearKernel::DoExecute(KernelExecContext* ctx) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Entry LlmCudaKernel: [%s]\n", GetName().c_str()); + + PPLNN_LLM_CUDA_REQUIRED_INPUT(input, 0); + PPLNN_LLM_CUDA_REQUIRED_INPUT(weight, 1); + PPLNN_LLM_CUDA_OPTIONAL_INPUT(bias, 2); + PPLNN_LLM_CUDA_REQUIRED_OUTPUT(output, 0); + + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [weight]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(weight); + if (bias) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [bias]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(bias); + } + + PPLNN_LLM_CUDA_DEBUG_TRACE("in_features: %d\n", param_->in_features); + PPLNN_LLM_CUDA_DEBUG_TRACE("out_features: %d\n", param_->out_features); + PPLNN_LLM_CUDA_DEBUG_TRACE("bias_term: %d\n", param_->bias_term); + PPLNN_LLM_CUDA_DEBUG_TRACE("input_is_parallel: %d\n", param_->input_is_parallel); + + PPLNN_LLM_CUDA_RESHAPE_OUTPUTS(); + + PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(output); + PPLNN_LLM_CUDA_DEBUG_TRACE("Output [output]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output); + + auto input_shape = input->GetShape(); + auto weight_shape = weight->GetShape(); + auto output_shape = output->GetShape(); + + TensorShape *bias_shape = nullptr; + void *bias_data = nullptr; + if (param_->bias_term) { + if (!bias) { + LOG(ERROR) << "bias_term == true but bias not found."; + return ppl::common::RC_NOT_FOUND; + } + bias_shape = bias->GetShape(); + bias_data = bias->GetBufferPtr(); + + if (bias_shape->GetDataType() != ppl::common::DATATYPE_FLOAT16) { + LOG(ERROR) << "currently only support fp16 bias"; + return ppl::common::RC_UNSUPPORTED; + } + } + + if (input_shape->GetDataType() != ppl::common::DATATYPE_FLOAT8E4M3) { + LOG(ERROR) << "only support float8 input"; + return ppl::common::RC_UNSUPPORTED; + } + if (weight_shape->GetDataType() != ppl::common::DATATYPE_FLOAT8E4M3) { + LOG(ERROR) << "only support float8 weight"; + return ppl::common::RC_UNSUPPORTED; + } + + auto cublas_handle = GetCublasHandle(); + auto nccl_param = GetTensorParallelNcclParam(); + + auto status = ppl::kernel::llm::cuda::pmx::f8f8::row_parallel_linear( + GetStream(), + cublas_handle, + nullptr, + input_shape, + input->GetBufferPtr(), + weight_shape, + weight->GetBufferPtr(), + bias_shape, + bias_data, + param_->in_features, + param_->out_features, + nccl_param, + param_->input_is_parallel, + nullptr, + GetCudaDevice()->GetCublasWorkspaceSize(), + GetCudaDevice()->GetCublasWorkspace(), + output_shape, + output->GetBufferPtr() + ); + + if (status != ppl::common::RC_SUCCESS) { + return status; + } + + if (input_shape->GetPadding1(0) > 0) { + output_shape->SetPadding1(0, 0); + } + + return ppl::common::RC_SUCCESS; +} + + +}}}}} // namespace ppl::nn::llm::cuda::pmx +#endif \ No newline at end of file diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/row_parallel_linear_kernel.h b/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/row_parallel_linear_kernel.h new file mode 100644 index 000000000..ddc0944cf --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/row_parallel_linear_kernel.h @@ -0,0 +1,29 @@ +#ifdef PPLNN_ENABLE_FP8 + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_PMX_F8F8_ROW_PARALLEL_LINEAR_KERNEL_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_PMX_F8F8_ROW_PARALLEL_LINEAR_KERNEL_H_ + +#include "ppl/nn/engines/llm_cuda/kernel.h" +#include "ppl/nn/params/opmx/row_parallel_linear_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class F8F8RowParallelLinearKernel : public LlmCudaKernel { +public: + F8F8RowParallelLinearKernel(const ir::Node* node) : LlmCudaKernel(node) {} + + void SetParam(const ppl::nn::opmx::RowParallelLinearParam* p) { + param_ = p; + } + +private: + ppl::common::RetCode DoExecute(KernelExecContext*) override; + +private: + const ppl::nn::opmx::RowParallelLinearParam* param_ = nullptr; +}; + +}}}}} // namespace ppl::nn::llm::cuda + +#endif +#endif diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/column_parallel_linear_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/column_parallel_linear_kernel.cc new file mode 100644 index 000000000..621e65e0e --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/column_parallel_linear_kernel.cc @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "column_parallel_linear_kernel.h" + +#include "ppl/common/cuda/nccl_utils.h" +#include "ppl/common/destructor.h" + +#include "ppl/kernel/llm/cuda/pmx/i4f16/column_parallel_linear.h" +#include "ppl/kernel/llm/cuda/pmx/column_parallel_linear.h" +#include "ppl/kernel/llm/cuda/pmx/i4f16/quantize.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + + +ppl::common::RetCode I4F16ColumnParallelLinearKernel::DoExecute(KernelExecContext* ctx) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Entry LlmCudaKernel: [%s]\n", GetName().c_str()); + + PPLNN_LLM_CUDA_REQUIRED_INPUT(input, 0); + PPLNN_LLM_CUDA_REQUIRED_INPUT(weight, 1); + PPLNN_LLM_CUDA_REQUIRED_INPUT(weight_scale, 2); + PPLNN_LLM_CUDA_OPTIONAL_INPUT(bias, 3); + PPLNN_LLM_CUDA_REQUIRED_OUTPUT(output, 0); + + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [weight]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(weight); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [weight_scale]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(weight_scale); + if (bias) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [bias]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(bias); + } + + PPLNN_LLM_CUDA_DEBUG_TRACE("in_features: %d\n", param_->in_features); + PPLNN_LLM_CUDA_DEBUG_TRACE("out_features: %d\n", param_->out_features); + PPLNN_LLM_CUDA_DEBUG_TRACE("bias_term: %d\n", param_->bias_term); + PPLNN_LLM_CUDA_DEBUG_TRACE("gather_output: %d\n", param_->gather_output); + + PPLNN_LLM_CUDA_RESHAPE_OUTPUTS(); + + PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(output); + PPLNN_LLM_CUDA_DEBUG_TRACE("Output [output]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output); + + auto input_shape = input->GetShape(); + auto weight_shape = weight->GetShape(); + auto output_shape = output->GetShape(); + + if (weight_shape->GetDataType() != ppl::common::DATATYPE_INT4X4) { + LOG(ERROR) << "currently only support int4x4 weight"; + return ppl::common::RC_UNSUPPORTED; + } + + TensorShape *bias_shape = nullptr; + void *bias_data = nullptr; + if (param_->bias_term) { + if (!bias) { + LOG(ERROR) << "bias_term == true but bias not found."; + return ppl::common::RC_NOT_FOUND; + } + bias_shape = bias->GetShape(); + bias_data = bias->GetBufferPtr(); + } + + if (ppl::common::DATATYPE_FLOAT16 != input_shape->GetDataType()) { + LOG(ERROR) << "currently only support fp16"; + return ppl::common::RC_UNSUPPORTED; + } + + auto nccl_param = GetTensorParallelNcclParam(); + + const int64_t M = input_shape->CalcElementsToDimensionExcludingPadding(input_shape->GetDimCount() - 1); + bool use_fp16_gemm = false; + { + const int32_t sm_count = GetCudaDevice()->GetDeviceProp().multiProcessorCount; + if (sm_count >= 96) { + use_fp16_gemm = M >= 768; + } else if (sm_count >= 48) { + use_fp16_gemm = M >= 512; + } else { + use_fp16_gemm = M >= 256; + } + } + + uint64_t dequant_weight_buffer_size = 0; + void *dequant_weight_buffer = nullptr; + if (use_fp16_gemm) { + dequant_weight_buffer_size = weight_shape->CalcElementsExcludingPadding() * sizeof(int16_t) * 4; + } + + uint64_t gather_buffer_size = 0; + void *gather_buffer = nullptr; + if (param_->gather_output && nccl_param->size > 1) { + gather_buffer_size = output_shape->CalcBytesExcludingPadding(); + } + + const int64_t tmp_buffer_size = dequant_weight_buffer_size + gather_buffer_size; + BufferDesc tmp_buffer_desc; + auto status = GetCudaDevice()->AllocTmpBuffer(tmp_buffer_size, &tmp_buffer_desc); + if (status != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "alloc tmp buffer size[" << tmp_buffer_size << "] for kernel[" << GetName() + << "] failed: " << ppl::common::GetRetCodeStr(status); + return status; + } + ppl::common::Destructor __tmp_buffer_guard([this, &tmp_buffer_desc]() -> void { + GetCudaDevice()->FreeTmpBuffer(&tmp_buffer_desc); + }); + gather_buffer = tmp_buffer_desc.addr; + + if (use_fp16_gemm) { + dequant_weight_buffer = (int8_t*)gather_buffer + gather_buffer_size; + + auto dequant_weight_shape = *weight_shape; + dequant_weight_shape.SetDim(0, weight_shape->GetDim(0) * 4); + dequant_weight_shape.SetDataType(weight_scale->GetShape()->GetDataType()); + + auto rc = ppl::kernel::llm::cuda::pmx::i4f16::minmax_dequantize_fp16( + GetStream(), + weight->GetBufferPtr(), + weight_scale->GetBufferPtr(), + dequant_weight_shape.GetDim(0), + dequant_weight_shape.GetDim(1), + 128, + dequant_weight_buffer + ); + if (ppl::common::RC_SUCCESS != rc) { + return rc; + } + const bool use_workspace = GetCudaDevice()->GetSMVersion() >= 90 && M >= 64; + return ppl::kernel::llm::cuda::pmx::column_parallel_linear( + GetStream(), + GetCublasHandle(), + nullptr, + input_shape, + input->GetBufferPtr(), + &dequant_weight_shape, + dequant_weight_buffer, + bias_shape, + bias_data, + param_->in_features, + param_->out_features, + nccl_param, + param_->gather_output, + gather_buffer, + use_workspace ? GetCudaDevice()->GetCublasWorkspaceSize() : 0, + use_workspace ? GetCudaDevice()->GetCublasWorkspace() : nullptr, + output_shape, + output->GetBufferPtr() + ); + } else { + return ppl::kernel::llm::cuda::pmx::i4f16::column_parallel_linear( + GetStream(), + GetCudaDevice()->GetI4F16GemmHandle(), + input_shape, + input->GetBufferPtr(), + weight_shape, + weight->GetBufferPtr(), + weight_scale->GetBufferPtr(), + bias_shape, + bias_data, + param_->in_features, + param_->out_features, + nccl_param, + param_->gather_output, + gather_buffer, + GetCudaDevice()->GetCublasWorkspaceSize(), + GetCudaDevice()->GetCublasWorkspace(), + output_shape, + output->GetBufferPtr() + ); + } +} + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/column_parallel_linear_kernel.h b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/column_parallel_linear_kernel.h new file mode 100644 index 000000000..3d2b2bb6d --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/column_parallel_linear_kernel.h @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_I4F16_COLUMN_PARALLEL_LINEAR_KERNEL_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_I4F16_COLUMN_PARALLEL_LINEAR_KERNEL_H_ + +#include "ppl/nn/engines/llm_cuda/kernel.h" +#include "ppl/nn/params/opmx/column_parallel_linear_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class I4F16ColumnParallelLinearKernel : public LlmCudaKernel { +public: + I4F16ColumnParallelLinearKernel(const ir::Node* node) : LlmCudaKernel(node) {} + + void SetParam(const ppl::nn::opmx::ColumnParallelLinearParam* p) { + param_ = p; + } + +private: + ppl::common::RetCode DoExecute(KernelExecContext*) override; + +private: + const ppl::nn::opmx::ColumnParallelLinearParam* param_ = nullptr; +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/row_parallel_linear_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/row_parallel_linear_kernel.cc new file mode 100644 index 000000000..7d8add011 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/row_parallel_linear_kernel.cc @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "row_parallel_linear_kernel.h" + +#include "ppl/common/cuda/nccl_utils.h" +#include "ppl/common/destructor.h" + +#include "ppl/kernel/llm/cuda/pmx/i4f16/row_parallel_linear.h" +#include "ppl/kernel/llm/cuda/pmx/row_parallel_linear.h" +#include "ppl/kernel/llm/cuda/pmx/i4f16/quantize.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +ppl::common::RetCode I4F16RowParallelLinearKernel::DoExecute(KernelExecContext* ctx) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Entry LlmCudaKernel: [%s]\n", GetName().c_str()); + + PPLNN_LLM_CUDA_REQUIRED_INPUT(input, 0); + PPLNN_LLM_CUDA_REQUIRED_INPUT(weight, 1); + PPLNN_LLM_CUDA_REQUIRED_INPUT(weight_scale, 2); + PPLNN_LLM_CUDA_OPTIONAL_INPUT(bias, 3); + PPLNN_LLM_CUDA_REQUIRED_OUTPUT(output, 0); + + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [weight]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(weight); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [weight_scale]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(weight_scale); + if (bias) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [bias]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(bias); + } + + PPLNN_LLM_CUDA_DEBUG_TRACE("in_features: %d\n", param_->in_features); + PPLNN_LLM_CUDA_DEBUG_TRACE("out_features: %d\n", param_->out_features); + PPLNN_LLM_CUDA_DEBUG_TRACE("bias_term: %d\n", param_->bias_term); + PPLNN_LLM_CUDA_DEBUG_TRACE("input_is_parallel: %d\n", param_->input_is_parallel); + + PPLNN_LLM_CUDA_RESHAPE_OUTPUTS(); + + PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(output); + PPLNN_LLM_CUDA_DEBUG_TRACE("Output [output]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output); + + auto input_shape = input->GetShape(); + auto weight_shape = weight->GetShape(); + auto output_shape = output->GetShape(); + + if (weight_shape->GetDataType() != ppl::common::DATATYPE_INT4X4) { + LOG(ERROR) << "currently only support int4x4 weight"; + return ppl::common::RC_UNSUPPORTED; + } + + TensorShape *bias_shape = nullptr; + void *bias_data = nullptr; + if (param_->bias_term) { + if (!bias) { + LOG(ERROR) << "bias_term == true but bias not found."; + return ppl::common::RC_NOT_FOUND; + } + bias_shape = bias->GetShape(); + bias_data = bias->GetBufferPtr(); + } + + if (ppl::common::DATATYPE_FLOAT16 != input_shape->GetDataType()) { + LOG(ERROR) << "currently only support fp16"; + return ppl::common::RC_UNSUPPORTED; + } + + auto nccl_param = GetTensorParallelNcclParam(); + + const int64_t M = input_shape->CalcElementsToDimensionExcludingPadding(input_shape->GetDimCount() - 1); + bool use_fp16_gemm = false; + { + const int32_t sm_count = GetCudaDevice()->GetDeviceProp().multiProcessorCount; + if (sm_count >= 96) { + use_fp16_gemm = M >= 768; + } else if (sm_count >= 48) { + use_fp16_gemm = M >= 512; + } else { + use_fp16_gemm = M >= 256; + } + } + + if (use_fp16_gemm) { + uint64_t dequant_weight_buffer_size = 0; + void *dequant_weight_buffer = nullptr; + if (use_fp16_gemm) { + dequant_weight_buffer_size = weight_shape->CalcElementsExcludingPadding() * sizeof(int16_t) * 4; + } + + const int64_t tmp_buffer_size = dequant_weight_buffer_size; + BufferDesc tmp_buffer_desc; + auto status = GetCudaDevice()->AllocTmpBuffer(tmp_buffer_size, &tmp_buffer_desc); + if (status != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "alloc tmp buffer size[" << tmp_buffer_size << "] for kernel[" << GetName() + << "] failed: " << ppl::common::GetRetCodeStr(status); + return status; + } + ppl::common::Destructor __tmp_buffer_guard([this, &tmp_buffer_desc]() -> void { + GetCudaDevice()->FreeTmpBuffer(&tmp_buffer_desc); + }); + dequant_weight_buffer = tmp_buffer_desc.addr; + + auto dequant_weight_shape = *weight_shape; + dequant_weight_shape.SetDim(0, weight_shape->GetDim(0) * 4); + dequant_weight_shape.SetDataType(weight_scale->GetShape()->GetDataType()); + + auto rc = ppl::kernel::llm::cuda::pmx::i4f16::minmax_dequantize_fp16( + GetStream(), + weight->GetBufferPtr(), + weight_scale->GetBufferPtr(), + dequant_weight_shape.GetDim(0), + dequant_weight_shape.GetDim(1), + 128, + dequant_weight_buffer + ); + if (ppl::common::RC_SUCCESS != rc) { + return rc; + } + const bool use_workspace = GetCudaDevice()->GetSMVersion() >= 90 && M >= 64; + return ppl::kernel::llm::cuda::pmx::row_parallel_linear( + GetStream(), + GetCublasHandle(), + nullptr, + input_shape, + input->GetBufferPtr(), + &dequant_weight_shape, + dequant_weight_buffer, + bias_shape, + bias_data, + param_->in_features, + param_->out_features, + nccl_param, + param_->input_is_parallel, + nullptr, + use_workspace ? GetCudaDevice()->GetCublasWorkspaceSize() : 0, + use_workspace ? GetCudaDevice()->GetCublasWorkspace() : nullptr, + output_shape, + output->GetBufferPtr() + ); + } else { + return ppl::kernel::llm::cuda::pmx::i4f16::row_parallel_linear( + GetStream(), + GetCudaDevice()->GetI4F16GemmHandle(), + input_shape, + input->GetBufferPtr(), + weight_shape, + weight->GetBufferPtr(), + weight_scale->GetBufferPtr(), + bias_shape, + bias_data, + param_->in_features, + param_->out_features, + nccl_param, + param_->input_is_parallel, + nullptr, + GetCudaDevice()->GetCublasWorkspaceSize(), + GetCudaDevice()->GetCublasWorkspace(), + output_shape, + output->GetBufferPtr() + ); + } +} + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/row_parallel_linear_kernel.h b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/row_parallel_linear_kernel.h new file mode 100644 index 000000000..1608c3f1f --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/row_parallel_linear_kernel.h @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_I4F16_ROW_PARALLEL_LINEAR_KERNEL_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_I4F16_ROW_PARALLEL_LINEAR_KERNEL_H_ + +#include "ppl/nn/engines/llm_cuda/kernel.h" +#include "ppl/nn/params/opmx/row_parallel_linear_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class I4F16RowParallelLinearKernel : public LlmCudaKernel { +public: + I4F16RowParallelLinearKernel(const ir::Node* node) : LlmCudaKernel(node) {} + + void SetParam(const ppl::nn::opmx::RowParallelLinearParam* p) { + param_ = p; + } + +private: + ppl::common::RetCode DoExecute(KernelExecContext*) override; + +private: + const ppl::nn::opmx::RowParallelLinearParam* param_ = nullptr; +}; + +}}}}} // namespace ppl::nn::llm::cuda + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_reshape_split_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_reshape_split_kernel.cc new file mode 100644 index 000000000..2fdf55149 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_reshape_split_kernel.cc @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "online_dequantize_reshape_split_kernel.h" + +#include "ppl/kernel/llm/cuda/pmx/i8i8/quantize.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +ppl::common::RetCode I8I8OnlineDequantizeReshapeSplitKernel::DoExecute(KernelExecContext* ctx) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Entry LlmCudaKernel: [%s]\n", GetName().c_str()); + + PPLNN_LLM_CUDA_REQUIRED_INPUT(input, 0); + PPLNN_LLM_CUDA_REQUIRED_INPUT(scale_outer, 1); + PPLNN_LLM_CUDA_REQUIRED_INPUT(scale_inner, 2); + PPLNN_LLM_CUDA_OPTIONAL_INPUT(bias, 3); + + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [scale_outer]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(scale_outer); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [scale_inner]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(scale_inner); + if (bias) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [bias]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(bias); + } + + PPLNN_LLM_CUDA_DEBUG_TRACE("bias_term: %d\n", param_->bias_term); + PPLNN_LLM_CUDA_DEBUG_TRACE("axis: %d\n", param_->split_param->axis); + PPLNN_LLM_CUDA_DEBUG_TRACE("split:\n"); + int64_t split_dim = 0; + for (size_t i = 0; i < param_->split.size(); ++i) { + PPLNN_LLM_CUDA_DEBUG_TRACE(" %ld\n", param_->split[i]); + split_dim += param_->split[i]; + } + PPLNN_LLM_CUDA_DEBUG_TRACE("shape:\n"); + for (size_t i = 0; i < param_->shape.size(); ++i) { + PPLNN_LLM_CUDA_DEBUG_TRACE(" %ld\n", param_->shape[i]); + } + + PPLNN_LLM_CUDA_RESHAPE_OUTPUTS(); + + if (param_->split.size() != 3) { + LOG(ERROR) << "only support 3 split point"; + return ppl::common::RC_UNSUPPORTED; + } + + if (scale_outer->GetShape()->GetDataType() != scale_inner->GetShape()->GetDataType()) { + LOG(ERROR) << "datatype of scale_outer must be equal to datatype of scale_inner: " + << ppl::common::GetDataTypeStr(scale_outer->GetShape()->GetDataType()) << " vs. " + << ppl::common::GetDataTypeStr(scale_inner->GetShape()->GetDataType()); + return ppl::common::RC_INVALID_VALUE; + } + + if (scale_outer->GetShape()->GetDataType() != ppl::common::DATATYPE_FLOAT16) { + LOG(ERROR) << "currently only support dequantize to fp16"; + return ppl::common::RC_UNSUPPORTED; + } + + if (input->GetShape()->GetDataType() != ppl::common::DATATYPE_INT32) { + LOG(ERROR) << "currently only support dequantize int32 data"; + return ppl::common::RC_UNSUPPORTED; + } + + TensorShape *bias_shape = nullptr; + void *bias_data = nullptr; + if (param_->bias_term) { + if (!bias) { + LOG(ERROR) << "bias_term == true but bias not found."; + return ppl::common::RC_NOT_FOUND; + } + bias_shape = bias->GetShape(); + bias_data = bias->GetBufferPtr(); + + if (bias_shape->GetDataType() != ppl::common::DATATYPE_FLOAT16) { + LOG(ERROR) << "currently only support fp16 bias"; + return ppl::common::RC_UNSUPPORTED; + } + } + + std::vector dst_datas(ctx->GetOutputCount()); + + for (uint32_t i = 0; i < ctx->GetOutputCount(); ++i) { + auto output = ctx->GetOutput(i); + PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(output); + PPLNN_LLM_CUDA_DEBUG_TRACE("Output [outputs[%u]]:\n", i); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output); + dst_datas[i] = output->GetBufferPtr(); + } + + const int64_t batch = scale_outer->GetShape()->CalcElementsIncludingPadding(); + const int64_t quant_dim = scale_inner->GetShape()->CalcElementsIncludingPadding(); + const int64_t total_elem = input->GetShape()->CalcElementsIncludingPadding(); + if (total_elem != batch * quant_dim) { + LOG(ERROR) << "input.numel must be equal to scale_outer.numel * scale_inner.numel): " + << batch << " * " << quant_dim << " != " << total_elem; + return ppl::common::RC_INVALID_VALUE; + } + if (quant_dim % split_dim != 0) { + LOG(ERROR) << "channel must be divided by sum(split): " + << quant_dim << " % " << split_dim << " = " << quant_dim % split_dim; + return ppl::common::RC_INVALID_VALUE; + } + const int64_t split_inner_dim = quant_dim / split_dim; + + auto from_layout = GetEngineOptions().cublas_layout_hint == CUBLAS_LAYOUT_AMPERE + ? ppl::kernel::llm::cuda::MATRIX_LAYOUT_COL32 + : ppl::kernel::llm::cuda::MATRIX_LAYOUT_ROW_MAJOR; + + return ppl::kernel::llm::cuda::pmx::i8i8::minmax_dequantize_split3_fp16( + GetStream(), + input->GetBufferPtr(), + bias_data, + scale_outer->GetBufferPtr(), + scale_inner->GetBufferPtr(), + batch, + quant_dim, + param_->split[0] * split_inner_dim, + param_->split[1] * split_inner_dim, + param_->split[2] * split_inner_dim, + ppl::kernel::llm::cuda::pmx::i8i8::token_down_scale, + ppl::kernel::llm::cuda::pmx::i8i8::hidden_down_scale, + from_layout, + dst_datas[0], + dst_datas[1], + dst_datas[2] + ); +} + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_reshape_split_kernel.h b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_reshape_split_kernel.h new file mode 100644 index 000000000..de33cb77b --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_reshape_split_kernel.h @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_I8I8_ONLINE_DEQUANTIZE_RESHAPE_SPLIT_KERNEL_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_I8I8_ONLINE_DEQUANTIZE_RESHAPE_SPLIT_KERNEL_H_ + +#include "ppl/nn/engines/llm_cuda/kernel.h" +#include "ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_reshape_split_op.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class I8I8OnlineDequantizeReshapeSplitKernel : public LlmCudaKernel { +public: + I8I8OnlineDequantizeReshapeSplitKernel(const ir::Node* node) : LlmCudaKernel(node) {} + + void SetParam(const I8I8OnlineDequantizeReshapeSplitOp::Param* p) { + param_ = p; + } + +private: + ppl::common::RetCode DoExecute(KernelExecContext*) override; + +private: + const I8I8OnlineDequantizeReshapeSplitOp::Param* param_ = nullptr; +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_silu_quantize_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_silu_quantize_kernel.cc new file mode 100644 index 000000000..f6db8997d --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_silu_quantize_kernel.cc @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "online_dequantize_silu_quantize_kernel.h" + +#include "ppl/kernel/llm/cuda/pmx/i8i8/quantize.h" + +#include "ppl/common/destructor.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +ppl::common::RetCode I8I8OnlineDequantizeSiLUQuantizeKernel::DoExecute(KernelExecContext* ctx) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Entry LlmCudaKernel: [%s]\n", GetName().c_str()); + + PPLNN_LLM_CUDA_REQUIRED_INPUT(input, 0); + PPLNN_LLM_CUDA_REQUIRED_INPUT(input_scale_outer, 1); + PPLNN_LLM_CUDA_REQUIRED_INPUT(input_scale_inner, 2); + PPLNN_LLM_CUDA_OPTIONAL_INPUT(gate, 3); + PPLNN_LLM_CUDA_OPTIONAL_INPUT(gate_scale_outer, 4); + PPLNN_LLM_CUDA_OPTIONAL_INPUT(gate_scale_inner, 5); + + PPLNN_LLM_CUDA_REQUIRED_OUTPUT(output, 0); + PPLNN_LLM_CUDA_REQUIRED_OUTPUT(scale, 1); + + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input_scale_outer]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input_scale_outer); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input_scale_inner]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input_scale_inner); + if (gate) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [gate]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(gate); + } + if (gate_scale_outer) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [gate_scale_outer]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(gate_scale_outer); + } + if (gate_scale_inner) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [gate_scale_inner]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(gate_scale_inner); + } + + PPLNN_LLM_CUDA_RESHAPE_OUTPUTS(); + + if (input_scale_outer->GetShape()->GetDataType() != input_scale_inner->GetShape()->GetDataType()) { + LOG(ERROR) << "datatype of scale_outer must be equal to datatype of scale_inner: " + << ppl::common::GetDataTypeStr(input_scale_outer->GetShape()->GetDataType()) << " vs. " + << ppl::common::GetDataTypeStr(input_scale_inner->GetShape()->GetDataType()); + return ppl::common::RC_INVALID_VALUE; + } + + if (input_scale_outer->GetShape()->GetDataType() != ppl::common::DATATYPE_FLOAT16) { + LOG(ERROR) << "currently only support dequantize to fp16"; + return ppl::common::RC_UNSUPPORTED; + } + + if (input->GetShape()->GetDataType() != ppl::common::DATATYPE_INT32) { + LOG(ERROR) << "currently only support dequantize int32 data"; + return ppl::common::RC_UNSUPPORTED; + } + + if (gate) { + if (gate_scale_outer == nullptr) { + LOG(ERROR) << "scale_outer of gate not found"; + return ppl::common::RC_NOT_FOUND; + } + + if (gate_scale_inner == nullptr) { + LOG(ERROR) << "scale_inner of gate not found"; + return ppl::common::RC_NOT_FOUND; + } + + if (gate_scale_outer->GetShape()->GetDataType() != gate_scale_inner->GetShape()->GetDataType()) { + LOG(ERROR) << "datatype of scale_outer must be equal to datatype of scale_inner: " + << ppl::common::GetDataTypeStr(gate_scale_outer->GetShape()->GetDataType()) << " vs. " + << ppl::common::GetDataTypeStr(gate_scale_inner->GetShape()->GetDataType()); + return ppl::common::RC_INVALID_VALUE; + } + + if (gate_scale_outer->GetShape()->GetDataType() != ppl::common::DATATYPE_FLOAT16) { + LOG(ERROR) << "currently only support dequantize to fp16"; + return ppl::common::RC_UNSUPPORTED; + } + + if (gate->GetShape()->GetDataType() != ppl::common::DATATYPE_INT32) { + LOG(ERROR) << "currently only support dequantize int32 data"; + return ppl::common::RC_UNSUPPORTED; + } + } + + PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(output); + PPLNN_LLM_CUDA_DEBUG_TRACE("Output [output]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output); + + PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(scale); + PPLNN_LLM_CUDA_DEBUG_TRACE("Output [scale]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(scale); + + const int64_t batch = input_scale_outer->GetShape()->CalcElementsIncludingPadding(); + const int64_t quant_dim = input_scale_inner->GetShape()->CalcElementsIncludingPadding(); + const int64_t total_elem = input->GetShape()->CalcElementsIncludingPadding(); + if (total_elem != batch * quant_dim) { + LOG(ERROR) << "input.numel must be equal to scale_outer.numel * scale_inner.numel): " + << batch << " * " << quant_dim << " != " << total_elem; + return ppl::common::RC_INVALID_VALUE; + } + + uint64_t dequant_buffer_size = input->GetShape()->CalcElementsIncludingPadding() * sizeof(int16_t); + void *dequant_buffer = nullptr; + + BufferDesc tmp_buffer_desc; + auto status = GetCudaDevice()->AllocTmpBuffer(dequant_buffer_size, &tmp_buffer_desc); + if (status != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "alloc tmp buffer size[" << dequant_buffer_size << "] for kernel[" << GetName() + << "] failed: " << ppl::common::GetRetCodeStr(status); + return status; + } + ppl::common::Destructor __tmp_buffer_guard([this, &tmp_buffer_desc]() -> void { + GetCudaDevice()->FreeTmpBuffer(&tmp_buffer_desc); + }); + dequant_buffer = tmp_buffer_desc.addr; + + auto tensor_layout = GetEngineOptions().cublas_layout_hint == CUBLAS_LAYOUT_AMPERE + ? ppl::kernel::llm::cuda::MATRIX_LAYOUT_COL32 + : ppl::kernel::llm::cuda::MATRIX_LAYOUT_ROW_MAJOR; + + return ppl::kernel::llm::cuda::pmx::i8i8::minmax_requantize_silu_fp16( + GetStream(), + input->GetBufferPtr(), + input_scale_outer->GetBufferPtr(), + input_scale_inner->GetBufferPtr(), + gate ? gate->GetBufferPtr() : nullptr, + gate ? gate_scale_outer->GetBufferPtr() : nullptr, + gate ? gate_scale_inner->GetBufferPtr() : nullptr, + batch, + quant_dim, + tensor_layout, + ppl::kernel::llm::cuda::pmx::i8i8::token_up_scale, + ppl::kernel::llm::cuda::pmx::i8i8::token_down_scale, + ppl::kernel::llm::cuda::pmx::i8i8::hidden_down_scale, + dequant_buffer, + output->GetBufferPtr(), + scale->GetBufferPtr() + ); +} + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_silu_quantize_kernel.h b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_silu_quantize_kernel.h new file mode 100644 index 000000000..1497a9ca0 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_silu_quantize_kernel.h @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_I8I8_ONLINE_DEQUANTIZE_SILU_QUANTIZE_KERNEL_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_I8I8_ONLINE_DEQUANTIZE_SILU_QUANTIZE_KERNEL_H_ + +#include "ppl/nn/engines/llm_cuda/kernel.h" +#include "ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_op.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class I8I8OnlineDequantizeSiLUQuantizeKernel : public LlmCudaKernel { +public: + I8I8OnlineDequantizeSiLUQuantizeKernel(const ir::Node* node) : LlmCudaKernel(node) {} + +private: + ppl::common::RetCode DoExecute(KernelExecContext*) override; +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_swiglu_quantize_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_swiglu_quantize_kernel.cc new file mode 100644 index 000000000..bc5bc0a94 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_swiglu_quantize_kernel.cc @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "online_dequantize_swiglu_quantize_kernel.h" + +#include "ppl/kernel/llm/cuda/pmx/i8i8/quantize.h" + +#include "ppl/common/destructor.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +ppl::common::RetCode I8I8OnlineDequantizeSwiGLUQuantizeKernel::DoExecute(KernelExecContext* ctx) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Entry LlmCudaKernel: [%s]\n", GetName().c_str()); + + PPLNN_LLM_CUDA_REQUIRED_INPUT(input, 0); + PPLNN_LLM_CUDA_REQUIRED_INPUT(input_scale_outer, 1); + PPLNN_LLM_CUDA_REQUIRED_INPUT(input_scale_inner, 2); + + PPLNN_LLM_CUDA_REQUIRED_OUTPUT(output, 0); + PPLNN_LLM_CUDA_REQUIRED_OUTPUT(scale, 1); + + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input_scale_outer]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input_scale_outer); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input_scale_inner]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input_scale_inner); + + PPLNN_LLM_CUDA_DEBUG_TRACE("beta: %f\n", param_->beta); + + PPLNN_LLM_CUDA_RESHAPE_OUTPUTS(); + + if (input_scale_outer->GetShape()->GetDataType() != input_scale_inner->GetShape()->GetDataType()) { + LOG(ERROR) << "datatype of scale_outer must be equal to datatype of scale_inner: " + << ppl::common::GetDataTypeStr(input_scale_outer->GetShape()->GetDataType()) << " vs. " + << ppl::common::GetDataTypeStr(input_scale_inner->GetShape()->GetDataType()); + return ppl::common::RC_INVALID_VALUE; + } + + if (input_scale_outer->GetShape()->GetDataType() != ppl::common::DATATYPE_FLOAT16) { + LOG(ERROR) << "currently only support dequantize to fp16"; + return ppl::common::RC_UNSUPPORTED; + } + + if (input->GetShape()->GetDataType() != ppl::common::DATATYPE_INT32) { + LOG(ERROR) << "currently only support dequantize int32 data"; + return ppl::common::RC_UNSUPPORTED; + } + + PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(output); + PPLNN_LLM_CUDA_DEBUG_TRACE("Output [output]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output); + + PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(scale); + PPLNN_LLM_CUDA_DEBUG_TRACE("Output [scale]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(scale); + + const int64_t batch = input_scale_outer->GetShape()->CalcElementsIncludingPadding(); + const int64_t dequant_dim = input_scale_inner->GetShape()->CalcElementsIncludingPadding(); + const int64_t total_elem = input->GetShape()->CalcElementsIncludingPadding(); + if (total_elem != batch * dequant_dim) { + LOG(ERROR) << "input.numel must be equal to scale_outer.numel * scale_inner.numel): " + << batch << " * " << dequant_dim << " != " << total_elem; + return ppl::common::RC_INVALID_VALUE; + } + + uint64_t dequant_buffer_size = output->GetShape()->CalcElementsIncludingPadding() * sizeof(int16_t); + void *dequant_buffer = nullptr; + + BufferDesc tmp_buffer_desc; + auto status = GetCudaDevice()->AllocTmpBuffer(dequant_buffer_size, &tmp_buffer_desc); + if (status != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "alloc tmp buffer size[" << dequant_buffer_size << "] for kernel[" << GetName() + << "] failed: " << ppl::common::GetRetCodeStr(status); + return status; + } + ppl::common::Destructor __tmp_buffer_guard([this, &tmp_buffer_desc]() -> void { + GetCudaDevice()->FreeTmpBuffer(&tmp_buffer_desc); + }); + dequant_buffer = tmp_buffer_desc.addr; + + auto tensor_layout = GetEngineOptions().cublas_layout_hint == CUBLAS_LAYOUT_AMPERE + ? ppl::kernel::llm::cuda::MATRIX_LAYOUT_COL32 + : ppl::kernel::llm::cuda::MATRIX_LAYOUT_ROW_MAJOR; + + const int64_t dim_count = output->GetShape()->GetDimCount(); + const int64_t quant_dim = output->GetShape()->GetDim(dim_count - 1); + return ppl::kernel::llm::cuda::pmx::i8i8::minmax_requantize_swiglu_fp16( + GetStream(), + input->GetBufferPtr(), + input_scale_outer->GetBufferPtr(), + input_scale_inner->GetBufferPtr(), + batch, + quant_dim, + param_->beta, + tensor_layout, + ppl::kernel::llm::cuda::pmx::i8i8::token_up_scale, + ppl::kernel::llm::cuda::pmx::i8i8::token_down_scale, + ppl::kernel::llm::cuda::pmx::i8i8::hidden_down_scale, + dequant_buffer, + output->GetBufferPtr(), + scale->GetBufferPtr() + ); +} + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_swiglu_quantize_kernel.h b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_swiglu_quantize_kernel.h new file mode 100644 index 000000000..56311585d --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_swiglu_quantize_kernel.h @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_I8I8_ONLINE_DEQUANTIZE_SWIGLU_QUANTIZE_KERNEL_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_I8I8_ONLINE_DEQUANTIZE_SWIGLU_QUANTIZE_KERNEL_H_ + +#include "ppl/nn/engines/llm_cuda/kernel.h" +#include "ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_op.h" +#include "ppl/nn/params/pmx/swish_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class I8I8OnlineDequantizeSwiGLUQuantizeKernel : public LlmCudaKernel { +public: + I8I8OnlineDequantizeSwiGLUQuantizeKernel(const ir::Node* node) : LlmCudaKernel(node) {} + + void SetParam(const ppl::nn::pmx::SwishParam* p) { + param_ = p; + } + +private: + ppl::common::RetCode DoExecute(KernelExecContext*) override; + + const ppl::nn::pmx::SwishParam* param_ = nullptr; +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/pixel_unshuffle_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/pixel_unshuffle_kernel.cc new file mode 100644 index 000000000..a09da60a7 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/pixel_unshuffle_kernel.cc @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "pixel_unshuffle_kernel.h" + +#include "ppl/common/destructor.h" + +#include "ppl/kernel/llm/cuda/pmx/pixel_unshuffle.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +ppl::common::RetCode PixelUnshuffleKernel::DoExecute(KernelExecContext* ctx) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Entry LlmCudaKernel: [%s]\n", GetName().c_str()); + + PPLNN_LLM_CUDA_REQUIRED_INPUT(input, 0); + PPLNN_LLM_CUDA_REQUIRED_OUTPUT(output, 0); + + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input); + + PPLNN_LLM_CUDA_DEBUG_TRACE("scale_factor: %d\n", param_->scale_factor); + PPLNN_LLM_CUDA_DEBUG_TRACE("data_layout: %d\n", param_->data_layout); + + PPLNN_LLM_CUDA_RESHAPE_OUTPUTS(); + + PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(output); + PPLNN_LLM_CUDA_DEBUG_TRACE("Output [output]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output); + + auto input_shape = input->GetShape(); + + if (ppl::common::DATATYPE_FLOAT16 != input_shape->GetDataType()) { + LOG(ERROR) << "currently only support fp16"; + return ppl::common::RC_UNSUPPORTED; + } + + if (param_->data_layout != param_->DATA_LAYOUT_NHWC) { + LOG(ERROR) << "currently only support scaling_type == 'nhwc'"; + return ppl::common::RC_UNSUPPORTED; + } + + return ppl::kernel::llm::cuda::pmx::pixel_unshuffle( + GetStream(), + input_shape, + input->GetBufferPtr(), + param_->scale_factor, + output->GetBufferPtr() + ); + +} + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/pixel_unshuffle_kernel.h b/src/ppl/nn/engines/llm_cuda/kernels/opmx/pixel_unshuffle_kernel.h new file mode 100644 index 000000000..a57d16f4b --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/pixel_unshuffle_kernel.h @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_PIXEL_UNSHUFFLE_KERNEL_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_PIXEL_UNSHUFFLE_KERNEL_H_ + +#include "ppl/nn/engines/llm_cuda/kernel.h" +#include "ppl/nn/params/opmx/pixel_unshuffle_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class PixelUnshuffleKernel : public LlmCudaKernel { +public: + PixelUnshuffleKernel(const ir::Node* node) : LlmCudaKernel(node) {} + + void SetParam(const ppl::nn::opmx::PixelUnshuffleParam* p) { + param_ = p; + } + +private: + ppl::common::RetCode DoExecute(KernelExecContext*) override; + +private: + const ppl::nn::opmx::PixelUnshuffleParam* param_ = nullptr; +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/tensor_parallel_rms_norm_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/tensor_parallel_rms_norm_kernel.cc new file mode 100644 index 000000000..45330a7a1 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/tensor_parallel_rms_norm_kernel.cc @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "tensor_parallel_rms_norm_kernel.h" + +#include "ppl/kernel/llm/cuda/pmx/tensor_parallel_rms_norm.h" +#include "ppl/common/destructor.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + + +ppl::common::RetCode TensorParallelRMSNormKernel::DoExecute(KernelExecContext* ctx) { + PPLNN_LLM_CUDA_DEBUG_TRACE("Entry LlmCudaKernel: [%s]\n", GetName().c_str()); + + PPLNN_LLM_CUDA_REQUIRED_INPUT(input, 0); + PPLNN_LLM_CUDA_REQUIRED_INPUT(weight, 1); + + PPLNN_LLM_CUDA_REQUIRED_OUTPUT(output, 0); + + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [input]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(input); + PPLNN_LLM_CUDA_DEBUG_TRACE("Input [weight]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(weight); + + PPLNN_LLM_CUDA_DEBUG_TRACE("eps: %f\n", param_->eps); + PPLNN_LLM_CUDA_DEBUG_TRACE("axis: %d\n", param_->axis); + PPLNN_LLM_CUDA_DEBUG_TRACE("scale: %d\n", param_->scale); + + PPLNN_LLM_CUDA_RESHAPE_OUTPUTS(); + + auto input_shape = input->GetShape(); + + if (param_->axis != -1 && param_->axis != input_shape->GetDim(input_shape->GetDimCount() - 1)) { + LOG(ERROR) << "currently only support axis == -1 or input's last dim."; + return ppl::common::RC_UNSUPPORTED; + } + + bool can_trans_input = ctx->IsLastConsumerOfInput(0) && input->GetType() == TENSORTYPE_NORMAL; + + auto input_data = input->GetBufferPtr(); + if (can_trans_input) { + output->TransferBufferFrom(input); + } else { + PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(output); + } + PPLNN_LLM_CUDA_DEBUG_TRACE("Output [output]:\n"); + PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output); + + if (ppl::common::DATATYPE_FLOAT16 != input->GetShape()->GetDataType()) { + LOG(ERROR) << "currently only support fp16"; + return ppl::common::RC_UNSUPPORTED; + } + + const int64_t dim_count = input_shape->GetDimCount(); + const int64_t real_axis = param_->axis > 0 ? param_->axis : (param_->axis + dim_count); + + const int64_t batch = input_shape->CalcElementsToDimensionIncludingPadding(real_axis); + const int64_t norm_dim = input_shape->CalcElementsFromDimensionIncludingPadding(real_axis); + + int workspace_size = batch * sizeof(float); + BufferDesc tmpbuffer_desc; + auto status = GetCudaDevice()->AllocTmpBuffer(workspace_size, &tmpbuffer_desc); + if (status != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "alloc tmp buffer size[" << workspace_size << "] for kernel[" << GetName() + << "] failed: " << ppl::common::GetRetCodeStr(status); + return status; + } + ppl::common::Destructor tp_pow_sum_tmpbuffer_guard([this, &tmpbuffer_desc]() -> void { + GetCudaDevice()->FreeTmpBuffer(&tmpbuffer_desc); + }); + void* tp_pow_sum = tmpbuffer_desc.addr; + + auto nccl_param = GetTensorParallelNcclParam(); + + return ppl::kernel::llm::cuda::pmx::tensor_parallel_rms_norm_fp16( + GetStream(), + input_data, + weight->GetBufferPtr(), + param_->eps, + param_->scale, + batch, + norm_dim, + nccl_param, + tp_pow_sum, + output->GetBufferPtr() + ); +} + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/tensor_parallel_rms_norm_kernel.h b/src/ppl/nn/engines/llm_cuda/kernels/opmx/tensor_parallel_rms_norm_kernel.h new file mode 100644 index 000000000..13aa529f8 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/tensor_parallel_rms_norm_kernel.h @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_TENSOR_PARALLEL_RMS_NORM_KERNEL_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_KERNELS_OPMX_TENSOR_PARALLEL_RMS_NORM_KERNEL_H_ + +#include "ppl/nn/engines/llm_cuda/kernel.h" +#include "ppl/nn/params/opmx/tensor_parallel_rms_norm_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class TensorParallelRMSNormKernel : public LlmCudaKernel { +public: + TensorParallelRMSNormKernel(const ir::Node* node) : LlmCudaKernel(node) {} + + void SetParam(const ppl::nn::opmx::TensorParallelRMSNormParam* p) { + param_ = p; + } + +private: + ppl::common::RetCode DoExecute(KernelExecContext*) override; + +private: + const ppl::nn::opmx::TensorParallelRMSNormParam* param_ = nullptr; +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/llm_cuda_device.cc b/src/ppl/nn/engines/llm_cuda/llm_cuda_device.cc index ae3654431..947a777cd 100644 --- a/src/ppl/nn/engines/llm_cuda/llm_cuda_device.cc +++ b/src/ppl/nn/engines/llm_cuda/llm_cuda_device.cc @@ -19,6 +19,7 @@ #include "ppl/common/cuda/cuda_env.h" #include #include +#include #include "ppl/kernel/llm/cuda/pmx/i4f16/gemm.h" @@ -36,6 +37,10 @@ LlmCudaDevice::LlmCudaDevice() { } LlmCudaDevice::~LlmCudaDevice() { + for (auto alibi_slopes_pair : alibi_slopes_map_) { + cudaFree(alibi_slopes_pair.second); + } + if (own_stream_ && stream_) { cudaStreamSynchronize(stream_); cudaStreamDestroy(stream_); @@ -85,6 +90,12 @@ RetCode LlmCudaDevice::Init(int device_id, bool init_cublas_cudnn, NcclParam* te return RC_INTERNAL_ERROR; } + i4f16_gemm_handle_ = ppl::kernel::llm::cuda::pmx::i4f16::create_gemm_handle(); + if (i4f16_gemm_handle_ == nullptr) { + LOG(ERROR) << "pmx::i4f16::create_gemm_handle failed."; + return RC_INTERNAL_ERROR; + } + /* refer to https://developer.nvidia.com/blog/new-cublas-12-0-features-and-matrix-multiplication-performance-on-nvidia-hopper-gpus/ NV said: NVIDIA Hopper architecture workspace requirements @@ -302,6 +313,52 @@ RetCode LlmCudaDevice::Synchronize() { return RC_SUCCESS; } +std::pair LlmCudaDevice::GetAlibiSlopes(int32_t num_heads) { + auto iter = alibi_slopes_map_.find(num_heads); + if (iter != alibi_slopes_map_.end()) + return {RC_SUCCESS, iter->second}; + + std::vector alibi_slopes; + alibi_slopes.reserve(num_heads); + + int32_t closest_power_of_2 = 1 << (int32_t)floor(log2(num_heads)); + for (int32_t i = 1; i < closest_power_of_2 + 1; ++i) { + alibi_slopes.push_back(exp2f(-8.0f * i / closest_power_of_2)); + LOG(INFO) << "alibi_slope of head(" << alibi_slopes.size() - 1 << ") = " << alibi_slopes.back(); + } + if (closest_power_of_2 < num_heads) { + for (int32_t i = 1; i < 2 * (num_heads - closest_power_of_2) + 1; i += 2) { + alibi_slopes.push_back(exp2f(-4.0f * i / closest_power_of_2)); + LOG(INFO) << "alibi_slope of head(" << alibi_slopes.size() - 1 << ") = " << alibi_slopes.back(); + } + } + + float* cu_alibi_slopes = nullptr; + auto cu_ret = cudaMallocAsync(&cu_alibi_slopes, num_heads * sizeof(float), GetStream()); + if (cu_ret != cudaSuccess) { + LOG(ERROR) << "cudaMallocAsync failed: " << (int)cu_ret << ", " << cudaGetErrorString(cu_ret); + return {RC_DEVICE_RUNTIME_ERROR, nullptr}; + } + + cu_ret = cudaMemcpyAsync( + cu_alibi_slopes, alibi_slopes.data(), + num_heads * sizeof(float), + cudaMemcpyHostToDevice, GetStream()); + if (cu_ret != cudaSuccess) { + LOG(ERROR) << "cudaMemcpyAsync H2D failed: " << (int)cu_ret << ", " << cudaGetErrorString(cu_ret); + return {RC_DEVICE_RUNTIME_ERROR, nullptr}; + } + + cu_ret = cudaStreamSynchronize(GetStream()); + if (cu_ret != cudaSuccess) { + LOG(ERROR) << "cudaStreamSynchronize failed: " << (int)cu_ret << ", " << cudaGetErrorString(cu_ret); + return {RC_DEVICE_RUNTIME_ERROR, nullptr}; + } + + alibi_slopes_map_.insert(std::make_pair(num_heads, cu_alibi_slopes)); + return {RC_SUCCESS, cu_alibi_slopes}; +} + /* ------------------------------------------------------------------------- */ RetCode LlmCudaDevice::ConfGetDeviceId(LlmCudaDevice* dev, va_list args) { diff --git a/src/ppl/nn/engines/llm_cuda/llm_cuda_device.h b/src/ppl/nn/engines/llm_cuda/llm_cuda_device.h index 70fc021ed..bc7ae0d70 100644 --- a/src/ppl/nn/engines/llm_cuda/llm_cuda_device.h +++ b/src/ppl/nn/engines/llm_cuda/llm_cuda_device.h @@ -34,6 +34,7 @@ typedef void* cudnnHandle_t; #endif #include +#include namespace ppl { namespace nn { namespace llm { namespace cuda { @@ -121,6 +122,8 @@ class LlmCudaDevice : public Device { return i4f16_gemm_handle_; } + std::pair GetAlibiSlopes(int32_t num_heads); + ppl::kernel::llm::cuda::cublas::AlgoCache* GetCublasAlgoCache() { return &cublas_algo_cache_; } @@ -129,6 +132,10 @@ class LlmCudaDevice : public Device { return &cublas_algo_cache_; } + void SetCudnnHandle(cudnnHandle_t cudnn_handle) { + cudnn_handle_ = cudnn_handle; + }; + cudnnHandle_t GetCudnnHandle() const { return cudnn_handle_; } @@ -191,6 +198,9 @@ class LlmCudaDevice : public Device { cudnnHandle_t cudnn_handle_ = nullptr; void* i4f16_gemm_handle_ = nullptr; + + // num_heads -> cuda raw ptr(alloc by cuda malloc async) + std::map alibi_slopes_map_; }; }}}} // namespace ppl::nn::llm::cuda diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/column_parallel_linear_op.cc b/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/column_parallel_linear_op.cc new file mode 100644 index 000000000..c9a8372de --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/column_parallel_linear_op.cc @@ -0,0 +1,63 @@ +#ifdef PPLNN_ENABLE_FP8 + +#include "column_parallel_linear_op.h" + +#include "ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/column_parallel_linear_kernel.h" +#include "ppl/nn/oputils/opmx/reshape_column_parallel_linear.h" +#include "ppl/nn/common/logger.h" + +using namespace std; +using namespace ppl::common; + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +RetCode F8F8ColumnParallelLinearOp::CommonInit() { + infer_type_and_format_func_ = [this](InputOutputInfo* info) -> RetCode { + auto output_shape = info->GetOutput(0)->GetShape(); + output_shape->SetDataFormat(DATAFORMAT_NDARRAY); + output_shape->SetDataType(DATATYPE_FLOAT16); + return RC_SUCCESS; + }; + infer_dims_func_ = [this](InputOutputInfo* info) -> RetCode { + auto status = ppl::nn::opmx::ReshapeColumnParallelLinear(info, param_.get(), nccl_param_->size); + if (status != RC_SUCCESS) { + return status; + } + + auto input_shape = info->GetInput(0)->GetShape(); + auto output_shape = info->GetOutput(0)->GetShape(); + + int64_t first_dim = input_shape->GetDim(0); + + if (first_dim % 16 != 0) { + int64_t aligned_dim = ((first_dim + 15) / 16) * 16; + uint16_t padding = aligned_dim - first_dim; + input_shape->SetPadding1(0, padding); + output_shape->SetPadding1(0, padding); + } + + return RC_SUCCESS; + }; + return RC_SUCCESS; +} + + +RetCode F8F8ColumnParallelLinearOp::DoInit(const OptKernelOptions& options) { + auto status = GenericLoadParam(options, ¶m_); + if (status != RC_SUCCESS) { + LOG(ERROR) << "GenericLoadParam failed: " << GetRetCodeStr(status); + return status; + } + + nccl_param_ = options.device->GetTensorParallelNcclParam(); + + return CommonInit(); +} + +KernelImpl* F8F8ColumnParallelLinearOp::CreateKernelImpl() const { + return CreateKernelImplWithParam(param_.get()); +} + +}}}}} // namespace ppl::nn::llm::cuda::pmx + +#endif \ No newline at end of file diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/column_parallel_linear_op.h b/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/column_parallel_linear_op.h new file mode 100644 index 000000000..0a15e0162 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/column_parallel_linear_op.h @@ -0,0 +1,28 @@ +#ifdef PPLNN_ENABLE_FP8 + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_PMX_F8F8_COLUMN_PARALLEL_LINEAR_OP_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_PMX_F8F8_COLUMN_PARALLEL_LINEAR_OP_H_ + +#include "ppl/nn/engines/llm_cuda/opt_kernel.h" +#include "ppl/nn/params/opmx/column_parallel_linear_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class F8F8ColumnParallelLinearOp final : public LlmCudaOptKernel { +public: + F8F8ColumnParallelLinearOp(const ir::Node* node) : LlmCudaOptKernel(node) {} + + KernelImpl* CreateKernelImpl() const override; + ppl::common::RetCode DoInit(const OptKernelOptions&) override; + +private: + ppl::common::RetCode CommonInit(); + + std::shared_ptr param_; + ppl::common::NcclParam *nccl_param_ = nullptr; +}; + +}}}}} // namespace ppl::nn::llm::cuda::pmx + +#endif +#endif diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/online_cast_op.cc b/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/online_cast_op.cc new file mode 100644 index 000000000..bb0b609b4 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/online_cast_op.cc @@ -0,0 +1,43 @@ +#ifdef PPLNN_ENABLE_FP8 + +#include "online_cast_op.h" + +#include "ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/online_cast_kernel.h" +#include "ppl/nn/common/logger.h" + +using namespace std; +using namespace ppl::common; + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +RetCode F8F8OnlineCastOp::CommonInit() { + infer_type_and_format_func_ = [this](InputOutputInfo* info) -> RetCode { + auto output_shape = info->GetOutput(0)->GetShape(); + + output_shape->SetDataFormat(DATAFORMAT_NDARRAY); + output_shape->SetDataType(DATATYPE_FLOAT8E4M3); + + return RC_SUCCESS; + }; + infer_dims_func_ = [this](InputOutputInfo* info) -> RetCode { + auto input_shape = info->GetInput(0)->GetShape(); + auto output_shape = info->GetOutput(0)->GetShape(); + + output_shape->Reshape(input_shape->GetDims(), input_shape->GetDimCount()); + + return RC_SUCCESS; + }; + return RC_SUCCESS; +} + +RetCode F8F8OnlineCastOp::DoInit(const OptKernelOptions& options) { + return CommonInit(); +} + +KernelImpl* F8F8OnlineCastOp::CreateKernelImpl() const { + return CreateKernelImplWithoutParam(); +} + +}}}}} // namespace ppl::nn::llm::cuda::pmx + +#endif \ No newline at end of file diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/online_cast_op.h b/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/online_cast_op.h new file mode 100644 index 000000000..05d38ccea --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/online_cast_op.h @@ -0,0 +1,24 @@ +#ifdef PPLNN_ENABLE_FP8 + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_PMX_F8F8_ONLINE_CAST_OP_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_PMX_F8F8_ONLINE_CAST_OP_H_ + +#include "ppl/nn/engines/llm_cuda/opt_kernel.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class F8F8OnlineCastOp final : public LlmCudaOptKernel { +public: + F8F8OnlineCastOp(const ir::Node* node) : LlmCudaOptKernel(node) {} + + KernelImpl* CreateKernelImpl() const override; + ppl::common::RetCode DoInit(const OptKernelOptions&) override; + +private: + ppl::common::RetCode CommonInit(); +}; + +}}}}} // namespace ppl::nn::llm::cuda::pmx + +#endif +#endif diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/row_parallel_linear_op.cc b/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/row_parallel_linear_op.cc new file mode 100644 index 000000000..e2feb405b --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/row_parallel_linear_op.cc @@ -0,0 +1,62 @@ +#ifdef PPLNN_ENABLE_FP8 + +#include "row_parallel_linear_op.h" + +#include "ppl/nn/engines/llm_cuda/kernels/opmx/f8f8/row_parallel_linear_kernel.h" +#include "ppl/nn/oputils/opmx/reshape_row_parallel_linear.h" +#include "ppl/nn/common/logger.h" + +using namespace std; +using namespace ppl::common; + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +RetCode F8F8RowParallelLinearOp::CommonInit() { + infer_type_and_format_func_ = [this](InputOutputInfo* info) -> RetCode { + auto output_shape = info->GetOutput(0)->GetShape(); + output_shape->SetDataFormat(DATAFORMAT_NDARRAY); + output_shape->SetDataType(DATATYPE_FLOAT16); + return RC_SUCCESS; + }; + infer_dims_func_ = [this](InputOutputInfo* info) -> RetCode { + auto status = nn::opmx::ReshapeRowParallelLinear(info, param_.get(), nccl_param_->size); + if (status != RC_SUCCESS) { + return status; + } + + auto input_shape = info->GetInput(0)->GetShape(); + auto output_shape = info->GetOutput(0)->GetShape(); + + int64_t first_dim = input_shape->GetDim(0); + + if (first_dim % 16 != 0) { + int64_t aligned_dim = ((first_dim + 15) / 16) * 16; + uint16_t padding = aligned_dim - first_dim; + input_shape->SetPadding1(0, padding); + output_shape->SetPadding1(0, padding); + } + + return RC_SUCCESS; + }; + return RC_SUCCESS; +} + +RetCode F8F8RowParallelLinearOp::DoInit(const OptKernelOptions& options) { + auto status = GenericLoadParam(options, ¶m_); + if (status != RC_SUCCESS) { + LOG(ERROR) << "GenericLoadParam failed: " << GetRetCodeStr(status); + return status; + } + + nccl_param_ = options.device->GetTensorParallelNcclParam(); + + return CommonInit(); +} + +KernelImpl* F8F8RowParallelLinearOp::CreateKernelImpl() const { + return CreateKernelImplWithParam(param_.get()); +} + +}}}}} // namespace ppl::nn::llm::cuda::pmx + +#endif \ No newline at end of file diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/row_parallel_linear_op.h b/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/row_parallel_linear_op.h new file mode 100644 index 000000000..eed5c83bf --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/f8f8/row_parallel_linear_op.h @@ -0,0 +1,28 @@ +#ifdef PPLNN_ENABLE_FP8 + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_PMX_F8F8_ROW_PARALLEL_LINEAR_OP_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_PMX_F8F8_ROW_PARALLEL_LINEAR_OP_H_ + +#include "ppl/nn/engines/llm_cuda/opt_kernel.h" +#include "ppl/nn/params/opmx/row_parallel_linear_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class F8F8RowParallelLinearOp final : public LlmCudaOptKernel { +public: + F8F8RowParallelLinearOp(const ir::Node* node) : LlmCudaOptKernel(node) {} + + KernelImpl* CreateKernelImpl() const override; + ppl::common::RetCode DoInit(const OptKernelOptions&) override; + +private: + ppl::common::RetCode CommonInit(); + + std::shared_ptr param_; + ppl::common::NcclParam *nccl_param_ = nullptr; +}; + +}}}}} // namespace ppl::nn::llm::cuda::pmx + +#endif +#endif diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/i4f16/column_parallel_linear_op.cc b/src/ppl/nn/engines/llm_cuda/ops/opmx/i4f16/column_parallel_linear_op.cc new file mode 100644 index 000000000..85a27d54d --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/i4f16/column_parallel_linear_op.cc @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "column_parallel_linear_op.h" + +#include "ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/column_parallel_linear_kernel.h" +#include "ppl/nn/oputils/opmx/reshape_column_parallel_linear.h" +#include "ppl/nn/common/logger.h" + +#ifdef PPLNN_ENABLE_PMX_MODEL +#include "ppl/nn/engines/llm_cuda/engine.h" +#include "ppl/nn/models/pmx/utils.h" +#include "ppl/nn/engines/llm_cuda/pmx/generated/llm_cuda_op_params_generated.h" +#endif + +using namespace std; +using namespace ppl::common; + + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +RetCode I4F16ColumnParallelLinearOp::CommonInit() { + infer_type_and_format_func_ = GenericInferTypeAndFormat; + infer_dims_func_ = [this](InputOutputInfo* info) -> RetCode { + return nn::opmx::ReshapeColumnParallelLinear(info, param_.get(), nccl_param_->size, 1, 4); + }; + return RC_SUCCESS; +} + +RetCode I4F16ColumnParallelLinearOp::DoInit(const OptKernelOptions& options) { + auto status = GenericLoadParam(options, ¶m_); + if (status != RC_SUCCESS) { + LOG(ERROR) << "GenericLoadParam failed: " << GetRetCodeStr(status); + return status; + } + + nccl_param_ = options.device->GetTensorParallelNcclParam(); + + return CommonInit(); +} + +KernelImpl* I4F16ColumnParallelLinearOp::CreateKernelImpl() const { + return CreateKernelImplWithParam(param_.get()); +} + +#ifdef PPLNN_ENABLE_PMX_MODEL +RetCode I4F16ColumnParallelLinearOp::SerializeData(const ppl::nn::pmx::SerializationContext&, utils::DataStream* ds) const { + flatbuffers::FlatBufferBuilder builder; + auto fb_param = opmx::CreateColumnParallelLinearParam(builder, + param_.get()->in_features, + param_.get()->out_features, + param_.get()->bias_term, + param_.get()->gather_output); + auto fb_op_param = opmx::CreateOpParam(builder, opmx::OpParamType_ColumnParallelLinearParam, fb_param.Union()); + opmx::FinishOpParamBuffer(builder, fb_op_param); + return ds->Write(builder.GetBufferPointer(), builder.GetSize()); +} + +RetCode I4F16ColumnParallelLinearOp::DeserializeData(const ppl::nn::pmx::DeserializationContext& ctx, const void* base, uint64_t size) { + auto fb_op_param = opmx::GetOpParam(base); + auto fb_param = fb_op_param->value_as_ColumnParallelLinearParam(); + param_ = make_shared(); + param_.get()->in_features = fb_param->in_features(); + param_.get()->out_features = fb_param->out_features(); + param_.get()->bias_term = fb_param->bias_term(); + param_.get()->gather_output = fb_param->gather_output(); + + nccl_param_ = dynamic_cast(ctx.engine)->GetTensorParallelNcclParam(); + + return CommonInit(); +} +#endif + + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/i4f16/column_parallel_linear_op.h b/src/ppl/nn/engines/llm_cuda/ops/opmx/i4f16/column_parallel_linear_op.h new file mode 100644 index 000000000..a4fc073c6 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/i4f16/column_parallel_linear_op.h @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_I4F16_COLUMN_PARALLEL_LINEAR_OP_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_I4F16_COLUMN_PARALLEL_LINEAR_OP_H_ + +#include "ppl/nn/engines/llm_cuda/opt_kernel.h" +#include "ppl/nn/params/opmx/column_parallel_linear_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class I4F16ColumnParallelLinearOp final : public LlmCudaOptKernel { +public: + I4F16ColumnParallelLinearOp(const ir::Node* node) : LlmCudaOptKernel(node) {} + + KernelImpl* CreateKernelImpl() const override; + ppl::common::RetCode DoInit(const OptKernelOptions&) override; + +#ifdef PPLNN_ENABLE_PMX_MODEL + ppl::common::RetCode SerializeData(const ppl::nn::pmx::SerializationContext&, utils::DataStream*) const override; + ppl::common::RetCode DeserializeData(const ppl::nn::pmx::DeserializationContext&, const void*, uint64_t) override; +#endif + +private: + ppl::common::RetCode CommonInit(); + + std::shared_ptr param_; + ppl::common::NcclParam *nccl_param_ = nullptr; +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/i4f16/row_parallel_linear_op.cc b/src/ppl/nn/engines/llm_cuda/ops/opmx/i4f16/row_parallel_linear_op.cc new file mode 100644 index 000000000..8fdfb2757 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/i4f16/row_parallel_linear_op.cc @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "row_parallel_linear_op.h" + +#include "ppl/nn/engines/llm_cuda/kernels/opmx/i4f16/row_parallel_linear_kernel.h" +#include "ppl/nn/oputils/opmx/reshape_row_parallel_linear.h" +#include "ppl/nn/common/logger.h" + +#ifdef PPLNN_ENABLE_PMX_MODEL +#include "ppl/nn/engines/llm_cuda/engine.h" +#include "ppl/nn/models/pmx/utils.h" +#include "ppl/nn/engines/llm_cuda/pmx/generated/llm_cuda_op_params_generated.h" +#endif + +using namespace std; +using namespace ppl::common; + + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +RetCode I4F16RowParallelLinearOp::CommonInit() { + infer_type_and_format_func_ = GenericInferTypeAndFormat; + infer_dims_func_ = [this](InputOutputInfo* info) -> RetCode { + return nn::opmx::ReshapeRowParallelLinear(info, param_.get(), nccl_param_->size, 1, 4); + }; + + return RC_SUCCESS; +} + +RetCode I4F16RowParallelLinearOp::DoInit(const OptKernelOptions& options) { + auto status = GenericLoadParam(options, ¶m_); + if (status != RC_SUCCESS) { + LOG(ERROR) << "GenericLoadParam failed: " << GetRetCodeStr(status); + return status; + } + + nccl_param_ = options.device->GetTensorParallelNcclParam(); + + return CommonInit(); +} + +KernelImpl* I4F16RowParallelLinearOp::CreateKernelImpl() const { + return CreateKernelImplWithParam(param_.get()); +} + +#ifdef PPLNN_ENABLE_PMX_MODEL +RetCode I4F16RowParallelLinearOp::SerializeData(const ppl::nn::pmx::SerializationContext& ctx, utils::DataStream* ds) const { + flatbuffers::FlatBufferBuilder builder; + auto fb_param = opmx::CreateRowParallelLinearParam(builder, + param_.get()->in_features, + param_.get()->out_features, + param_.get()->bias_term, + param_.get()->input_is_parallel); + auto fb_op_param = opmx::CreateOpParam(builder, opmx::OpParamType_RowParallelLinearParam, fb_param.Union()); + opmx::FinishOpParamBuffer(builder, fb_op_param); + return ds->Write(builder.GetBufferPointer(), builder.GetSize()); +} + +RetCode I4F16RowParallelLinearOp::DeserializeData(const ppl::nn::pmx::DeserializationContext& ctx, const void* base, uint64_t size) { + auto fb_op_param = opmx::GetOpParam(base); + auto fb_param = fb_op_param->value_as_RowParallelLinearParam(); + param_ = make_shared(); + param_.get()->in_features = fb_param->in_features(); + param_.get()->out_features = fb_param->out_features(); + param_.get()->bias_term = fb_param->bias_term(); + param_.get()->input_is_parallel = fb_param->input_is_parallel(); + + nccl_param_ = dynamic_cast(ctx.engine)->GetTensorParallelNcclParam(); + + return CommonInit(); +} +#endif + + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/i4f16/row_parallel_linear_op.h b/src/ppl/nn/engines/llm_cuda/ops/opmx/i4f16/row_parallel_linear_op.h new file mode 100644 index 000000000..7102530e0 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/i4f16/row_parallel_linear_op.h @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_I4F16_ROW_PARALLEL_LINEAR_OP_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_I4F16_ROW_PARALLEL_LINEAR_OP_H_ + +#include "ppl/nn/engines/llm_cuda/opt_kernel.h" +#include "ppl/nn/params/opmx/row_parallel_linear_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class I4F16RowParallelLinearOp final : public LlmCudaOptKernel { +public: + I4F16RowParallelLinearOp(const ir::Node* node) : LlmCudaOptKernel(node) {} + + KernelImpl* CreateKernelImpl() const override; + ppl::common::RetCode DoInit(const OptKernelOptions&) override; + +#ifdef PPLNN_ENABLE_PMX_MODEL + ppl::common::RetCode SerializeData(const ppl::nn::pmx::SerializationContext&, utils::DataStream*) const override; + ppl::common::RetCode DeserializeData(const ppl::nn::pmx::DeserializationContext&, const void*, uint64_t) override; +#endif + +private: + ppl::common::RetCode CommonInit(); + + std::shared_ptr param_; + ppl::common::NcclParam *nccl_param_ = nullptr; +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_reshape_split_op.cc b/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_reshape_split_op.cc new file mode 100644 index 000000000..a6b37ad76 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_reshape_split_op.cc @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "online_dequantize_reshape_split_op.h" + +#include "ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_reshape_split_kernel.h" +#include "ppl/nn/oputils/onnx/reshape_split.h" +#include "ppl/nn/common/logger.h" + +#ifdef PPLNN_ENABLE_PMX_MODEL +#include "ppl/nn/engines/llm_cuda/engine.h" +#include "ppl/nn/models/pmx/utils.h" +#include "ppl/nn/engines/llm_cuda/pmx/generated/llm_cuda_op_i8i8_params_generated.h" +#endif + +using namespace std; +using namespace ppl::common; + + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +RetCode I8I8OnlineDequantizeReshapeSplitOp::CommonInit() { + infer_type_and_format_func_ = [this](InputOutputInfo* info) -> RetCode { + auto scale_outer_shape = info->GetInput(2)->GetShape(); + + for (uint32_t i = 0; i < info->GetOutputCount(); ++i) { + auto output_shape = info->GetOutput(i)->GetShape(); + output_shape->SetDataFormat(DATAFORMAT_NDARRAY); + output_shape->SetDataType(scale_outer_shape->GetDataType()); + } + + return RC_SUCCESS; + }; + infer_dims_func_ = [this](InputOutputInfo* info) -> RetCode { + std::vector output_dims(param_.shape.begin(), param_.shape.end()); + auto input_shape = info->GetInput(0)->GetShape(); + + const int32_t axis = param_.split_param->axis < 0 + ? param_.split_param->axis + param_.shape.size() + : param_.split_param->axis; + + // fill zeros + for (int32_t i = 0; i < axis; ++i) { + output_dims[i] = input_shape->GetDim(i); + } + // set split dim + for (uint32_t i = 0; i < info->GetOutputCount(); ++i) { + output_dims[axis] = param_.split[i]; + auto output_shape = info->GetOutput(i)->GetShape(); + output_shape->Reshape(output_dims); + } + + return ppl::common::RC_SUCCESS; + }; + + return ppl::common::RC_SUCCESS; +} + +RetCode I8I8OnlineDequantizeReshapeSplitOp::DoInit(const OptKernelOptions& options) { + auto status = GenericLoadParam(options, ¶m_.split_param); + if (status != RC_SUCCESS) { + LOG(ERROR) << "GenericLoadParam failed: " << GetRetCodeStr(status); + return status; + } + + return CommonInit(); +} + +KernelImpl* I8I8OnlineDequantizeReshapeSplitOp::CreateKernelImpl() const { + return CreateKernelImplWithParam(¶m_); +} + +#ifdef PPLNN_ENABLE_PMX_MODEL +ppl::common::RetCode I8I8OnlineDequantizeReshapeSplitOp::SerializeData(const ppl::nn::pmx::SerializationContext&, utils::DataStream* ds) const { + flatbuffers::FlatBufferBuilder builder; + auto fb_split_point = builder.CreateVector(param_.split_param.get()->split_point); + auto fb_split = builder.CreateVector(param_.split); + auto fb_shape = builder.CreateVector(param_.shape); + auto fb_param = opmx::i8i8::CreateOnlineDequantizeReshapeSplitParam(builder, + param_.split_param.get()->axis, + fb_split_point, + fb_split, + fb_shape, + param_.bias_term); + auto fb_op_param = opmx::i8i8::CreateOpParam(builder, opmx::i8i8::OpParamType_OnlineDequantizeReshapeSplitParam, fb_param.Union()); + opmx::i8i8::FinishOpParamBuffer(builder, fb_op_param); + return ds->Write(builder.GetBufferPointer(), builder.GetSize()); +} + +ppl::common::RetCode I8I8OnlineDequantizeReshapeSplitOp::DeserializeData(const ppl::nn::pmx::DeserializationContext& ctx, const void* base, uint64_t size) { + auto fb_op_param = opmx::i8i8::GetOpParam(base); + auto fb_param = fb_op_param->value_as_OnlineDequantizeReshapeSplitParam(); + param_.split_param = make_shared(); + param_.split_param.get()->axis = fb_param->axis(); + param_.bias_term = fb_param->bias_term(); + ppl::nn::opmx::utils::Fbvec2Stdvec(fb_param->split_point(), &(param_.split_param.get()->split_point)); + ppl::nn::opmx::utils::Fbvec2Stdvec(fb_param->split(), &(param_.split)); + ppl::nn::opmx::utils::Fbvec2Stdvec(fb_param->shape(), &(param_.shape)); + + return CommonInit(); +} +#endif + + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_reshape_split_op.h b/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_reshape_split_op.h new file mode 100644 index 000000000..5b0872417 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_reshape_split_op.h @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_I8I8_ONLINE_DEQUANTIZE_RESHAPE_SPLIT_OP_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_I8I8_ONLINE_DEQUANTIZE_RESHAPE_SPLIT_OP_H_ + +#include "online_dequantize_op.h" +#include "ppl/nn/params/onnx/split_param.h" +#include "ppl/nn/engines/llm_cuda/opt_kernel.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class I8I8OnlineDequantizeReshapeSplitOp final : public LlmCudaOptKernel { +public: + I8I8OnlineDequantizeReshapeSplitOp(const ir::Node* node) : LlmCudaOptKernel(node) {} + + struct Param { + std::shared_ptr split_param; + std::vector split; + std::vector shape; + bool bias_term = false; + }; + + KernelImpl* CreateKernelImpl() const override; + ppl::common::RetCode DoInit(const OptKernelOptions&) override; + +#ifdef PPLNN_ENABLE_PMX_MODEL + ppl::common::RetCode SerializeData(const ppl::nn::pmx::SerializationContext&, utils::DataStream*) const override; + ppl::common::RetCode DeserializeData(const ppl::nn::pmx::DeserializationContext&, const void*, uint64_t) override; +#endif + + Param* GetParam() { return ¶m_; } + const Param* GetParam() const { return ¶m_; } + +private: + ppl::common::RetCode CommonInit(); + + Param param_; +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_silu_quantize_op.cc b/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_silu_quantize_op.cc new file mode 100644 index 000000000..a5421d4b1 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_silu_quantize_op.cc @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "online_dequantize_silu_quantize_op.h" + +#include "ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_silu_quantize_kernel.h" +#include "ppl/nn/common/logger.h" + +#ifdef PPLNN_ENABLE_PMX_MODEL +#include "ppl/nn/engines/llm_cuda/engine.h" +#include "ppl/nn/models/pmx/utils.h" +#include "ppl/nn/engines/llm_cuda/pmx/generated/llm_cuda_op_i8i8_params_generated.h" +#endif + +using namespace std; +using namespace ppl::common; + + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +RetCode I8I8OnlineDequantizeSiLUQuantizeOp::CommonInit() { + infer_type_and_format_func_ = [this](InputOutputInfo* info) -> RetCode { + auto input_scale_outer_shape = info->GetInput(1)->GetShape(); + auto output_shape = info->GetOutput(0)->GetShape(); + auto scale_shape = info->GetOutput(1)->GetShape(); + + output_shape->SetDataFormat(DATAFORMAT_NDARRAY); + output_shape->SetDataType(DATATYPE_INT8); + + scale_shape->SetDataFormat(DATAFORMAT_NDARRAY); + scale_shape->SetDataType(input_scale_outer_shape->GetDataType()); + + return RC_SUCCESS; + }; + infer_dims_func_ = [this](InputOutputInfo* info) -> RetCode { + auto input_shape = info->GetInput(0)->GetShape(); + auto output_shape = info->GetOutput(0)->GetShape(); + auto scale_shape = info->GetOutput(1)->GetShape(); + + output_shape->Reshape(input_shape->GetDims(), input_shape->GetDimCount()); + scale_shape->Reshape(input_shape->GetDims(), input_shape->GetDimCount() - 1); + + return RC_SUCCESS; + }; + + return RC_SUCCESS; +} + +RetCode I8I8OnlineDequantizeSiLUQuantizeOp::DoInit(const OptKernelOptions& options) { + return CommonInit(); +} + +KernelImpl* I8I8OnlineDequantizeSiLUQuantizeOp::CreateKernelImpl() const { + return CreateKernelImplWithoutParam(); +} + +#ifdef PPLNN_ENABLE_PMX_MODEL +ppl::common::RetCode I8I8OnlineDequantizeSiLUQuantizeOp::SerializeData(const ppl::nn::pmx::SerializationContext&, utils::DataStream* ds) const { + return RC_SUCCESS; +} + +ppl::common::RetCode I8I8OnlineDequantizeSiLUQuantizeOp::DeserializeData(const ppl::nn::pmx::DeserializationContext& ctx, const void* base, uint64_t size) { + return CommonInit(); +} +#endif + + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_silu_quantize_op.h b/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_silu_quantize_op.h new file mode 100644 index 000000000..8f73daa5d --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_silu_quantize_op.h @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_I8I8_ONLINE_DEQUANTIZE_SILU_QUANTIZE_OP_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_I8I8_ONLINE_DEQUANTIZE_SILU_QUANTIZE_OP_H_ + +#include "ppl/nn/engines/llm_cuda/opt_kernel.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class I8I8OnlineDequantizeSiLUQuantizeOp final : public LlmCudaOptKernel { +public: + I8I8OnlineDequantizeSiLUQuantizeOp(const ir::Node* node) : LlmCudaOptKernel(node) {} + + KernelImpl* CreateKernelImpl() const override; + ppl::common::RetCode DoInit(const OptKernelOptions&) override; + +#ifdef PPLNN_ENABLE_PMX_MODEL + ppl::common::RetCode SerializeData(const ppl::nn::pmx::SerializationContext&, utils::DataStream*) const override; + ppl::common::RetCode DeserializeData(const ppl::nn::pmx::DeserializationContext&, const void*, uint64_t) override; +#endif + +private: + ppl::common::RetCode CommonInit(); + +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_swiglu_quantize_op.cc b/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_swiglu_quantize_op.cc new file mode 100644 index 000000000..691665791 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_swiglu_quantize_op.cc @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "online_dequantize_swiglu_quantize_op.h" + +#include "ppl/nn/engines/llm_cuda/kernels/opmx/i8i8/online_dequantize_swiglu_quantize_kernel.h" +#include "ppl/nn/oputils/opmx/reshape_glu.h" +#include "ppl/nn/common/logger.h" + +#ifdef PPLNN_ENABLE_PMX_MODEL +#include "ppl/nn/models/pmx/utils.h" +#include "ppl/nn/engines/llm_cuda/pmx/generated/llm_cuda_op_params_generated.h" +#endif + +using namespace std; +using namespace ppl::common; + + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +RetCode I8I8OnlineDequantizeSwiGLUQuantizeOp::CommonInit() { + + infer_type_and_format_func_ = [this](InputOutputInfo* info) -> RetCode { + auto input_scale_outer_shape = info->GetInput(1)->GetShape(); + auto output_shape = info->GetOutput(0)->GetShape(); + auto scale_shape = info->GetOutput(1)->GetShape(); + + output_shape->SetDataFormat(DATAFORMAT_NDARRAY); + output_shape->SetDataType(DATATYPE_INT8); + + scale_shape->SetDataFormat(DATAFORMAT_NDARRAY); + scale_shape->SetDataType(input_scale_outer_shape->GetDataType()); + + return RC_SUCCESS; + }; + infer_dims_func_ = [this](InputOutputInfo* info) -> RetCode { + auto input_shape = info->GetInput(0)->GetShape(); + auto scale_shape = info->GetOutput(1)->GetShape(); + + scale_shape->Reshape(input_shape->GetDims(), input_shape->GetDimCount() - 1); + + return ppl::nn::opmx::ReshapeGLU(info); + }; + + return RC_SUCCESS; +} + +RetCode I8I8OnlineDequantizeSwiGLUQuantizeOp::DoInit(const OptKernelOptions& options) { + auto status = GenericLoadParam(options, ¶m_); + if (status != RC_SUCCESS) { + LOG(ERROR) << "GenericLoadParam failed: " << GetRetCodeStr(status); + return status; + } + + return CommonInit(); +} + +KernelImpl* I8I8OnlineDequantizeSwiGLUQuantizeOp::CreateKernelImpl() const { + return CreateKernelImplWithParam(param_.get()); +} + +#ifdef PPLNN_ENABLE_PMX_MODEL +ppl::common::RetCode I8I8OnlineDequantizeSwiGLUQuantizeOp::SerializeData(const ppl::nn::pmx::SerializationContext&, utils::DataStream* ds) const { + flatbuffers::FlatBufferBuilder builder; + auto fb_param = opmx::CreateSwishParam(builder, param_.get()->beta); + auto fb_op_param = opmx::CreateOpParam(builder, opmx::OpParamType_SwishParam, fb_param.Union()); + opmx::FinishOpParamBuffer(builder, fb_op_param); + return ds->Write(builder.GetBufferPointer(), builder.GetSize()); +} + +ppl::common::RetCode I8I8OnlineDequantizeSwiGLUQuantizeOp::DeserializeData(const ppl::nn::pmx::DeserializationContext& ctx, const void* base, uint64_t size) { + auto fb_op_param = opmx::GetOpParam(base); + auto fb_param = fb_op_param->value_as_SwishParam(); + param_ = make_shared(); + param_.get()->beta = fb_param->beta(); + + return CommonInit(); +} +#endif + + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_swiglu_quantize_op.h b/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_swiglu_quantize_op.h new file mode 100644 index 000000000..a36e3f2c7 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_swiglu_quantize_op.h @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_I8I8_ONLINE_DEQUANTIZE_SWIGLU_QUANTIZE_OP_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_I8I8_ONLINE_DEQUANTIZE_SWIGLU_QUANTIZE_OP_H_ + +#include "ppl/nn/engines/llm_cuda/opt_kernel.h" +#include "ppl/nn/params/pmx/swish_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class I8I8OnlineDequantizeSwiGLUQuantizeOp final : public LlmCudaOptKernel { +public: + I8I8OnlineDequantizeSwiGLUQuantizeOp(const ir::Node* node) : LlmCudaOptKernel(node) {} + + KernelImpl* CreateKernelImpl() const override; + ppl::common::RetCode DoInit(const OptKernelOptions&) override; + +#ifdef PPLNN_ENABLE_PMX_MODEL + ppl::common::RetCode SerializeData(const ppl::nn::pmx::SerializationContext&, utils::DataStream*) const override; + ppl::common::RetCode DeserializeData(const ppl::nn::pmx::DeserializationContext&, const void*, uint64_t) override; +#endif + +private: + ppl::common::RetCode CommonInit(); + + std::shared_ptr param_; +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/pixel_unshuffle_op.cc b/src/ppl/nn/engines/llm_cuda/ops/opmx/pixel_unshuffle_op.cc new file mode 100644 index 000000000..c46202061 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/pixel_unshuffle_op.cc @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "pixel_unshuffle_op.h" + +#include "ppl/nn/engines/llm_cuda/kernels/opmx/pixel_unshuffle_kernel.h" +#include "ppl/nn/oputils/opmx/reshape_pixel_unshuffle.h" +#include "ppl/nn/common/logger.h" + +#ifdef PPLNN_ENABLE_PMX_MODEL +#include "ppl/nn/models/pmx/utils.h" +#include "ppl/nn/engines/llm_cuda/pmx/generated/llm_cuda_op_params_generated.h" +#endif + +using namespace std; +using namespace ppl::common; + + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +RetCode PixelUnshuffleOp::CommonInit() { + infer_type_and_format_func_ = GenericInferTypeAndFormat; + infer_dims_func_ = [this](InputOutputInfo* info) -> RetCode { + return nn::opmx::ReshapePixelUnshuffle(info, param_.get()); + }; + return RC_SUCCESS; +} + +RetCode PixelUnshuffleOp::DoInit(const OptKernelOptions& options) { + auto status = GenericLoadParam(options, ¶m_); + if (status != RC_SUCCESS) { + LOG(ERROR) << "GenericLoadParam failed: " << GetRetCodeStr(status); + return status; + } + + return CommonInit(); +} + +KernelImpl* PixelUnshuffleOp::CreateKernelImpl() const { + return CreateKernelImplWithParam(param_.get()); +} + +#ifdef PPLNN_ENABLE_PMX_MODEL +ppl::common::RetCode PixelUnshuffleOp::SerializeData(const ppl::nn::pmx::SerializationContext& ctx, utils::DataStream* ds) const { + flatbuffers::FlatBufferBuilder builder; + auto fb_param = opmx::CreatePixelUnshuffleParam(builder, + param_.get()->scale_factor, + param_.get()->data_layout); + auto fb_op_param = opmx::CreateOpParam(builder, opmx::OpParamType_PixelUnshuffleParam, fb_param.Union()); + opmx::FinishOpParamBuffer(builder, fb_op_param); + return ds->Write(builder.GetBufferPointer(), builder.GetSize()); +} + +ppl::common::RetCode PixelUnshuffleOp::DeserializeData(const ppl::nn::pmx::DeserializationContext& ctx, const void* base, uint64_t size) { + auto fb_op_param = opmx::GetOpParam(base); + auto fb_param = fb_op_param->value_as_PixelUnshuffleParam(); + param_ = make_shared(); + param_.get()->scale_factor = fb_param->scale_factor(); + param_.get()->data_layout = fb_param->data_layout(); + + return CommonInit(); +} +#endif + + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/pixel_unshuffle_op.h b/src/ppl/nn/engines/llm_cuda/ops/opmx/pixel_unshuffle_op.h new file mode 100644 index 000000000..6874f4bc4 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/pixel_unshuffle_op.h @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_PIXEL_UNSHUFFLE_OP_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_PIXEL_UNSHUFFLE_OP_H_ + +#include "ppl/nn/engines/llm_cuda/opt_kernel.h" +#include "ppl/nn/params/opmx/pixel_unshuffle_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class PixelUnshuffleOp final : public LlmCudaOptKernel { +public: + PixelUnshuffleOp(const ir::Node* node) : LlmCudaOptKernel(node) {} + + KernelImpl* CreateKernelImpl() const override; + ppl::common::RetCode DoInit(const OptKernelOptions&) override; + +#ifdef PPLNN_ENABLE_PMX_MODEL + ppl::common::RetCode SerializeData(const ppl::nn::pmx::SerializationContext&, utils::DataStream*) const override; + ppl::common::RetCode DeserializeData(const ppl::nn::pmx::DeserializationContext&, const void*, uint64_t) override; +#endif + +private: + ppl::common::RetCode CommonInit(); + + std::shared_ptr param_; +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/tensor_parallel_rms_norm_op.cc b/src/ppl/nn/engines/llm_cuda/ops/opmx/tensor_parallel_rms_norm_op.cc new file mode 100644 index 000000000..dd846d606 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/tensor_parallel_rms_norm_op.cc @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "tensor_parallel_rms_norm_op.h" + +#include "ppl/nn/engines/llm_cuda/kernels/opmx/tensor_parallel_rms_norm_kernel.h" +#include "ppl/nn/common/logger.h" + +#ifdef PPLNN_ENABLE_PMX_MODEL +#include "ppl/nn/models/pmx/utils.h" +#include "ppl/nn/engines/llm_cuda/pmx/generated/llm_cuda_op_params_generated.h" +#endif + +using namespace std; +using namespace ppl::common; + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +RetCode TensorParallelRMSNormOp::CommonInit() { + infer_type_and_format_func_ = GenericInferTypeAndFormat; + infer_dims_func_ = GenericInferDims; + return RC_SUCCESS; +} + +RetCode TensorParallelRMSNormOp::DoInit(const OptKernelOptions& options) { + auto status = GenericLoadParam(options, ¶m_); + if (status != RC_SUCCESS) { + LOG(ERROR) << "GenericLoadParam failed: " << GetRetCodeStr(status); + return status; + } + + return CommonInit(); +} + +KernelImpl* TensorParallelRMSNormOp::CreateKernelImpl() const { + return CreateKernelImplWithParam(param_.get()); +} + +#ifdef PPLNN_ENABLE_PMX_MODEL +ppl::common::RetCode TensorParallelRMSNormOp::SerializeData(const ppl::nn::pmx::SerializationContext& ctx, utils::DataStream* ds) const { + flatbuffers::FlatBufferBuilder builder; + auto fb_param = opmx::CreateTensorParallelRMSNormParam(builder, + param_.get()->axis, + param_.get()->eps, + param_.get()->scale); + auto fb_op_param = opmx::CreateOpParam(builder, opmx::OpParamType_TensorParallelRMSNormParam, fb_param.Union()); + opmx::FinishOpParamBuffer(builder, fb_op_param); + return ds->Write(builder.GetBufferPointer(), builder.GetSize()); +} + +ppl::common::RetCode TensorParallelRMSNormOp::DeserializeData(const ppl::nn::pmx::DeserializationContext& ctx, const void* base, uint64_t size) { + auto fb_op_param = opmx::GetOpParam(base); + auto fb_param = fb_op_param->value_as_TensorParallelRMSNormParam(); + param_ = make_shared(); + param_.get()->axis = fb_param->axis(); + param_.get()->eps = fb_param->eps(); + param_.get()->scale = fb_param->scale(); + + return CommonInit(); +} +#endif + + +}}}}} // namespace ppl::nn::llm::cuda::opmx diff --git a/src/ppl/nn/engines/llm_cuda/ops/opmx/tensor_parallel_rms_norm_op.h b/src/ppl/nn/engines/llm_cuda/ops/opmx/tensor_parallel_rms_norm_op.h new file mode 100644 index 000000000..c5c65598b --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/ops/opmx/tensor_parallel_rms_norm_op.h @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_TENSOR_PARALLEL_RMS_NORM_OP_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_OPS_OPMX_TENSOR_PARALLEL_RMS_NORM_OP_H_ + +#include "ppl/nn/engines/llm_cuda/opt_kernel.h" +#include "ppl/nn/params/opmx/tensor_parallel_rms_norm_param.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace opmx { + +class TensorParallelRMSNormOp final : public LlmCudaOptKernel { +public: + TensorParallelRMSNormOp(const ir::Node* node) : LlmCudaOptKernel(node) {} + + KernelImpl* CreateKernelImpl() const override; + ppl::common::RetCode DoInit(const OptKernelOptions&) override; + +#ifdef PPLNN_ENABLE_PMX_MODEL + ppl::common::RetCode SerializeData(const ppl::nn::pmx::SerializationContext&, utils::DataStream*) const override; + ppl::common::RetCode DeserializeData(const ppl::nn::pmx::DeserializationContext&, const void*, uint64_t) override; +#endif + +private: + ppl::common::RetCode CommonInit(); + + std::shared_ptr param_; +}; + +}}}}} // namespace ppl::nn::llm::cuda::opmx + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/opt_graph.cc b/src/ppl/nn/engines/llm_cuda/opt_graph.cc index 2cbdeaccd..0e3d1459b 100644 --- a/src/ppl/nn/engines/llm_cuda/opt_graph.cc +++ b/src/ppl/nn/engines/llm_cuda/opt_graph.cc @@ -120,9 +120,35 @@ RetCode OptGraph::Optimize( } } + if (engine_options.quant_method == QUANT_METHOD_ONLINE_F8F8) { + +#ifndef PPLNN_ENABLE_FP8 + LOG(ERROR) << "FP8 requires CUDA version 12 or higher, and the PPLNN_ENABLE_FP8 compile option must be enabled."; + return RC_UNSUPPORTED; +#endif + + LOG(INFO) << "Processing F8F8Cast..."; + auto prc = OptPassManager::GetInstance()->Apply("", "F8F8Cast", options); + if (prc.retcode != RC_SUCCESS) { + LOG(ERROR) << "F8F8Cast failed: " << GetRetCodeStr(prc.retcode); + return prc.retcode; + } + if (!prc.graph_modified) { + LOG(INFO) << "F8F8Cast: nothing has been changed."; + } + } + + if (engine_options.quant_method == QUANT_METHOD_ONLINE_I4F16) { - LOG(INFO) << "I4F16Quantization has not been implemented"; - return ppl::common::RC_UNSUPPORTED; + LOG(INFO) << "Processing I4F16Quantization..."; + auto prc = OptPassManager::GetInstance()->Apply("", "I4F16Quantization", options); + if (prc.retcode != RC_SUCCESS) { + LOG(ERROR) << "I4F16Quantization failed: " << GetRetCodeStr(prc.retcode); + return prc.retcode; + } + if (!prc.graph_modified) { + LOG(INFO) << "I4F16Quantization: nothing has been changed."; + } } rc = utils::LoadConstants(*graph_, device, &partition_info_->constants); diff --git a/src/ppl/nn/engines/llm_cuda/opt_pass_manager.cc b/src/ppl/nn/engines/llm_cuda/opt_pass_manager.cc index b83bcd3e5..48efafc60 100644 --- a/src/ppl/nn/engines/llm_cuda/opt_pass_manager.cc +++ b/src/ppl/nn/engines/llm_cuda/opt_pass_manager.cc @@ -17,6 +17,13 @@ #include "passes/i8i8/quantization_pass.h" #include "passes/i8i8/fuse_rms_norm_pass.h" +#include "passes/i8i8/fuse_split_pass.h" +#include "passes/i8i8/fuse_silu_pass.h" +#include "passes/i8i8/fuse_swiglu_pass.h" + +#include "passes/i4f16/quantization_pass.h" + +#include "passes/f8f8/cast_pass.h" #include "opt_pass_manager.h" @@ -78,6 +85,15 @@ OptPassManager::~OptPassManager() { OptPassManager::OptPassManager() { Register("", "I8I8Quantization", i8i8::QuantizationPass); Register("i8i8.fuse", "FuseRMSNorm", i8i8::FuseRMSNormPass); + Register("i8i8.fuse", "FuseSplit", i8i8::FuseSplitPass); + Register("i8i8.fuse", "FuseSiLU", i8i8::FuseSiLUPass); + Register("i8i8.fuse", "FuseSwiGLU", i8i8::FuseSwiGLUPass); + + Register("", "I4F16Quantization", i4f16::QuantizationPass); + +#ifdef PPLNN_ENABLE_FP8 + Register("", "F8F8Cast", f8f8::CastPass); +#endif } diff --git a/src/ppl/nn/engines/llm_cuda/passes/f8f8/cast_pass.cc b/src/ppl/nn/engines/llm_cuda/passes/f8f8/cast_pass.cc new file mode 100644 index 000000000..108936d8d --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/passes/f8f8/cast_pass.cc @@ -0,0 +1,361 @@ +#ifdef PPLNN_ENABLE_FP8 + +#include "cast_pass.h" + +#include "ppl/nn/params/opmx/column_parallel_linear_param.h" + +#include "ppl/nn/engines/llm_cuda/ops/opmx/f8f8/online_cast_op.h" +#include "ppl/nn/engines/llm_cuda/ops/opmx/f8f8/column_parallel_linear_op.h" +#include "ppl/nn/engines/llm_cuda/ops/opmx/f8f8/row_parallel_linear_op.h" + +#include "ppl/kernel/llm/cuda/pmx/f8f8/cast.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace f8f8 { + +static std::string GetCastedEdgeName(const std::string& tensor_name) { + return tensor_name + ".casted"; +} + +static std::string GetCastNodeName(const std::string& tensor_name) { + return "Cast." + tensor_name; +} + +struct CastWeightResult final { + ppl::common::RetCode retcode; + bool casted; +}; + + +static CastWeightResult CastWeight( + ir::Node* linear_node, + const OptKernelOptions& options, + const int64_t in_features, + const int64_t out_features) +{ + auto topo = options.graph->topo.get(); + auto constants = &options.graph->data->constants; + auto shapes = &options.graph->data->shapes; + auto loaded_constants = &options.partition_info->constants; + + auto weight_edge = topo->GetEdge(linear_node->GetInput(1)); + + std::set consumer_white_list = { + "ColumnParallelLinear", + "RowParallelLinear", + }; + + for (auto iter = weight_edge->CreateConsumerIter(); iter.IsValid(); iter.Forward()) { + auto &consumer_type = topo->GetNode(iter.Get())->GetType(); + if (consumer_white_list.find(consumer_type.name) == consumer_white_list.end()) { + LOG(WARNING) << "failed to f8f8 cast weight[" << weight_edge->GetName() << "], " + << "met unsupported consumer type [" << consumer_type.domain << ":" << consumer_type.name << "]"; + return {ppl::common::RC_SUCCESS, false}; + } + } + + // check wether this weight has been processed + if (loaded_constants->find(weight_edge->GetId()) != loaded_constants->end()) { + return {ppl::common::RC_SUCCESS, true}; + } + + auto weight_shape = &shapes->at(weight_edge->GetId()); + + if (weight_shape->data_type != ppl::common::DATATYPE_FLOAT16) { + LOG(WARNING) << "only support f8f8 cast for fp16 weight"; + return {ppl::common::RC_SUCCESS, false}; + } + + ppl::common::RetCode rc; + + // alloc buffer for casted weight + RuntimeConstantInfo casted_weight_buffer; + casted_weight_buffer.GetShape()->Reshape({out_features, in_features}); + casted_weight_buffer.GetShape()->SetDataType(ppl::common::DATATYPE_FLOAT8E4M3); + casted_weight_buffer.GetShape()->SetDataFormat(ppl::common::DATAFORMAT_NDARRAY); + casted_weight_buffer.SetDevice(options.device); + rc = casted_weight_buffer.ReallocBuffer(); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "realloc buffer for cast weight[" << weight_edge->GetName() << "] failed: " << ppl::common::GetRetCodeStr(rc); + return {ppl::common::RC_OUT_OF_MEMORY, false}; + } + + // alloc buffer for origin weight at last + // NOTE: it must be alloced at last to avoid memory fragmentation when it freed after being casted + RuntimeConstantInfo weight_buffer; + weight_buffer.Reshape(*casted_weight_buffer.GetShape()); + weight_buffer.GetShape()->SetDataType(weight_shape->data_type); + weight_buffer.SetDevice(options.device); + + // use zero copy to reduce GPU memory fragmentation + void* weight_pinned_host_buffer = nullptr; + auto cuda_err = cudaMallocHost(&weight_pinned_host_buffer, weight_buffer.GetShape()->CalcBytesIncludingPadding(), cudaHostAllocMapped); + if (cudaSuccess != cuda_err) { + LOG(ERROR) << "realloc pinned buffer for weight[" << weight_edge->GetName() << "] failed: " << cudaGetErrorString(cuda_err); + return {ppl::common::RC_OUT_OF_MEMORY, false}; + } + void *weight_pinned_dev_buffer = nullptr; + cuda_err = cudaHostGetDevicePointer(&weight_pinned_dev_buffer, weight_pinned_host_buffer, 0); + if (cudaSuccess != cuda_err) { + LOG(ERROR) << "get device pinned buffer for weight[" << weight_edge->GetName() << "] failed: " << cudaGetErrorString(cuda_err); + return {ppl::common::RC_DEVICE_MEMORY_ERROR, false}; + } + weight_buffer.SetBuffer(weight_pinned_dev_buffer); + + // copy fp16 data to pinned memory for cast + auto weight_host = &constants->at(weight_edge->GetId()); + memcpy(weight_pinned_host_buffer, weight_host->data.GetData(), weight_host->data.GetSize()); + constants->erase(weight_edge->GetId()); + + // call cast kernel here + rc = ppl::kernel::llm::cuda::pmx::f8f8::cast_fp16( + options.device->GetStream(), + weight_buffer.GetBufferPtr(), + out_features, + in_features, + casted_weight_buffer.GetBufferPtr()); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "do cast for weight[" << weight_edge->GetName() << "] failed: " << ppl::common::GetRetCodeStr(rc); + return {ppl::common::RC_DEVICE_RUNTIME_ERROR, false}; + } + rc = options.device->Synchronize(); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "synchronize cast for weight[" << weight_edge->GetName() << "] failed: " << ppl::common::GetRetCodeStr(rc); + return {ppl::common::RC_DEVICE_RUNTIME_ERROR, false}; + } + + // change weight datatype + weight_shape->data_type = ppl::common::DATATYPE_FLOAT8E4M3; + + // emplace GPU buffer to runtime constants + loaded_constants->emplace(weight_edge->GetId(), std::move(casted_weight_buffer)); + + cuda_err = cudaFreeHost(weight_pinned_host_buffer); + if (cudaSuccess != cuda_err) { + LOG(ERROR) << "free pinned buffer for weight[" << weight_edge->GetName() << "] failed: " << cudaGetErrorString(cuda_err); + return {ppl::common::RC_DEVICE_MEMORY_ERROR, false}; + } + + return {ppl::common::RC_SUCCESS, true}; +} + + +static ppl::common::RetCode CastLinear( + ir::Node* linear_node, + const OptKernelOptions& options, + const int64_t in_features, + const int64_t out_features) +{ + auto topo = options.graph->topo.get(); + auto kernels = &options.partition_info->kernels; + + auto input_edge = topo->GetEdge(linear_node->GetInput(0)); + + // sometime there are 2 linear node consume same input, + // such as llama's FeedForward: y = w2(silu(w1(x)) * w3(x)). + // q_input_exits is for checking if input of linear has been casted. + auto casted_input_exits = false; + auto casted_input_name = GetCastedEdgeName(input_edge->GetName()); + auto edge_ret_pair = topo->AddEdge(casted_input_name); + if (!edge_ret_pair.second) { + LOG(DEBUG) << "casted edge[" << casted_input_name << "] for input[" << input_edge->GetName() << "] exists"; + casted_input_exits = true; + } else { + LOG(DEBUG) << "add casted edge[" << casted_input_name << "] for input[" << input_edge->GetName() << "] success"; + } + auto casted_input_edge = edge_ret_pair.first; + + auto cast_node_exists = false; + auto cast_node_name = GetCastNodeName(input_edge->GetName()); + auto node_ret_pair = topo->AddNode(cast_node_name); + auto cast_node = node_ret_pair.first; + if (!node_ret_pair.second) { + // we shoud check the cast method of cast_node + if (cast_node->GetType().domain != "opmx.f8f8" || cast_node->GetType().name != "OnlineCast") { + LOG(ERROR) << "cast node[" << cast_node_name << "] for input[" << input_edge->GetName() << "] exists, " + << "expect for [opmx.f8f8:OnlineCast] but given [" + << cast_node->GetType().domain << ":" << cast_node->GetType().name << "]"; + return ppl::common::RC_EXISTS; + } + LOG(DEBUG) << "cast node[" << cast_node_name << "] for input[" << input_edge->GetName() << "] exists"; + cast_node_exists = true; + } else { + LOG(DEBUG) << "add cast node[" << cast_node_name << "] for input[" << input_edge->GetName() << "] success"; + cast_node->SetType({"opmx.f8f8", "OnlineCast", 1}); + } + + bool input_has_casted = cast_node_exists && casted_input_exits; + if (!input_has_casted && (cast_node_exists || casted_input_exits)) { + LOG(ERROR) << "input[" << input_edge->GetName() << "] has not been completely f8f8 casted: " + << "cast_node(" << cast_node_exists << "), " + << "cast_input(" << casted_input_exits << ")"; + } + + { + // rearrange node and edge + // before: input_edge --> linear_node -> output_edge + // weight_edge -> | + // after: input_edge -> cast_node -> casted_input_edge ---> linear_node -> output_edge + // | casted_weight_edge -/ + // bias_edge -/ + linear_node->ReplaceInput(input_edge->GetId(), casted_input_edge->GetId()); + + if (!input_has_casted) { + cast_node->AddInput(input_edge->GetId()); + cast_node->AddOutput(casted_input_edge->GetId()); + } + + input_edge->DelConsumer(linear_node->GetId()); + input_edge->AddConsumer(cast_node->GetId()); + + casted_input_edge->AddConsumer(linear_node->GetId()); + casted_input_edge->SetProducer(cast_node->GetId()); + } + + { + if (!input_has_casted) { + auto cast_kernel = std::unique_ptr(new opmx::F8F8OnlineCastOp(cast_node)); + auto rc = cast_kernel->Init(options); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "init kernel[" << cast_kernel->GetNode()->GetName() << " failed: " << ppl::common::GetRetCodeStr(rc); + return rc; + } + kernels->emplace(cast_node->GetId(), std::move(cast_kernel)); + } + } + + return ppl::common::RC_SUCCESS; +} + + +static OptPassStatus CastColunmParallelLinear(ir::Node* linear_node, const OptKernelOptions& options) { + OptPassStatus status = {ppl::common::RC_SUCCESS, false}; + + auto param = std::static_pointer_cast(options.graph->data->attrs[linear_node->GetId()]); + const auto in_features = param->in_features; + const auto out_features_per_part = param->out_features / options.device->GetTensorParallelNcclParam()->size; + if ((in_features % 16 != 0) || (out_features_per_part % 16 != 0 )) { + LOG(WARNING) << "in_features and out_features_per_part should be aligned with 16 for f8f8 cast, " + <<"ColumnParallelLinear[" << linear_node->GetName() << "], whose weight is (" + << out_features_per_part << ", " << in_features << ") will not be casted"; + return status; + } + + + { + LOG(DEBUG) << "processing f8f8 for ColumnParallelLinear[" << linear_node->GetName() << "]"; + auto cast_ret = CastWeight(linear_node, options, in_features, out_features_per_part); + if (cast_ret.retcode != ppl::common::RC_SUCCESS) { + status.retcode = cast_ret.retcode; + status.graph_modified = true; + return status; + } + if (cast_ret.casted == false) { + return status; + } + + status.graph_modified = true; + status.retcode = CastLinear(linear_node, options, in_features, out_features_per_part); + if (ppl::common::RC_SUCCESS != status.retcode) { + return status; + } + } + + if (status.graph_modified) { + // change ColunmParallelLinear to f8f8.ColunmParallelLinear + linear_node->SetType({"opmx.f8f8", "ColumnParallelLinear", 1}); + auto cast_linear_kernel = new opmx::F8F8ColumnParallelLinearOp(linear_node); + status.retcode = cast_linear_kernel->Init(options); + if (ppl::common::RC_SUCCESS != status.retcode) { + LOG(ERROR) << "init kernel[" << cast_linear_kernel->GetNode()->GetName() + << " failed: " << ppl::common::GetRetCodeStr(status.retcode); + return status; + } + options.partition_info->kernels[linear_node->GetId()].reset(cast_linear_kernel); + LOG(DEBUG) << "process f8f8 for ColumnParallelLinear[" << linear_node->GetName() << "] success"; + } + + return status; +} + +static OptPassStatus CastRowParallelLinear(ir::Node* linear_node, const OptKernelOptions& options) { + OptPassStatus status = {ppl::common::RC_SUCCESS, false}; + + auto param = std::static_pointer_cast(options.graph->data->attrs[linear_node->GetId()]); + const auto in_features_per_part = param->in_features / options.device->GetTensorParallelNcclParam()->size; + const auto out_features = param->out_features; + if ((in_features_per_part % 16 != 0) || (out_features % 16 != 0 )) { + LOG(WARNING) << "in_features_per_part and out_features should be aligned with 16 for f8f8 cast, " + <<"ColumnParallelLinear[" << linear_node->GetName() << "], whose weight is (" + << out_features << ", " << in_features_per_part << ") will not be casted"; + return status; + } + + { + LOG(DEBUG) << "processing f8f8 for RowParallelLinear[" << linear_node->GetName() << "]"; + auto cast_ret = CastWeight(linear_node, options, in_features_per_part, out_features); + if (cast_ret.retcode != ppl::common::RC_SUCCESS) { + status.retcode = cast_ret.retcode; + status.graph_modified = true; + return status; + } + if (cast_ret.casted == false) { + return status; + } + + status.graph_modified = true; + status.retcode = CastLinear(linear_node, options, in_features_per_part, out_features); + if (ppl::common::RC_SUCCESS != status.retcode) { + return status; + } + } + + if (status.graph_modified) { + // change RowParallelLinear to f8f8.RowParallelLinear + linear_node->SetType({"opmx.f8f8", "RowParallelLinear", 1}); + auto cast_linear_kernel = new opmx::F8F8RowParallelLinearOp(linear_node); + status.retcode = cast_linear_kernel->Init(options); + if (ppl::common::RC_SUCCESS != status.retcode) { + LOG(ERROR) << "init kernel[" << cast_linear_kernel->GetNode()->GetName() + << " failed: " << ppl::common::GetRetCodeStr(status.retcode); + return status; + } + options.partition_info->kernels[linear_node->GetId()].reset(cast_linear_kernel); + LOG(DEBUG) << "process f8f8 for RowParallelLinear[" << linear_node->GetName() << "] success"; + } + + return status; +} + +OptPassStatus CastPass(const OptKernelOptions& options) +{ + OptPassStatus status = {ppl::common::RC_SUCCESS, false}; + + if (options.device->GetSMVersion() < 89) { + LOG(WARNING) << "f8f8 cast only support sm >= 89 now"; + return status; + } + + for (auto it = options.graph->topo->CreateNodeIter(); it->IsValid(); it->Forward()) { + auto node = it->Get(); + if (node->GetType().domain == "opmx" && node->GetType().name == "ColumnParallelLinear") { + auto ret = CastColunmParallelLinear(node, options); + status.graph_modified = status.graph_modified || ret.graph_modified; + status.retcode = ret.retcode; + if (ppl::common::RC_SUCCESS != status.retcode) + return status; + } + if (node->GetType().domain == "opmx" && node->GetType().name == "RowParallelLinear") { + auto ret = CastRowParallelLinear(node, options); + status.graph_modified = status.graph_modified || ret.graph_modified; + status.retcode = ret.retcode; + if (ppl::common::RC_SUCCESS != status.retcode) + return status; + } + } + + return status; +} + +}}}}} + +#endif \ No newline at end of file diff --git a/src/ppl/nn/engines/llm_cuda/passes/f8f8/cast_pass.h b/src/ppl/nn/engines/llm_cuda/passes/f8f8/cast_pass.h new file mode 100644 index 000000000..48153dae4 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/passes/f8f8/cast_pass.h @@ -0,0 +1,16 @@ +#ifdef PPLNN_ENABLE_FP8 + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_PASSES_F8F8_CAST_PASS_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_PASSES_F8F8_CAST_PASS_H_ + +#include "ppl/nn/engines/llm_cuda/opt_pass.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace f8f8 { + +OptPassStatus CastPass(const OptKernelOptions& options); + +}}}}} + +#endif +#endif + diff --git a/src/ppl/nn/engines/llm_cuda/passes/i4f16/quantization_pass.cc b/src/ppl/nn/engines/llm_cuda/passes/i4f16/quantization_pass.cc new file mode 100644 index 000000000..a08194f9b --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/passes/i4f16/quantization_pass.cc @@ -0,0 +1,365 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "quantization_pass.h" + +#include "ppl/nn/params/opmx/column_parallel_linear_param.h" + +#include "ppl/nn/engines/llm_cuda/ops/opmx/i4f16/column_parallel_linear_op.h" +#include "ppl/nn/engines/llm_cuda/ops/opmx/i4f16/row_parallel_linear_op.h" + +#include "ppl/kernel/llm/cuda/pmx/i4f16/quantize.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace i4f16 { + +static std::string GetScaleName(const std::string& tensor_name) { + return tensor_name + ".scale"; +} + +struct QuantizeWeightResult final { + ppl::common::RetCode retcode; + ir::Edge* scale_edge; +}; + +// return scale edge +static QuantizeWeightResult QuantizeWeight( + ir::Node* linear_node, + const OptKernelOptions& options, + const int64_t in_features, + const int64_t out_features) +{ + auto topo = options.graph->topo.get(); + auto constants = &options.graph->data->constants; + auto shapes = &options.graph->data->shapes; + auto loaded_constants = &options.partition_info->constants; + + auto weight_edge = topo->GetEdge(linear_node->GetInput(1)); + auto scale_name = GetScaleName(weight_edge->GetName()); + + std::set consumer_white_list = { + "ColumnParallelLinear", + "RowParallelLinear", + }; + + for (auto iter = weight_edge->CreateConsumerIter(); iter.IsValid(); iter.Forward()) { + auto &consumer_type = topo->GetNode(iter.Get())->GetType(); + if (consumer_white_list.find(consumer_type.name) == consumer_white_list.end()) { + LOG(WARNING) << "failed to i4f16 quantize weight[" << weight_edge->GetName() << "], " + << "met unsupported consumer type [" << consumer_type.domain << ":" << consumer_type.name << "]"; + return {ppl::common::RC_SUCCESS, nullptr}; + } + } + + const int64_t out_features_pack_size = 4; + const int64_t weight_quant_group_size = 128; + + if (out_features % out_features_pack_size != 0) { + LOG(WARNING) << "only support out_features(" << out_features << ") aligned with " << out_features_pack_size; + return {ppl::common::RC_SUCCESS, nullptr}; + } + + // N must be aligned to 128 for int4 gemm api + if (out_features % weight_quant_group_size != 0) { + LOG(WARNING) << "only support out_features(" << out_features << ") aligned with " << weight_quant_group_size; + return {ppl::common::RC_SUCCESS, nullptr}; + } + + if (in_features % weight_quant_group_size != 0) { + LOG(WARNING) << "only support in_features(" << in_features << ") aligned with " << weight_quant_group_size; + return {ppl::common::RC_SUCCESS, nullptr}; + } + + // check wether this weight has been processed + if (loaded_constants->find(weight_edge->GetId()) != loaded_constants->end()) { + auto scale_edge = topo->GetEdge(scale_name); + return {ppl::common::RC_SUCCESS, scale_edge}; + } + + auto weight_shape = &shapes->at(weight_edge->GetId()); + if (weight_shape->data_type != ppl::common::DATATYPE_FLOAT16) { + LOG(WARNING) << "only support i4f16 quantize for fp16 weight"; + return {ppl::common::RC_SUCCESS, nullptr}; + } + + // add constant scale edge and check + auto ret_pair = topo->AddEdge(scale_name); + if (!ret_pair.second) { + LOG(ERROR) << "add scale edge[" << scale_name << "] for weight[" << weight_edge->GetName() << "] failed"; + return {ppl::common::RC_EXISTS, nullptr}; + } + LOG(DEBUG) << "add scale edge[" << scale_name << "] for weight[" << weight_edge->GetName() << "] success"; + auto scale_edge = ret_pair.first; + topo->MarkAsConstant(scale_edge->GetId()); + + ppl::common::RetCode rc; + + // alloc buffer for quantized weight + RuntimeConstantInfo quantized_weight_buffer; + quantized_weight_buffer.GetShape()->Reshape({out_features / out_features_pack_size, in_features}); + quantized_weight_buffer.GetShape()->SetDataType(ppl::common::DATATYPE_INT4X4); + quantized_weight_buffer.GetShape()->SetDataFormat(ppl::common::DATAFORMAT_NDARRAY); + quantized_weight_buffer.SetDevice(options.device); + rc = quantized_weight_buffer.ReallocBuffer(); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "realloc buffer for quantize weight[" << weight_edge->GetName() << "] failed: " << ppl::common::GetRetCodeStr(rc); + return {ppl::common::RC_OUT_OF_MEMORY, nullptr}; + } + + // alloc buffer for scale + RuntimeConstantInfo scale_buffer; + scale_buffer.GetShape()->Reshape({in_features / weight_quant_group_size, out_features}); + scale_buffer.GetShape()->SetDataType(weight_shape->data_type); + scale_buffer.GetShape()->SetDataFormat(ppl::common::DATAFORMAT_NDARRAY); + scale_buffer.SetDevice(options.device); + rc = scale_buffer.ReallocBuffer(); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "realloc buffer for scale of weight[" << weight_edge->GetName() << "] failed: " << ppl::common::GetRetCodeStr(rc); + return {ppl::common::RC_OUT_OF_MEMORY, nullptr}; + } + + // alloc buffer for origin weight at last + // NOTE: it must be alloced at last to avoid memory fragmentation when it freed after being quantized + RuntimeConstantInfo weight_buffer; + weight_buffer.GetShape()->Reshape({out_features, in_features}); + weight_buffer.GetShape()->SetDataType(weight_shape->data_type); + weight_buffer.GetShape()->SetDataFormat(ppl::common::DATAFORMAT_NDARRAY); + weight_buffer.SetDevice(options.device); + + // use zero copy to reduce GPU memory fragmentation + void* weight_pinned_host_buffer = nullptr; + auto cuda_err = cudaMallocHost(&weight_pinned_host_buffer, weight_buffer.GetShape()->CalcBytesIncludingPadding(), cudaHostAllocMapped); + if (cudaSuccess != cuda_err) { + LOG(ERROR) << "realloc pinned buffer for weight[" << weight_edge->GetName() << "] failed: " << cudaGetErrorString(cuda_err); + return {ppl::common::RC_OUT_OF_MEMORY, nullptr}; + } + void *weight_pinned_dev_buffer = nullptr; + cuda_err = cudaHostGetDevicePointer(&weight_pinned_dev_buffer, weight_pinned_host_buffer, 0); + if (cudaSuccess != cuda_err) { + LOG(ERROR) << "get device pinned buffer for weight[" << weight_edge->GetName() << "] failed: " << cudaGetErrorString(cuda_err); + return {ppl::common::RC_DEVICE_MEMORY_ERROR, nullptr}; + } + weight_buffer.SetBuffer(weight_pinned_dev_buffer); + + // copy fp16 data to pinned memory for quantize + auto weight_host = &constants->at(weight_edge->GetId()); + memcpy(weight_pinned_host_buffer, weight_host->data.GetData(), weight_host->data.GetSize()); + constants->erase(weight_edge->GetId()); + + rc = ppl::kernel::llm::cuda::pmx::i4f16::minmax_quantize_fp16( + options.device->GetStream(), + weight_buffer.GetBufferPtr(), + out_features, + in_features, + weight_quant_group_size, + quantized_weight_buffer.GetBufferPtr(), + scale_buffer.GetBufferPtr()); + + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "do quantize for weight[" << weight_edge->GetName() << "] failed: " << ppl::common::GetRetCodeStr(rc); + topo->DelEdge(scale_edge->GetId()); + return {ppl::common::RC_DEVICE_RUNTIME_ERROR, nullptr}; + } + rc = options.device->Synchronize(); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "synchronize quantize for weight[" << weight_edge->GetName() << "] failed: " << ppl::common::GetRetCodeStr(rc); + topo->DelEdge(scale_edge->GetId()); + return {ppl::common::RC_DEVICE_RUNTIME_ERROR, nullptr}; + } + + // fill constant scale shape + auto scale_shape = &shapes->emplace(scale_edge->GetId(), std::move(ir::Shape())).first->second; + scale_shape->data_type = weight_shape->data_type; + scale_shape->data_format = ppl::common::DATAFORMAT_NDARRAY; + scale_shape->dims = {in_features / weight_quant_group_size, out_features}; + + // change weight shape and datatype + weight_shape->data_type = ppl::common::DATATYPE_INT4X4; + weight_shape->dims = {out_features / out_features_pack_size, in_features}; + + // emplace GPU buffer to runtime constants + loaded_constants->emplace(weight_edge->GetId(), std::move(quantized_weight_buffer)); + loaded_constants->emplace(scale_edge->GetId(), std::move(scale_buffer)); + + cuda_err = cudaFreeHost(weight_pinned_host_buffer); + if (cudaSuccess != cuda_err) { + LOG(ERROR) << "free pinned buffer for weight[" << weight_edge->GetName() << "] failed: " << cudaGetErrorString(cuda_err); + return {ppl::common::RC_DEVICE_MEMORY_ERROR, nullptr}; + } + + return {ppl::common::RC_SUCCESS, scale_edge}; +} + +static ppl::common::RetCode QuantizeLinearWeightOnly( + ir::Node* linear_node, + const OptKernelOptions& options, + const int64_t in_features, + const int64_t out_features) +{ + auto topo = options.graph->topo.get(); + + auto weight_edge = topo->GetEdge(linear_node->GetInput(1)); + auto weight_scale_edge = topo->GetEdge(GetScaleName(weight_edge->GetName())); + if (weight_scale_edge == nullptr) { + LOG(ERROR) << "scale edge[" << GetScaleName(weight_edge->GetName()) << "] not found"; + return ppl::common::RC_NOT_FOUND; + } + + { + // rearrange node and edge + // before: input_edge --> linear_node -> output_edge + // weight_edge -| + // bias_edge -/ + // after: input_edge --> linear_node -> output_edge + // weight_edge -| + // weight_scale_edge -| + // bias_edge -/ + linear_node->InsertInput(2, weight_scale_edge->GetId()); + weight_scale_edge->AddConsumer(linear_node->GetId()); + } + + return ppl::common::RC_SUCCESS; +} + +static OptPassStatus QuantizeColunmParallelLinear(ir::Node* linear_node, const OptKernelOptions& options) { + OptPassStatus status = {ppl::common::RC_SUCCESS, false}; + + auto param = std::static_pointer_cast(options.graph->data->attrs[linear_node->GetId()]); + const auto in_features = param->in_features; + const auto out_features_per_part = param->out_features / options.device->GetTensorParallelNcclParam()->size; + if ((in_features % 32 != 0) || (out_features_per_part % 32 != 0 )) { + LOG(WARNING) << "in_features and out_features_per_part should be aligned with 32 for i4f16 quantization, " + <<"ColumnParallelLinear[" << linear_node->GetName() << "], whose weight is (" + << out_features_per_part << ", " << in_features << ") will not be quantized"; + return status; + } + + { + LOG(DEBUG) << "processing i4f16 for ColumnParallelLinear[" << linear_node->GetName() << "]"; + auto quantize_ret = QuantizeWeight(linear_node, options, in_features, out_features_per_part); + if (quantize_ret.retcode != ppl::common::RC_SUCCESS) { + status.retcode = quantize_ret.retcode; + status.graph_modified = true; + return status; + } + if (quantize_ret.scale_edge == nullptr) { + return status; + } + + status.graph_modified = true; + status.retcode = QuantizeLinearWeightOnly(linear_node, options, in_features, out_features_per_part); + if (ppl::common::RC_SUCCESS != status.retcode) { + return status; + } + } + + if (status.graph_modified) { + // change ColunmParallelLinear to i4f16.ColunmParallelLinear + linear_node->SetType({"opmx.i4f16", "ColumnParallelLinear", 1}); + auto q_linear_kernel = new opmx::I4F16ColumnParallelLinearOp(linear_node); + status.retcode = q_linear_kernel->Init(options); + if (ppl::common::RC_SUCCESS != status.retcode) { + LOG(ERROR) << "init kernel[" << q_linear_kernel->GetNode()->GetName() + << " failed: " << ppl::common::GetRetCodeStr(status.retcode); + return status; + } + options.partition_info->kernels[linear_node->GetId()].reset(q_linear_kernel); + LOG(DEBUG) << "process i4f16 for ColumnParallelLinear[" << linear_node->GetName() << "] success"; + } + + return status; +} + +static OptPassStatus QuantizeRowParallelLinear(ir::Node* linear_node, const OptKernelOptions& options) { + OptPassStatus status = {ppl::common::RC_SUCCESS, false}; + + auto param = std::static_pointer_cast(options.graph->data->attrs[linear_node->GetId()]); + const auto in_features_per_part = param->in_features / options.device->GetTensorParallelNcclParam()->size; + const auto out_features = param->out_features; + if ((in_features_per_part % 32 != 0) || (out_features % 32 != 0 )) { + LOG(WARNING) << "in_features_per_part and out_features should be aligned with 32 for i4f16 quantization, " + <<"ColumnParallelLinear[" << linear_node->GetName() << "], whose weight is (" + << out_features << ", " << in_features_per_part << ") will not be quantized"; + return status; + } + + { + LOG(DEBUG) << "processing i4f16 for RowParallelLinear[" << linear_node->GetName() << "]"; + auto quantize_ret = QuantizeWeight(linear_node, options, in_features_per_part, out_features); + if (quantize_ret.retcode != ppl::common::RC_SUCCESS) { + status.retcode = quantize_ret.retcode; + status.graph_modified = true; + return status; + } + if (quantize_ret.scale_edge == nullptr) { + return status; + } + + status.graph_modified = true; + status.retcode = QuantizeLinearWeightOnly(linear_node, options, in_features_per_part, out_features); + if (ppl::common::RC_SUCCESS != status.retcode) { + return status; + } + } + + if (status.graph_modified) { + // change RowParallelLinear to i4f16.RowParallelLinear + linear_node->SetType({"opmx.i4f16", "RowParallelLinear", 1}); + auto q_linear_kernel = new opmx::I4F16RowParallelLinearOp(linear_node); + status.retcode = q_linear_kernel->Init(options); + if (ppl::common::RC_SUCCESS != status.retcode) { + LOG(ERROR) << "init kernel[" << q_linear_kernel->GetNode()->GetName() + << " failed: " << ppl::common::GetRetCodeStr(status.retcode); + return status; + } + options.partition_info->kernels[linear_node->GetId()].reset(q_linear_kernel); + LOG(DEBUG) << "process i4f16 for RowParallelLinear[" << linear_node->GetName() << "] success"; + } + + return status; +} + +OptPassStatus QuantizationPass(const OptKernelOptions& options) +{ + OptPassStatus status = {ppl::common::RC_SUCCESS, false}; + + if (options.device->GetSMVersion() < 80) { + LOG(WARNING) << "i4f16 quantize only support sm >= 80 now"; + return status; + } + + for (auto it = options.graph->topo->CreateNodeIter(); it->IsValid(); it->Forward()) { + auto node = it->Get(); + if (node->GetType().domain == "opmx" && node->GetType().name == "ColumnParallelLinear") { + auto ret = QuantizeColunmParallelLinear(node, options); + status.graph_modified = status.graph_modified || ret.graph_modified; + status.retcode = ret.retcode; + if (ppl::common::RC_SUCCESS != status.retcode) + return status; + } + if (node->GetType().domain == "opmx" && node->GetType().name == "RowParallelLinear") { + auto ret = QuantizeRowParallelLinear(node, options); + status.graph_modified = status.graph_modified || ret.graph_modified; + status.retcode = ret.retcode; + if (ppl::common::RC_SUCCESS != status.retcode) + return status; + } + } + + return status; +} + +}}}}} diff --git a/src/ppl/nn/engines/llm_cuda/passes/i4f16/quantization_pass.h b/src/ppl/nn/engines/llm_cuda/passes/i4f16/quantization_pass.h new file mode 100644 index 000000000..6f9f5fd3a --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/passes/i4f16/quantization_pass.h @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_PASSES_I4F16_QUANTIZATION_PASS_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_PASSES_I4F16_QUANTIZATION_PASS_H_ + +#include "ppl/nn/engines/llm_cuda/opt_pass.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace i4f16 { + +OptPassStatus QuantizationPass(const OptKernelOptions& options); + +}}}}} + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_silu_pass.cc b/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_silu_pass.cc new file mode 100644 index 000000000..c8b84b4cd --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_silu_pass.cc @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "fuse_silu_pass.h" + +#include "ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_silu_quantize_op.h" +#include "ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_op.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace i8i8 { + +OptPassStatus FuseSiLUPass(const OptKernelOptions& options) +{ + OptPassStatus status = {ppl::common::RC_SUCCESS, false}; + + for (auto it = options.graph->topo->CreateNodeIter(); it->IsValid(); it->Forward()) { + auto silu_node = it->Get(); + if (silu_node->GetType().domain == "opmx" && silu_node->GetType().name == "SiLU") { + + auto topo = options.graph->topo.get(); + auto kernels = &options.partition_info->kernels; + + auto output_edge = topo->GetEdge(silu_node->GetOutput(0)); + auto input_edge = topo->GetEdge(silu_node->GetInput(0)); + auto gate_edge = silu_node->GetInputCount() > 1 ? topo->GetEdge(silu_node->GetInput(1)) : nullptr; + + auto q_node = topo->GetNode(output_edge->CreateConsumerIter().Get()); + auto dq_node = topo->GetNode(input_edge->GetProducer()); + + if (output_edge->CalcConsumerCount() == 1 + && q_node + && q_node->GetType().domain == "opmx.i8i8" + && q_node->GetType().name == "OnlineQuantize" + && input_edge->CalcConsumerCount() == 1 + && dq_node + && dq_node->GetType().domain == "opmx.i8i8" + && dq_node->GetType().name == "OnlineDequantize" + ) + { + auto dq_kernel = (opmx::I8I8OnlineDequantizeOp*)options.partition_info->kernels[dq_node->GetId()].get(); + if (dq_kernel->GetParam()->bias_term) + continue; + + if (gate_edge) { + auto gate_dq_node = topo->GetNode(gate_edge->GetProducer()); + auto gate_dq_kernel = gate_dq_node ? (opmx::I8I8OnlineDequantizeOp*)options.partition_info->kernels[gate_dq_node->GetId()].get() : nullptr; + if (!(gate_edge->CalcConsumerCount() == 1 + && gate_dq_node + && gate_dq_node->GetType().domain == "opmx.i8i8" + && gate_dq_node->GetType().name == "OnlineDequantize" + && gate_dq_kernel->GetParam()->bias_term == false)) + { + continue; + } + } + + status.graph_modified = status.graph_modified || true; + + auto q_output_edge = topo->GetEdge(q_node->GetOutput(0)); + auto scale_edge = topo->GetEdge(q_node->GetOutput(1)); + + auto dq_input_edge = topo->GetEdge(dq_node->GetInput(0)); + auto dq_scale_outer_edge = topo->GetEdge(dq_node->GetInput(1)); + auto dq_scale_inner_edge = topo->GetEdge(dq_node->GetInput(2)); + + // form: (dq_input_edge, scale_outer, scale_inner) -> dq_node -> input_edge -> silu_node -> output_edge -> q_node -> (q_output_edge, scale_edge) + // to : (dq_input_edge, scale_outer, scale_inner) -> silu_node -> (q_output_edge, scale_edge) + silu_node->ReplaceOutput(output_edge->GetId(), q_output_edge->GetId()); + silu_node->AddOutput(scale_edge->GetId()); + + silu_node->ClearInputs(); + silu_node->AddInput(dq_input_edge->GetId()); + silu_node->AddInput(dq_scale_outer_edge->GetId()); + silu_node->AddInput(dq_scale_inner_edge->GetId()); + + dq_input_edge->DelConsumer(dq_node->GetId()); + dq_input_edge->AddConsumer(silu_node->GetId()); + + dq_scale_outer_edge->DelConsumer(dq_node->GetId()); + dq_scale_outer_edge->AddConsumer(silu_node->GetId()); + + dq_scale_inner_edge->DelConsumer(dq_node->GetId()); + dq_scale_inner_edge->AddConsumer(silu_node->GetId()); + + q_output_edge->SetProducer(silu_node->GetId()); + scale_edge->SetProducer(silu_node->GetId()); + + kernels->erase(q_node->GetId()); + topo->DelEdge(output_edge->GetId()); + topo->DelNode(q_node->GetId()); + + kernels->erase(dq_node->GetId()); + topo->DelEdge(input_edge->GetId()); + topo->DelNode(dq_node->GetId()); + + if (gate_edge) { + dq_node = topo->GetNode(gate_edge->GetProducer()); + dq_kernel = (opmx::I8I8OnlineDequantizeOp*)options.partition_info->kernels[dq_node->GetId()].get(); + + dq_input_edge = topo->GetEdge(dq_node->GetInput(0)); + dq_scale_outer_edge = topo->GetEdge(dq_node->GetInput(1)); + dq_scale_inner_edge = topo->GetEdge(dq_node->GetInput(2)); + + silu_node->AddInput(dq_input_edge->GetId()); + silu_node->AddInput(dq_scale_outer_edge->GetId()); + silu_node->AddInput(dq_scale_inner_edge->GetId()); + + dq_input_edge->DelConsumer(dq_node->GetId()); + dq_input_edge->AddConsumer(silu_node->GetId()); + + dq_scale_outer_edge->DelConsumer(dq_node->GetId()); + dq_scale_outer_edge->AddConsumer(silu_node->GetId()); + + dq_scale_inner_edge->DelConsumer(dq_node->GetId()); + dq_scale_inner_edge->AddConsumer(silu_node->GetId()); + + kernels->erase(dq_node->GetId()); + topo->DelEdge(gate_edge->GetId()); + topo->DelNode(dq_node->GetId()); + } + + silu_node->SetType({"opmx.i8i8", "OnlineDequantizeSiLUQuantize", 1}); + auto dq_silu_q_kernel = new opmx::I8I8OnlineDequantizeSiLUQuantizeOp(silu_node); + status.retcode = dq_silu_q_kernel->Init(options); + if (ppl::common::RC_SUCCESS != status.retcode) { + LOG(ERROR) << "init kernel[" << dq_silu_q_kernel->GetNode()->GetName() + << " failed: " << ppl::common::GetRetCodeStr(status.retcode); + return status; + } + options.partition_info->kernels[silu_node->GetId()].reset(dq_silu_q_kernel); + LOG(DEBUG) << "process fuse for SiLU[" << silu_node->GetName() << "] success"; + } + } + } + + return status; +} + +}}}}} diff --git a/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_silu_pass.h b/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_silu_pass.h new file mode 100644 index 000000000..9ccea99f8 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_silu_pass.h @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_PASSES_I8I8_FUSE_SILU_PASS_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_PASSES_I8I8_FUSE_SILU_PASS_H_ + +#include "ppl/nn/engines/llm_cuda/opt_pass.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace i8i8 { + +OptPassStatus FuseSiLUPass(const OptKernelOptions& options); + +}}}}} + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_split_pass.cc b/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_split_pass.cc new file mode 100644 index 000000000..62c073d97 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_split_pass.cc @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "fuse_split_pass.h" + +#include "ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_reshape_split_op.h" +#include "ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_op.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace i8i8 { + +OptPassStatus FuseSplitPass(const OptKernelOptions& options) +{ + OptPassStatus status = {ppl::common::RC_SUCCESS, false}; + + for (auto it = options.graph->topo->CreateNodeIter(); it->IsValid(); it->Forward()) { + auto split_node = it->Get(); + if (split_node->GetType().domain == "" && split_node->GetType().name == "Split") { + + auto topo = options.graph->topo.get(); + auto kernels = &options.partition_info->kernels; + auto data = options.graph->data.get(); + + auto input_edge = topo->GetEdge(split_node->GetInput(0)); + auto split_edge = topo->GetEdge(split_node->GetInput(1)); + auto reshape_node = topo->GetNode(input_edge->GetProducer()); + + // is this a static split? + std::vector constant_split_data; + { + auto split_data_it = data->constants.find(split_edge->GetId()); + const int64_t* split_data = nullptr; + if (split_data_it != data->constants.end()) { + split_data = (const int64_t*)split_data_it->second.data.GetData(); + } + + if (split_data != nullptr) { + auto split_shape_it = data->shapes.find(split_edge->GetId()); + if (split_shape_it != data->shapes.end()) { + auto& split_shape = split_shape_it->second; + constant_split_data.assign(split_data, split_data + split_shape.dims[0]); + } + } + } + if (constant_split_data.size() == 0) + continue; + + // find it's param + ppl::nn::onnx::SplitParam* split_param = nullptr; + { + auto param_ref = data->attrs.find(split_node->GetId()); + if (param_ref != data->attrs.end()) { + split_param = (ppl::nn::onnx::SplitParam*)param_ref->second.get(); + } + } + if (split_param == nullptr) + continue; + + if (!reshape_node) + continue; + auto rs_input_edge = topo->GetEdge(reshape_node->GetInput(0)); + auto shape_edge = topo->GetEdge(reshape_node->GetInput(1)); + auto dq_node = topo->GetNode(rs_input_edge->GetProducer()); + + if (input_edge->CalcConsumerCount() == 1 + && reshape_node->GetType().domain == "" + && reshape_node->GetType().name == "Reshape" + && dq_node + && dq_node->GetType().domain == "opmx.i8i8" + && dq_node->GetType().name == "OnlineDequantize") + { + // is this a static reshape? + std::vector constant_shape_data; + { + auto shape_data_it = data->constants.find(shape_edge->GetId()); + const int64_t* shape_data = nullptr; + if (shape_data_it != data->constants.end()) { + shape_data = (const int64_t*)shape_data_it->second.data.GetData(); + } + + if (shape_data != nullptr) { + auto shape_shape_it = data->shapes.find(shape_edge->GetId()); + if (shape_shape_it != data->shapes.end()) { + auto& shape_shape = shape_shape_it->second; + constant_shape_data.assign(shape_data, shape_data + shape_shape.dims[0]); + } + } + } + if (constant_shape_data.size() == 0) + continue; + + // We only accept reshape like (0,0,...,0,A,B,C,D,E,...) + // axis here -| + // Where A is -1 or postive + // And B,C,D,E.. must be postive + int64_t prefix_dim = 0; + int64_t suffix_dim = 1; + int64_t dim_count = (int64_t)constant_shape_data.size(); + int64_t axis = split_param->axis < 0 ? split_param->axis + dim_count : split_param->axis; + if (constant_shape_data[axis] == 0) + continue; + for (int64_t i = 0; i < axis; ++i) + prefix_dim += constant_shape_data[i]; + for (int32_t i = axis + 1; i < dim_count; ++i) + suffix_dim *= constant_shape_data[i]; + if (prefix_dim != 0 || suffix_dim <= 0) + continue; + + // now we can do optimize + status.graph_modified = status.graph_modified || true; + + auto dq_kernel = (opmx::I8I8OnlineDequantizeOp*)options.partition_info->kernels[dq_node->GetId()].get(); + auto dq_input_edge = topo->GetEdge(dq_node->GetInput(0)); + auto dq_scale_outer_edge = topo->GetEdge(dq_node->GetInput(1)); + auto dq_scale_inner_edge = topo->GetEdge(dq_node->GetInput(2)); + auto dq_bias_edge = dq_kernel->GetParam()->bias_term ? topo->GetEdge(dq_node->GetInput(3)) : nullptr; + + // form: dq_input_edge -> dq_node -> rs_input_edge -> reshape_node -> input_edge -> split_node + // dq_scale_outer_edge -| shape_edge -/ split_edge -/ + // dq_scale_inner_edge -| + // dq_bias_edge -/ + // to : dq_input_edge -> split_node + // dq_scale_outer_edge -| + // dq_scale_inner_edge -| + // dq_bias_edge -/ + split_node->ReplaceInput(input_edge->GetId(), dq_input_edge->GetId()); + split_node->ReplaceInput(split_edge->GetId(), dq_scale_outer_edge->GetId()); + split_node->AddInput(dq_scale_inner_edge->GetId()); + if (dq_bias_edge) + split_node->AddInput(dq_bias_edge->GetId()); + + dq_input_edge->DelConsumer(dq_node->GetId()); + dq_input_edge->AddConsumer(split_node->GetId()); + + dq_scale_outer_edge->DelConsumer(dq_node->GetId()); + dq_scale_outer_edge->AddConsumer(split_node->GetId()); + + dq_scale_inner_edge->DelConsumer(dq_node->GetId()); + dq_scale_inner_edge->AddConsumer(split_node->GetId()); + + shape_edge->DelConsumer(reshape_node->GetId()); // do not delete this edge + + split_edge->DelConsumer(split_node->GetId()); // do not delete this edge + + if (dq_bias_edge) { + dq_bias_edge->DelConsumer(dq_node->GetId()); + dq_bias_edge->AddConsumer(split_node->GetId()); + } + + kernels->erase(dq_node->GetId()); + kernels->erase(reshape_node->GetId()); + topo->DelEdge(input_edge->GetId()); + topo->DelNode(dq_node->GetId()); + topo->DelEdge(rs_input_edge->GetId()); + topo->DelNode(reshape_node->GetId()); + + split_node->SetType({"opmx.i8i8", "OnlineDequantizeReshapeSplit", 1}); + auto dq_split_kernel = new opmx::I8I8OnlineDequantizeReshapeSplitOp(split_node); + // set params + dq_split_kernel->GetParam()->bias_term = dq_bias_edge != nullptr; + dq_split_kernel->GetParam()->shape.assign(constant_shape_data.begin(), constant_shape_data.end()); + dq_split_kernel->GetParam()->split.assign(constant_split_data.begin(), constant_split_data.end()); + + status.retcode = dq_split_kernel->Init(options); + if (ppl::common::RC_SUCCESS != status.retcode) { + LOG(ERROR) << "init kernel[" << dq_split_kernel->GetNode()->GetName() + << " failed: " << ppl::common::GetRetCodeStr(status.retcode); + return status; + } + options.partition_info->kernels[split_node->GetId()].reset(dq_split_kernel); + LOG(DEBUG) << "process fuse for Split[" << split_node->GetName() << "] success"; + } + } + } + + return status; +} + +}}}}} diff --git a/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_split_pass.h b/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_split_pass.h new file mode 100644 index 000000000..601ed169c --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_split_pass.h @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_PASSES_I8I8_FUSE_SPLIT_PASS_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_PASSES_I8I8_FUSE_SPLIT_PASS_H_ + +#include "ppl/nn/engines/llm_cuda/opt_pass.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace i8i8 { + +OptPassStatus FuseSplitPass(const OptKernelOptions& options); + +}}}}} + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_swiglu_pass.cc b/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_swiglu_pass.cc new file mode 100644 index 000000000..c15822f58 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_swiglu_pass.cc @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "fuse_swiglu_pass.h" + +#include "ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_swiglu_quantize_op.h" +#include "ppl/nn/engines/llm_cuda/ops/opmx/i8i8/online_dequantize_op.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace i8i8 { + +OptPassStatus FuseSwiGLUPass(const OptKernelOptions& options) +{ + OptPassStatus status = {ppl::common::RC_SUCCESS, false}; + + for (auto it = options.graph->topo->CreateNodeIter(); it->IsValid(); it->Forward()) { + auto swiglu_node = it->Get(); + if (swiglu_node->GetType().domain == "opmx" && swiglu_node->GetType().name == "SwiGLU") { + + auto topo = options.graph->topo.get(); + auto kernels = &options.partition_info->kernels; + + auto output_edge = topo->GetEdge(swiglu_node->GetOutput(0)); + auto input_edge = topo->GetEdge(swiglu_node->GetInput(0)); + + auto q_node = topo->GetNode(output_edge->CreateConsumerIter().Get()); + auto dq_node = topo->GetNode(input_edge->GetProducer()); + + if (output_edge->CalcConsumerCount() == 1 + && q_node + && q_node->GetType().domain == "opmx.i8i8" + && q_node->GetType().name == "OnlineQuantize" + && input_edge->CalcConsumerCount() == 1 + && dq_node + && dq_node->GetType().domain == "opmx.i8i8" + && dq_node->GetType().name == "OnlineDequantize" + ) + { + auto dq_kernel = (opmx::I8I8OnlineDequantizeOp*)options.partition_info->kernels[dq_node->GetId()].get(); + if (dq_kernel->GetParam()->bias_term) + continue; + + status.graph_modified = status.graph_modified || true; + + auto q_output_edge = topo->GetEdge(q_node->GetOutput(0)); + auto scale_edge = topo->GetEdge(q_node->GetOutput(1)); + + auto dq_input_edge = topo->GetEdge(dq_node->GetInput(0)); + auto dq_scale_outer_edge = topo->GetEdge(dq_node->GetInput(1)); + auto dq_scale_inner_edge = topo->GetEdge(dq_node->GetInput(2)); + + // form: (dq_input_edge, scale_outer, scale_inner) -> dq_node -> input_edge -> swiglu_node -> output_edge -> q_node -> (q_output_edge, scale_edge) + // to : (dq_input_edge, scale_outer, scale_inner) -> swiglu_node -> (q_output_edge, scale_edge) + swiglu_node->ReplaceOutput(output_edge->GetId(), q_output_edge->GetId()); + swiglu_node->AddOutput(scale_edge->GetId()); + + swiglu_node->ClearInputs(); + swiglu_node->AddInput(dq_input_edge->GetId()); + swiglu_node->AddInput(dq_scale_outer_edge->GetId()); + swiglu_node->AddInput(dq_scale_inner_edge->GetId()); + + dq_input_edge->DelConsumer(dq_node->GetId()); + dq_input_edge->AddConsumer(swiglu_node->GetId()); + + dq_scale_outer_edge->DelConsumer(dq_node->GetId()); + dq_scale_outer_edge->AddConsumer(swiglu_node->GetId()); + + dq_scale_inner_edge->DelConsumer(dq_node->GetId()); + dq_scale_inner_edge->AddConsumer(swiglu_node->GetId()); + + q_output_edge->SetProducer(swiglu_node->GetId()); + scale_edge->SetProducer(swiglu_node->GetId()); + + kernels->erase(q_node->GetId()); + topo->DelEdge(output_edge->GetId()); + topo->DelNode(q_node->GetId()); + + kernels->erase(dq_node->GetId()); + topo->DelEdge(input_edge->GetId()); + topo->DelNode(dq_node->GetId()); + + swiglu_node->SetType({"opmx.i8i8", "OnlineDequantizeSwiGLUQuantize", 1}); + auto dq_swiglu_q_kernel = new opmx::I8I8OnlineDequantizeSwiGLUQuantizeOp(swiglu_node); + status.retcode = dq_swiglu_q_kernel->Init(options); + if (ppl::common::RC_SUCCESS != status.retcode) { + LOG(ERROR) << "init kernel[" << dq_swiglu_q_kernel->GetNode()->GetName() + << " failed: " << ppl::common::GetRetCodeStr(status.retcode); + return status; + } + options.partition_info->kernels[swiglu_node->GetId()].reset(dq_swiglu_q_kernel); + LOG(DEBUG) << "process fuse for SwiGLU[" << swiglu_node->GetName() << "] success"; + } + } + } + + return status; +} + +}}}}} diff --git a/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_swiglu_pass.h b/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_swiglu_pass.h new file mode 100644 index 000000000..6df17bd21 --- /dev/null +++ b/src/ppl/nn/engines/llm_cuda/passes/i8i8/fuse_swiglu_pass.h @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_PASSES_I8I8_FUSE_SWIGLU_PASS_H_ +#define _ST_HPC_PPL_NN_ENGINES_LLM_CUDA_PASSES_I8I8_FUSE_SWIGLU_PASS_H_ + +#include "ppl/nn/engines/llm_cuda/opt_pass.h" + +namespace ppl { namespace nn { namespace llm { namespace cuda { namespace i8i8 { + +OptPassStatus FuseSwiGLUPass(const OptKernelOptions& options); + +}}}}} + +#endif diff --git a/src/ppl/nn/engines/llm_cuda/pmx/generated/llm_cuda_op_i8i8_params_generated.h b/src/ppl/nn/engines/llm_cuda/pmx/generated/llm_cuda_op_i8i8_params_generated.h index bb698121f..811863eb3 100644 --- a/src/ppl/nn/engines/llm_cuda/pmx/generated/llm_cuda_op_i8i8_params_generated.h +++ b/src/ppl/nn/engines/llm_cuda/pmx/generated/llm_cuda_op_i8i8_params_generated.h @@ -23,35 +23,41 @@ namespace i8i8 { struct OnlineDequantizeParam; struct OnlineDequantizeParamBuilder; +struct OnlineDequantizeReshapeSplitParam; +struct OnlineDequantizeReshapeSplitParamBuilder; + struct OpParam; struct OpParamBuilder; enum OpParamType : uint8_t { OpParamType_NONE = 0, OpParamType_OnlineDequantizeParam = 1, + OpParamType_OnlineDequantizeReshapeSplitParam = 2, OpParamType_MIN = OpParamType_NONE, - OpParamType_MAX = OpParamType_OnlineDequantizeParam + OpParamType_MAX = OpParamType_OnlineDequantizeReshapeSplitParam }; -inline const OpParamType (&EnumValuesOpParamType())[2] { +inline const OpParamType (&EnumValuesOpParamType())[3] { static const OpParamType values[] = { OpParamType_NONE, - OpParamType_OnlineDequantizeParam + OpParamType_OnlineDequantizeParam, + OpParamType_OnlineDequantizeReshapeSplitParam }; return values; } inline const char * const *EnumNamesOpParamType() { - static const char * const names[3] = { + static const char * const names[4] = { "NONE", "OnlineDequantizeParam", + "OnlineDequantizeReshapeSplitParam", nullptr }; return names; } inline const char *EnumNameOpParamType(OpParamType e) { - if (flatbuffers::IsOutRange(e, OpParamType_NONE, OpParamType_OnlineDequantizeParam)) return ""; + if (flatbuffers::IsOutRange(e, OpParamType_NONE, OpParamType_OnlineDequantizeReshapeSplitParam)) return ""; const size_t index = static_cast(e); return EnumNamesOpParamType()[index]; } @@ -64,6 +70,10 @@ template<> struct OpParamTypeTraits struct OpParamTypeTraits { + static const OpParamType enum_value = OpParamType_OnlineDequantizeReshapeSplitParam; +}; + bool VerifyOpParamType(flatbuffers::Verifier &verifier, const void *obj, OpParamType type); bool VerifyOpParamTypeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); @@ -108,6 +118,109 @@ inline flatbuffers::Offset CreateOnlineDequantizeParam( return builder_.Finish(); } +struct OnlineDequantizeReshapeSplitParam FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OnlineDequantizeReshapeSplitParamBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_AXIS = 4, + VT_SPLIT_POINT = 6, + VT_SPLIT = 8, + VT_SHAPE = 10, + VT_BIAS_TERM = 12 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + const flatbuffers::Vector *split_point() const { + return GetPointer *>(VT_SPLIT_POINT); + } + const flatbuffers::Vector *split() const { + return GetPointer *>(VT_SPLIT); + } + const flatbuffers::Vector *shape() const { + return GetPointer *>(VT_SHAPE); + } + bool bias_term() const { + return GetField(VT_BIAS_TERM, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_AXIS, 4) && + VerifyOffset(verifier, VT_SPLIT_POINT) && + verifier.VerifyVector(split_point()) && + VerifyOffset(verifier, VT_SPLIT) && + verifier.VerifyVector(split()) && + VerifyOffset(verifier, VT_SHAPE) && + verifier.VerifyVector(shape()) && + VerifyField(verifier, VT_BIAS_TERM, 1) && + verifier.EndTable(); + } +}; + +struct OnlineDequantizeReshapeSplitParamBuilder { + typedef OnlineDequantizeReshapeSplitParam Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement(OnlineDequantizeReshapeSplitParam::VT_AXIS, axis, 0); + } + void add_split_point(flatbuffers::Offset> split_point) { + fbb_.AddOffset(OnlineDequantizeReshapeSplitParam::VT_SPLIT_POINT, split_point); + } + void add_split(flatbuffers::Offset> split) { + fbb_.AddOffset(OnlineDequantizeReshapeSplitParam::VT_SPLIT, split); + } + void add_shape(flatbuffers::Offset> shape) { + fbb_.AddOffset(OnlineDequantizeReshapeSplitParam::VT_SHAPE, shape); + } + void add_bias_term(bool bias_term) { + fbb_.AddElement(OnlineDequantizeReshapeSplitParam::VT_BIAS_TERM, static_cast(bias_term), 0); + } + explicit OnlineDequantizeReshapeSplitParamBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOnlineDequantizeReshapeSplitParam( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0, + flatbuffers::Offset> split_point = 0, + flatbuffers::Offset> split = 0, + flatbuffers::Offset> shape = 0, + bool bias_term = false) { + OnlineDequantizeReshapeSplitParamBuilder builder_(_fbb); + builder_.add_shape(shape); + builder_.add_split(split); + builder_.add_split_point(split_point); + builder_.add_axis(axis); + builder_.add_bias_term(bias_term); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateOnlineDequantizeReshapeSplitParamDirect( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0, + const std::vector *split_point = nullptr, + const std::vector *split = nullptr, + const std::vector *shape = nullptr, + bool bias_term = false) { + auto split_point__ = split_point ? _fbb.CreateVector(*split_point) : 0; + auto split__ = split ? _fbb.CreateVector(*split) : 0; + auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; + return ppl::nn::llm::cuda::opmx::i8i8::CreateOnlineDequantizeReshapeSplitParam( + _fbb, + axis, + split_point__, + split__, + shape__, + bias_term); +} + struct OpParam FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef OpParamBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { @@ -124,6 +237,9 @@ struct OpParam FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const ppl::nn::llm::cuda::opmx::i8i8::OnlineDequantizeParam *value_as_OnlineDequantizeParam() const { return value_type() == ppl::nn::llm::cuda::opmx::i8i8::OpParamType_OnlineDequantizeParam ? static_cast(value()) : nullptr; } + const ppl::nn::llm::cuda::opmx::i8i8::OnlineDequantizeReshapeSplitParam *value_as_OnlineDequantizeReshapeSplitParam() const { + return value_type() == ppl::nn::llm::cuda::opmx::i8i8::OpParamType_OnlineDequantizeReshapeSplitParam ? static_cast(value()) : nullptr; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_VALUE_TYPE, 1) && @@ -137,6 +253,10 @@ template<> inline const ppl::nn::llm::cuda::opmx::i8i8::OnlineDequantizeParam *O return value_as_OnlineDequantizeParam(); } +template<> inline const ppl::nn::llm::cuda::opmx::i8i8::OnlineDequantizeReshapeSplitParam *OpParam::value_as() const { + return value_as_OnlineDequantizeReshapeSplitParam(); +} + struct OpParamBuilder { typedef OpParam Table; flatbuffers::FlatBufferBuilder &fbb_; @@ -177,6 +297,10 @@ inline bool VerifyOpParamType(flatbuffers::Verifier &verifier, const void *obj, auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case OpParamType_OnlineDequantizeReshapeSplitParam: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } diff --git a/src/ppl/nn/engines/llm_cuda/pmx/schema/llm_cuda_op_i8i8_params.fbs b/src/ppl/nn/engines/llm_cuda/pmx/schema/llm_cuda_op_i8i8_params.fbs index 0eab0c098..9c4d64657 100644 --- a/src/ppl/nn/engines/llm_cuda/pmx/schema/llm_cuda_op_i8i8_params.fbs +++ b/src/ppl/nn/engines/llm_cuda/pmx/schema/llm_cuda_op_i8i8_params.fbs @@ -4,8 +4,17 @@ table OnlineDequantizeParam { bias_term: bool; } +table OnlineDequantizeReshapeSplitParam { + axis: int32; + split_point: [int32]; + split: [int64]; + shape: [int64]; + bias_term: bool; +} + union OpParamType { OnlineDequantizeParam, + OnlineDequantizeReshapeSplitParam, } table OpParam { diff --git a/src/ppl/nn/engines/llm_cuda/register_ops.cc b/src/ppl/nn/engines/llm_cuda/register_ops.cc index b8710d7d8..ef64ab1fe 100644 --- a/src/ppl/nn/engines/llm_cuda/register_ops.cc +++ b/src/ppl/nn/engines/llm_cuda/register_ops.cc @@ -40,12 +40,14 @@ #include "ppl/nn/engines/llm_cuda/ops/opmx/moe_select_op.h" #include "ppl/nn/engines/llm_cuda/ops/opmx/multi_head_attention_op.h" #include "ppl/nn/engines/llm_cuda/ops/opmx/parallel_embedding_op.h" +#include "ppl/nn/engines/llm_cuda/ops/opmx/pixel_unshuffle_op.h" #include "ppl/nn/engines/llm_cuda/ops/opmx/rms_norm_op.h" #include "ppl/nn/engines/llm_cuda/ops/opmx/rotary_position_embedding_op.h" #include "ppl/nn/engines/llm_cuda/ops/opmx/row_parallel_linear_op.h" #include "ppl/nn/engines/llm_cuda/ops/opmx/silu_op.h" #include "ppl/nn/engines/llm_cuda/ops/opmx/swiglu_op.h" #include "ppl/nn/engines/llm_cuda/ops/opmx/swish_op.h" +#include "ppl/nn/engines/llm_cuda/ops/opmx/tensor_parallel_rms_norm_op.h" #include "ppl/nn/engines/llm_cuda/ops/opmx/vision_embedding_op.h" #include "ppl/nn/engines/llm_cuda/ops/opmx/dynamic_batching/key_value_cache_op.h" @@ -162,6 +164,7 @@ void RegisterBuiltinOpImpls() { // O // P RegisterOptKernelCreator("opmx", "ParallelEmbedding", 1, 1); + RegisterOptKernelCreator("opmx", "PixelUnshuffle", 1, 1); // Q // R RegisterOptKernelCreator("opmx", "RMSNorm", 1, 1); @@ -172,6 +175,7 @@ void RegisterBuiltinOpImpls() { RegisterOptKernelCreator("opmx", "SwiGLU", 1, 1); RegisterOptKernelCreator("opmx", "Swish", 1, 1); // T + RegisterOptKernelCreator("opmx", "TensorParallelRMSNorm", 1, 1); // U // V RegisterOptKernelCreator("opmx", "VisionEmbedding", 1, 1); diff --git a/src/ppl/nn/models/onnx/param_parser_manager.cc b/src/ppl/nn/models/onnx/param_parser_manager.cc index 530b2956b..3f001f796 100644 --- a/src/ppl/nn/models/onnx/param_parser_manager.cc +++ b/src/ppl/nn/models/onnx/param_parser_manager.cc @@ -88,9 +88,11 @@ #include "ppl/nn/models/onnx/parsers/opmx/parse_multi_head_attention_param.h" #include "ppl/nn/models/onnx/parsers/opmx/parse_multi_head_cache_attention_param.h" #include "ppl/nn/models/onnx/parsers/opmx/parse_parallel_embedding_param.h" +#include "ppl/nn/models/onnx/parsers/opmx/parse_pixel_unshuffle_param.h" #include "ppl/nn/models/onnx/parsers/opmx/parse_reshape_param.h" #include "ppl/nn/models/onnx/parsers/opmx/parse_rotary_position_embedding_param.h" #include "ppl/nn/models/onnx/parsers/opmx/parse_row_parallel_linear_param.h" +#include "ppl/nn/models/onnx/parsers/opmx/parse_tensor_parallel_rms_norm_param.h" #include "ppl/nn/models/onnx/parsers/opmx/parse_vision_embedding_param.h" using namespace std; @@ -319,6 +321,8 @@ ParamParserManager::ParamParserManager() { // P PPL_REGISTER_OP_WITH_PARAM("opmx", "ParallelEmbedding", 1, 1, ppl::nn::opmx::ParallelEmbeddingParam, ppl::nn::opmx::ParseParallelEmbeddingParam, nullptr); + PPL_REGISTER_OP_WITH_PARAM("opmx", "PixelUnshuffle", 1, 1, ppl::nn::opmx::PixelUnshuffleParam, + ppl::nn::opmx::ParsePixelUnshuffleParam, nullptr); // R PPL_REGISTER_OP_WITH_PARAM("opmx", "Reshape", 1, 1, ppl::nn::onnx::ReshapeParam, ppl::nn::opmx::ParseReshapeParam, nullptr); @@ -335,6 +339,10 @@ ParamParserManager::ParamParserManager() { PPL_REGISTER_OP_WITH_PARAM("opmx", "Swish", 1, 1, ppl::nn::pmx::SwishParam, ppl::nn::pmx::ParseSwishParam, nullptr); + // T + PPL_REGISTER_OP_WITH_PARAM("opmx", "TensorParallelRMSNorm", 1, 1, ppl::nn::opmx::TensorParallelRMSNormParam, ppl::nn::opmx::ParseTensorParallelRMSNormParam, + nullptr); + // V PPL_REGISTER_OP_WITH_PARAM("opmx", "VisionEmbedding", 1, 1, ppl::nn::opmx::VisionEmbeddingParam, ppl::nn::opmx::ParseVisionEmbeddingParam, nullptr); diff --git a/src/ppl/nn/models/onnx/parsers/opmx/parse_pixel_unshuffle_param.cc b/src/ppl/nn/models/onnx/parsers/opmx/parse_pixel_unshuffle_param.cc new file mode 100644 index 000000000..adbbd6e5c --- /dev/null +++ b/src/ppl/nn/models/onnx/parsers/opmx/parse_pixel_unshuffle_param.cc @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "ppl/nn/models/onnx/parsers/opmx/parse_pixel_unshuffle_param.h" +#include "ppl/nn/common/logger.h" +#include "ppl/nn/models/onnx/utils.h" +using namespace std; +using namespace ppl::common; +using namespace ppl::nn::opmx; + +namespace ppl { namespace nn { namespace opmx { + +RetCode ParsePixelUnshuffleParam(const ::onnx::NodeProto& pb_node, const onnx::ParamParserExtraArgs& args, ir::Node* node, + ir::Attr* arg) { + auto param = static_cast(arg); + + if (!onnx::utils::GetNodeAttr(pb_node, "scale_factor", ¶m->scale_factor, -1)) { + LOG(ERROR) << node->GetName() << ": missing scale_factor"; + return RC_INVALID_VALUE; + } + + string data_layout; + if (!onnx::utils::GetNodeAttr(pb_node, "data_layout", &data_layout, "")) { + LOG(ERROR) << node->GetName() << ": missing data_layout"; + return RC_INVALID_VALUE; + } + + if (data_layout == "nhwc") { + param->data_layout = ppl::nn::opmx::PixelUnshuffleParam::DATA_LAYOUT_NHWC; + } else { + LOG(ERROR) << "unsupported data_layout: " << data_layout; + return RC_UNSUPPORTED; + } + + return RC_SUCCESS; +} + +}}} // namespace ppl::nn::opmx diff --git a/src/ppl/nn/models/onnx/parsers/opmx/parse_pixel_unshuffle_param.h b/src/ppl/nn/models/onnx/parsers/opmx/parse_pixel_unshuffle_param.h new file mode 100644 index 000000000..042bb5c20 --- /dev/null +++ b/src/ppl/nn/models/onnx/parsers/opmx/parse_pixel_unshuffle_param.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_MODELS_ONNX_PARSERS_OPMX_PARSE_PIXEL_UNSHUFFLE_PARAM_H_ +#define _ST_HPC_PPL_NN_MODELS_ONNX_PARSERS_OPMX_PARSE_PIXEL_UNSHUFFLE_PARAM_H_ + +#include "ppl/common/retcode.h" +#include "ppl/nn/params/opmx/pixel_unshuffle_param.h" +#include "ppl/nn/models/onnx/param_parser_extra_args.h" +#include "onnx.pb.h" + +namespace ppl { namespace nn { namespace opmx { + +ppl::common::RetCode ParsePixelUnshuffleParam(const ::onnx::NodeProto&, const onnx::ParamParserExtraArgs&, ir::Node*, + ir::Attr*); + +}}} // namespace ppl::nn::opmx + +#endif diff --git a/src/ppl/nn/models/onnx/parsers/opmx/parse_tensor_parallel_rms_norm_param.cc b/src/ppl/nn/models/onnx/parsers/opmx/parse_tensor_parallel_rms_norm_param.cc new file mode 100644 index 000000000..88619e1ba --- /dev/null +++ b/src/ppl/nn/models/onnx/parsers/opmx/parse_tensor_parallel_rms_norm_param.cc @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "ppl/nn/models/onnx/parsers/opmx/parse_tensor_parallel_rms_norm_param.h" +#include "ppl/nn/common/logger.h" +#include "ppl/nn/models/onnx/utils.h" +using namespace std; +using namespace ppl::common; +using namespace ppl::nn::opmx; + +namespace ppl { namespace nn { namespace opmx { + +RetCode ParseTensorParallelRMSNormParam(const ::onnx::NodeProto& pb_node, const onnx::ParamParserExtraArgs& args, ir::Node* node, + ir::Attr* arg) { + auto param = static_cast(arg); + onnx::utils::GetNodeAttr(pb_node, "scale", ¶m->scale, 1.0); + onnx::utils::GetNodeAttr(pb_node, "axis", ¶m->axis, -1); + onnx::utils::GetNodeAttr(pb_node, "eps", ¶m->eps, 1e-5); + return RC_SUCCESS; +} + +}}} // namespace ppl::nn::opmx diff --git a/src/ppl/nn/models/onnx/parsers/opmx/parse_tensor_parallel_rms_norm_param.h b/src/ppl/nn/models/onnx/parsers/opmx/parse_tensor_parallel_rms_norm_param.h new file mode 100644 index 000000000..d427058fc --- /dev/null +++ b/src/ppl/nn/models/onnx/parsers/opmx/parse_tensor_parallel_rms_norm_param.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_MODELS_ONNX_PARSERS_OPMX_PARSE_TENSOR_PARALLEL_RMS_NORM_PARAM_H_ +#define _ST_HPC_PPL_NN_MODELS_ONNX_PARSERS_OPMX_PARSE_TENSOR_PARALLEL_RMS_NORM_PARAM_H_ + +#include "ppl/common/retcode.h" +#include "ppl/nn/params/opmx/tensor_parallel_rms_norm_param.h" +#include "ppl/nn/models/onnx/param_parser_extra_args.h" +#include "onnx.pb.h" + +namespace ppl { namespace nn { namespace opmx { + +ppl::common::RetCode ParseTensorParallelRMSNormParam(const ::onnx::NodeProto&, const onnx::ParamParserExtraArgs&, ir::Node*, + ir::Attr*); + +}}} // namespace ppl::nn::opmx + +#endif diff --git a/src/ppl/nn/oputils/opmx/reshape_pixel_unshuffle.cc b/src/ppl/nn/oputils/opmx/reshape_pixel_unshuffle.cc new file mode 100644 index 000000000..e5993301f --- /dev/null +++ b/src/ppl/nn/oputils/opmx/reshape_pixel_unshuffle.cc @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "reshape_pixel_unshuffle.h" + +#include "ppl/nn/runtime/tensor_impl.h" +#include "ppl/nn/common/logger.h" + +using namespace ppl::common; + +namespace ppl { namespace nn { namespace opmx { + +ppl::common::RetCode ReshapePixelUnshuffle(InputOutputInfo* info, const ir::Attr* arg) { + auto param = static_cast(arg); + + const TensorShape& in_shape = *info->GetInput(0)->GetShape(); + const int32_t out_dim_count = in_shape.GetDimCount(); + + if (in_shape.GetDim(in_shape.GetDimCount() - 3) % param->scale_factor) { + LOG(ERROR) << info->GetNode()->GetName() << " W dim[" << in_shape.GetDim(in_shape.GetDimCount() - 3) + << "] is not divisible by scale factor[" << param->scale_factor << "]"; + return RC_INVALID_VALUE; + } + + if (in_shape.GetDim(in_shape.GetDimCount() - 2) % param->scale_factor) { + LOG(ERROR) << info->GetNode()->GetName() << " H dim[" << in_shape.GetDim(in_shape.GetDimCount() - 2) + << "] is not divisible by scale factor[" << param->scale_factor << "]"; + return RC_INVALID_VALUE; + } + + std::vector out_dims(out_dim_count); + for (int32_t i = 0; i < out_dim_count - 3; i++) { + out_dims[i] = in_shape.GetDim(i); + } + out_dims[out_dim_count - 3] = in_shape.GetDim(out_dim_count - 3) / param->scale_factor; + out_dims[out_dim_count - 2] = in_shape.GetDim(out_dim_count - 2) / param->scale_factor; + out_dims[out_dim_count - 1] = in_shape.GetDim(out_dim_count - 1) * param->scale_factor * param->scale_factor; + + info->GetOutput(0)->GetShape()->Reshape(out_dims); + + return RC_SUCCESS; +} + +}}} // namespace ppl::nn::opmx diff --git a/src/ppl/nn/oputils/opmx/reshape_pixel_unshuffle.h b/src/ppl/nn/oputils/opmx/reshape_pixel_unshuffle.h new file mode 100644 index 000000000..f524063d8 --- /dev/null +++ b/src/ppl/nn/oputils/opmx/reshape_pixel_unshuffle.h @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef _ST_HPC_PPL_NN_OPUTILS_OPMX_PIXEL_UNSHUFFLE_H_ +#define _ST_HPC_PPL_NN_OPUTILS_OPMX_PIXEL_UNSHUFFLE_H_ + +#include "ppl/common/retcode.h" +#include "ppl/nn/params/opmx/pixel_unshuffle_param.h" +#include "ppl/nn/common/input_output_info.h" +#include "ppl/nn/ir/attr.h" + +namespace ppl { namespace nn { namespace opmx { + +ppl::common::RetCode ReshapePixelUnshuffle(InputOutputInfo*, const ir::Attr*); + +}}} // namespace ppl::nn::opmx + +#endif \ No newline at end of file diff --git a/src/ppl/nn/params/opmx/pixel_unshuffle_param.h b/src/ppl/nn/params/opmx/pixel_unshuffle_param.h new file mode 100644 index 000000000..52eb033f6 --- /dev/null +++ b/src/ppl/nn/params/opmx/pixel_unshuffle_param.h @@ -0,0 +1,27 @@ +#ifndef _ST_HPC_PPL_NN_PARAMS_OPMX_PIXEL_UNSHUFFLE_PARAM_H_ +#define _ST_HPC_PPL_NN_PARAMS_OPMX_PIXEL_UNSHUFFLE_PARAM_H_ + +#include "ppl/nn/ir/attr.h" +#include +#include + +namespace ppl { namespace nn { namespace opmx { + +struct PixelUnshuffleParam final : public ir::TypedAttr { + enum { + DATA_LAYOUT_NONE = 0, + DATA_LAYOUT_NHWC = 1, + }; + + int32_t scale_factor; + int32_t data_layout; + + bool operator==(const PixelUnshuffleParam& p) const { + return (scale_factor == p.scale_factor + && data_layout == p.data_layout); + } +}; + +}}} // namespace ppl::nn::opmx + +#endif diff --git a/src/ppl/nn/params/opmx/tensor_parallel_rms_norm_param.h b/src/ppl/nn/params/opmx/tensor_parallel_rms_norm_param.h new file mode 100644 index 000000000..1ebcc3bd7 --- /dev/null +++ b/src/ppl/nn/params/opmx/tensor_parallel_rms_norm_param.h @@ -0,0 +1,22 @@ +#ifndef _ST_HPC_PPL_NN_PARAMS_OPMX_TENSOR_PARALLEL_RMS_NORM_PARAM_H_ +#define _ST_HPC_PPL_NN_PARAMS_OPMX_TENSOR_PARALLEL_RMS_NORM_PARAM_H_ + +#include "ppl/nn/ir/attr.h" +#include +#include + +namespace ppl { namespace nn { namespace opmx { + +struct TensorParallelRMSNormParam final : public ir::TypedAttr { + int32_t axis; + float eps; + float scale; + + bool operator==(const TensorParallelRMSNormParam& p) const { + return (axis == p.axis && fabs(eps - p.eps) <= 1e-05 && fabs(scale - p.scale) <= 1e-05); + } +}; + +}}} // namespace ppl::nn::opmx + +#endif diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 0b4a56393..588474543 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -36,6 +36,15 @@ if(PPLNN_USE_LLM_CUDA) target_link_libraries(benchmark_llama pplnn_static) target_include_directories(benchmark_llama PRIVATE ${HPCC_DEPS_DIR}/rapidjson/include) + # keep it for compare + file(GLOB_RECURSE BENCHMARK_LLAMA_DEPRECATED_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/benchmark_llama_deprecated/*.cc) + add_executable(benchmark_llama_deprecated + ${BENCHMARK_LLAMA_DEPRECATED_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/simple_flags.cc) + target_link_libraries(benchmark_llama_deprecated pplnn_static) + target_include_directories(benchmark_llama_deprecated PRIVATE ${HPCC_DEPS_DIR}/rapidjson/include) + if(PPLNN_CUDA_ENABLE_NCCL) find_package(OpenMP REQUIRED) add_executable(test_nccl ${CMAKE_CURRENT_SOURCE_DIR}/test_nccl.cc) diff --git a/tools/benchmark_llama/app.cc b/tools/benchmark_llama/app.cc index 60480208d..f987f6c61 100644 --- a/tools/benchmark_llama/app.cc +++ b/tools/benchmark_llama/app.cc @@ -27,7 +27,7 @@ Define_uint32_opt("--benchmark-loops", g_flag_benchmark_loops, 4, "benchmark loo Define_string_opt("--quant-method", g_flag_quant_method, "none", "llm cuda quantization mehtod, only accept " - "\"none\", \"online_i8i8\" and \"online_i4f16\", " + "\"none\", \"online_i8i8\" , \"online_f8f8\" and \"online_i4f16\", " "default: \"none\""); Define_string_opt("--cublas-layout-hint", g_cublas_layout_hint, "default", "matrix layout hint for cublas(currently only effect int8 gemm), only accept " @@ -36,6 +36,26 @@ Define_string_opt("--cublas-layout-hint", g_cublas_layout_hint, "default", Define_bool_opt("--kernel-profiling", g_flag_kernel_profiling, true, "enable kernel profiling and print profiling info"); +Define_bool_opt("--enable-cache-prefill", g_flag_enable_cache_prefill, + false, "enable cache prefill flash attention"); + +Define_bool_opt("--disable-decoding-shm-mha", g_flag_disable_decoding_shm_mha, + false, "disable shared memory decoding attention algorithm"); +Define_bool_opt("--disable-decoding-inf-mha", g_flag_disable_decoding_inf_mha, + false, "disable infinity decoding attention algorithm"); +Define_bool_opt("--disable-decoding-inf-gqa", g_flag_disable_decoding_inf_gqa, + false, "disable infinity grouped query decoding attention algorithm"); +Define_int32_opt("--configure-decoding-attn-split-k", g_flag_configure_decoding_attn_split_k, 1, + "configuring split-k decoding attention algorithm, " + "accepted values: always-on(2)/heuristic(1)/off(0)," + "default is heuristic(1)"); +Define_int32_opt("--specify-decoding-attn-tpb", g_flag_specify_decoding_attn_tpb, 0, + "specify decoding attention kernel threads per block, " + "accepted values: 512/256/heuristic(0)," + "default is heuristic(0)"); + +Define_bool_opt("--disable-graph-fusion", g_flag_disable_graph_fusion, false, "disable graph kernel fusion rules"); + static bool WriteOutputToFile(const std::string& output_file, std::queue &responses) { std::ofstream fout(output_file, std::ios::out); @@ -103,7 +123,16 @@ int main(int argc, char* argv[]) { LOG(INFO) << "================== Init TextGenerator =================="; // TODO: Move new CudaTextGenerator into #ifdef USE_LLM_CUDA std::unique_ptr generator(new CudaTextGenerator( - {g_cublas_layout_hint} + { + g_cublas_layout_hint, + g_flag_disable_graph_fusion, + g_flag_enable_cache_prefill, + g_flag_disable_decoding_shm_mha, + g_flag_disable_decoding_inf_mha, + g_flag_disable_decoding_inf_gqa, + g_flag_configure_decoding_attn_split_k, + g_flag_specify_decoding_attn_tpb, + } )); auto rc = generator->InitModel( "decoder_only", diff --git a/tools/benchmark_llama/cuda/cuda_text_generator.cc b/tools/benchmark_llama/cuda/cuda_text_generator.cc index c90c17fd1..c55c2b968 100644 --- a/tools/benchmark_llama/cuda/cuda_text_generator.cc +++ b/tools/benchmark_llama/cuda/cuda_text_generator.cc @@ -124,6 +124,8 @@ ppl::nn::Engine* CudaTextGenerator::ThreadCreateEngine(const int32_t tid) { options.quant_method = ppl::nn::llm::cuda::QUANT_METHOD_NONE; } else if (model_config_.quant_method == "online_i8i8") { options.quant_method = ppl::nn::llm::cuda::QUANT_METHOD_ONLINE_I8I8; + } else if (model_config_.quant_method == "online_f8f8") { + options.quant_method = ppl::nn::llm::cuda::QUANT_METHOD_ONLINE_F8F8; } else if (model_config_.quant_method == "online_i4f16") { options.quant_method = ppl::nn::llm::cuda::QUANT_METHOD_ONLINE_I4F16; } else { @@ -131,12 +133,12 @@ ppl::nn::Engine* CudaTextGenerator::ThreadCreateEngine(const int32_t tid) { return nullptr; } - if (cublas_layout_hint_ == "default") { + if (construct_options_.cublas_layout_hint == "default") { options.cublas_layout_hint = ppl::nn::llm::cuda::CUBLAS_LAYOUT_DEFAULT; - } else if (cublas_layout_hint_ == "ampere") { + } else if (construct_options_.cublas_layout_hint == "ampere") { options.cublas_layout_hint = ppl::nn::llm::cuda::CUBLAS_LAYOUT_AMPERE; } else { - LOG(ERROR) << "unknown/unsupported --cublas-layout-hint option: " << cublas_layout_hint_; + LOG(ERROR) << "unknown/unsupported --cublas-layout-hint option: " << construct_options_.cublas_layout_hint; return nullptr; } @@ -146,10 +148,49 @@ ppl::nn::Engine* CudaTextGenerator::ThreadCreateEngine(const int32_t tid) { return nullptr; } + ppl::common::RetCode rc; + rc = engine->Configure(ppl::nn::llm::cuda::ENGINE_CONF_CACHE_PREFILL, construct_options_.enable_cache_prefill ? 1 : 0); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_CACHE_PREFILL failed: " << ppl::common::GetRetCodeStr(rc); + return nullptr; + } + + rc = engine->Configure(ppl::nn::llm::cuda::ENGINE_CONF_DECODING_SHM_MHA, construct_options_.disable_decoding_shm_mha ? 0 : 1); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_DECODING_SHM_MHA failed: " << ppl::common::GetRetCodeStr(rc); + return nullptr; + } + rc = engine->Configure(ppl::nn::llm::cuda::ENGINE_CONF_DECODING_INF_MHA, construct_options_.disable_decoding_inf_mha ? 0 : 1); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_DECODING_INF_MHA failed: " << ppl::common::GetRetCodeStr(rc); + return nullptr; + } + rc = engine->Configure(ppl::nn::llm::cuda::ENGINE_CONF_DECODING_INF_GQA, construct_options_.disable_decoding_inf_gqa ? 0 : 1); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_DECODING_INF_GQA failed: " << ppl::common::GetRetCodeStr(rc); + return nullptr; + } + rc = engine->Configure(ppl::nn::llm::cuda::ENGINE_CONF_DECODING_ATTN_SPLIT_K, construct_options_.configure_decoding_attn_split_k); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_DECODING_ATTN_SPLIT_K failed: " << ppl::common::GetRetCodeStr(rc); + return nullptr; + } + rc = engine->Configure(ppl::nn::llm::cuda::ENGINE_CONF_DECODING_ATTN_TPB, construct_options_.specify_decoding_attn_tpb); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_DECODING_ATTN_TPB failed: " << ppl::common::GetRetCodeStr(rc); + return nullptr; + } + + rc = engine->Configure(ppl::nn::llm::cuda::ENGINE_CONF_GRAPH_FUSION, construct_options_.disable_graph_fusion ? 0 : 1); + if (ppl::common::RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_GRAPH_FUSION failed: " << ppl::common::GetRetCodeStr(rc); + return nullptr; + } + #ifdef PPLNN_CUDA_ENABLE_NCCL - auto rc = engine->Configure(ppl::nn::llm::cuda::ENGINE_CONF_SET_TP_NCCL_COMM, nccl_comm_list_[tid]); + rc = engine->Configure(ppl::nn::llm::cuda::ENGINE_CONF_SET_TP_NCCL_COMM, nccl_comm_list_[tid]); if (rc != ppl::common::RC_SUCCESS) { - LOG(ERROR) << "engine configure failed"; + LOG(ERROR) << "engine configure failed: " << ppl::common::GetRetCodeStr(rc); return nullptr; } #endif diff --git a/tools/benchmark_llama/cuda/cuda_text_generator.h b/tools/benchmark_llama/cuda/cuda_text_generator.h index 895045fac..16d9b7d59 100644 --- a/tools/benchmark_llama/cuda/cuda_text_generator.h +++ b/tools/benchmark_llama/cuda/cuda_text_generator.h @@ -12,11 +12,18 @@ typedef void* ncclComm_t; class CudaTextGenerator final : public TextGenerator { public: struct ConstructOptions { - std::string cublas_layout_hint; + std::string cublas_layout_hint = "default"; + bool disable_graph_fusion = false; + bool enable_cache_prefill = false; + bool disable_decoding_shm_mha = false; + bool disable_decoding_inf_mha = false; + bool disable_decoding_inf_gqa = false; + uint32_t configure_decoding_attn_split_k = 1; + uint32_t specify_decoding_attn_tpb = 0; }; CudaTextGenerator(const ConstructOptions& options) { - cublas_layout_hint_ = options.cublas_layout_hint; + construct_options_ = options; } virtual bool CheckParameters() override; @@ -42,9 +49,10 @@ class CudaTextGenerator final : public TextGenerator { private: std::vector nccl_comm_list_; std::vector> host_device_list_; - std::string cublas_layout_hint_; std::unique_ptr sampler_; + ConstructOptions construct_options_; + void InitCudaThread(); void FinalizeCudaThread(); diff --git a/tools/benchmark_llama/text_generator.cc b/tools/benchmark_llama/text_generator.cc index 59e407e51..9c10b3c8a 100644 --- a/tools/benchmark_llama/text_generator.cc +++ b/tools/benchmark_llama/text_generator.cc @@ -309,11 +309,15 @@ ppl::common::RetCode TextGenerator::Generate( do { RuntimePoolRun([&](uint32_t nthr, uint32_t tid) { auto arg = &runtime_thread_args_[tid]; + +#ifdef PPLNN_ENABLE_KERNEL_PROFILING if (tid == 0 && profiler && profiler->collect_statistics && (state_.current_step == 0 || state_.current_step == 1)) { auto rc = arg->runtime->Configure(ppl::nn::RUNTIME_CONF_SET_KERNEL_PROFILING_FLAG, true); if (rc != ppl::common::RC_SUCCESS) LOG(WARNING) << "enable kernel profiling failed: " << ppl::common::GetRetCodeStr(rc); } +#endif + bool ret = ThreadSetInputTensors(tid); if (!ret) { LOG(ERROR) << "SetInputTensor failed"; @@ -326,6 +330,7 @@ ppl::common::RetCode TextGenerator::Generate( return ppl::common::RC_OTHER_ERROR; } +#ifdef PPLNN_ENABLE_KERNEL_PROFILING if (tid == 0 && profiler && profiler->collect_statistics && (state_.current_step == 0 || state_.current_step == state_.max_steps - 1)) { if (state_.current_step == 0) { auto rc = arg->runtime->GetProfilingStatistics(&profiler->prefill_statistics); @@ -340,7 +345,8 @@ ppl::common::RetCode TextGenerator::Generate( if (rc != ppl::common::RC_SUCCESS) LOG(WARNING) << "enable profiling failed: " << ppl::common::GetRetCodeStr(rc); } - +#endif + auto logits = arg->logits; rc = ThreadSampleArgMax( tid, diff --git a/tools/benchmark_llama_deprecated/benchmark_llama.cc b/tools/benchmark_llama_deprecated/benchmark_llama.cc new file mode 100644 index 000000000..20cbf8410 --- /dev/null +++ b/tools/benchmark_llama_deprecated/benchmark_llama.cc @@ -0,0 +1,1263 @@ +#include "sampler.h" +#include "../simple_flags.h" + +#include "rapidjson/document.h" +#include "rapidjson/istreamwrapper.h" +#include "ppl/common/log.h" +#include "ppl/common/cuda/cuda_env.h" +#include "ppl/common/threadpool.h" + +#include "ppl/nn/runtime/options.h" +#include "ppl/nn/runtime/runtime.h" +#include "ppl/nn/engines/engine.h" +#include "ppl/nn/engines/llm_cuda/engine_factory.h" +#include "ppl/nn/models/onnx/runtime_builder_factory.h" + +#ifdef PPLNN_ENABLE_PMX_MODEL +#include "ppl/nn/models/pmx/runtime_builder_factory.h" +#include "ppl/nn/models/pmx/load_model_options.h" +#include "ppl/nn/models/pmx/save_model_options.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef PPLNN_CUDA_ENABLE_NCCL +#include +#else +typedef void* ncclComm_t; +#endif + +struct Profiler { + double prefill_latency = 0; + // std::vector decode_latency = 0; // size = gen len + std::vector step_latency; + double total_latency = 0; + double set_intput_latency = 0; + double mem_usage = 0; // GB + void Reset() { + this->prefill_latency = 0; + this->set_intput_latency = 0; + this->step_latency.assign(this->step_latency.size(), 0); + this->mem_usage = 0; + } +}; + +static Profiler profiling; + +class ThreadPool { +private: + ppl::common::StaticThreadPool pool_; + std::vector retcode_; + +public: + void Init(int nthr) { + pool_.Init(nthr); + retcode_.resize(nthr); + } + + void Run(const std::function& f) { + pool_.Run([&] (uint32_t nthr, uint32_t tid) { + retcode_[tid] = f(nthr, tid); + }); + for (auto ret : retcode_) { + if (ret != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "exit with thread error"; + exit(-1); + } + } + } +}; + +static ThreadPool gpu_thread_pool; + +Define_bool_opt("--help", g_flag_help, false, "show these help information"); +Define_string_opt("--model-type", g_flag_model_type, "", "model type"); +Define_string_opt("--model-dir", g_flag_model_dir, "", "model directory"); +Define_string_opt("--model-param-path", g_flag_model_param_path, "", "path of model params"); + +#ifdef PPLNN_ENABLE_PMX_MODEL +Define_bool_opt("--use-pmx", g_flag_use_pmx, false, "use pmx model"); +#endif + +Define_uint32_opt("--tensor-parallel-size", g_flag_tensor_parallel_size, 1, "tensor parallel size"); +Define_float_opt("--top-p", g_flag_top_p, 0.0, "top p"); +Define_uint32_opt("--top-k", g_flag_top_k, 1, "top k"); +Define_float_opt("--temperature", g_flag_temperature, 1.0, "temperature"); +Define_uint32_opt("--generation-len", g_flag_generation_len, 32, "generation length"); +Define_uint32_opt("--warmup-loops", g_flag_warmup_loops, 2, "warm loops"); +Define_uint32_opt("--benchmark-loops", g_flag_benchmark_loops, 4, "benchmark loops"); +Define_string_opt("--input-file", g_flag_input_file, "", "input file of request's token ids. no effect if --input-len is non-zero"); +Define_uint32_opt("--input-len", g_flag_input_len, 0, "input length of request. default: 0(get length from input file)"); +Define_uint32_opt("--batch-size", g_flag_batch_size, UINT32_MAX, "batch size"); +Define_uint32_opt("--micro-batch", g_flag_micro_batch, UINT32_MAX, "dummy"); +Define_string_opt("--output-file", g_flag_output_file, "", "output file of output token ids.") +Define_string_opt("--quant-method", g_flag_quant_method, "none", + "llm cuda quantization mehtod, only accept " + "\"none\", \"online_i8i8\" and \"online_i4f16\", " + "default: \"none\""); +Define_string_opt("--cublas-layout-hint", g_cublas_layout_hint, "default", + "matrix layout hint for cublas(currently only effect int8 gemm), only accept " + "\"default\", \"ampere\". " + "default: \"default\""); + +Define_bool_opt("--kernel-profiling", g_flag_kernel_profiling, true, "enable kernel profiling and print profiling info"); + +static int64_t random_input[1024] = {4854, 28445, 26882, 19570, 28904, 7224, 11204, 12608, 23093, 5763, 17481, 3637, 4989, 8263, 18072, 7607, 10287, 6389, 30521, 19284, 1001, 30170, 16117, 11688, 3189, 4694, 18740, 6585, 3299, 289, 14008, 22789, 12043, 29885, 19050, 24321, 11134, 6291, 26101, 21448, 9998, 11708, 13471, 4035, 6285, 15050, 3445, 30546, 3335, 9024, 20135, 462, 27882, 29628, 2573, 29186, 24879, 16327, 13250, 2196, 4584, 14253, 24544, 14142, 21916, 26777, 22673, 23681, 29726, 4875, 15073, 25115, 29674, 19967, 14119, 18069, 23952, 4903, 14050, 7884, 25496, 25353, 8206, 17718, 24951, 22931, 25282, 27350, 7459, 15428, 13848, 17086, 30838, 6330, 19846, 21990, 12750, 18192, 23364, 31189, 2049, 5170, 18875, 1550, 24837, 20623, 5968, 21205, 12275, 11288, 31214, 17545, 25403, 22595, 26832, 27094, 4287, 2088, 14693, 30114, 11775, 16566, 1128, 9841, 6723, 4064, 19010, 10563, 16391, 22630, 25224, 4214, 10438, 4197, 20711, 25095, 8637, 1249, 21827, 15920, 1269, 24989, 18823, 10217, 4197, 18277, 3692, 3326, 16183, 12565, 11703, 20781, 26531, 9290, 11666, 18146, 20460, 3866, 30325, 23696, 14540, 15313, 17313, 11808, 24707, 7762, 7928, 31121, 188, 27724, 20011, 21316, 26679, 8934, 25191, 7640, 12644, 2745, 28379, 2915, 30257, 11475, 23502, 18365, 16392, 6913, 26862, 12704, 18085, 28552, 7072, 23477, 30879, 26014, 10777, 22887, 25528, 13986, 16807, 7838, 1914, 29227, 13069, 9977, 15107, 22174, 2453, 4482, 25644, 20425, 23556, 22172, 15768, 15790, 29825, 14381, 30648, 9594, 22624, 11919, 4756, 8095, 3566, 25349, 7798, 1451, 16108, 1740, 20877, 8163, 30604, 31876, 24077, 18241, 7281, 6266, 4243, 7069, 19769, 22766, 18629, 11727, 19192, 26391, 26689, 25834, 19592, 7891, 21956, 14238, 27197, 12860, 31620, 25199, 30635, 20908, 10656, 12847, 2502, 12412, 4969, 12149, 13885, 19198, 2346, 23433, 8594, 26669, 25496, 3386, 15291, 7447, 27139, 14139, 9704, 7289, 2297, 18465, 15065, 29629, 29297, 18111, 16321, 23181, 4635, 5194, 5680, 20010, 22590, 2653, 3869, 24767, 1965, 24028, 30772, 23175, 29866, 2205, 18108, 15062, 3118, 9045, 5723, 6415, 31082, 2188, 7311, 20256, 19578, 21254, 16531, 16726, 3079, 10648, 10834, 11582, 19042, 4120, 21394, 18674, 23845, 1607, 16299, 22337, 22147, 4969, 25872, 24250, 29371, 23383, 13664, 9146, 23049, 17562, 3404, 1871, 27293, 1761, 16423, 13860, 10916, 2501, 18750, 31245, 9438, 7113, 27553, 19404, 3935, 19308, 19074, 10950, 2523, 10560, 8343, 9880, 27166, 15279, 14267, 20852, 14966, 24011, 22818, 15692, 1707, 5708, 9276, 24446, 27951, 4064, 3860, 11723, 14799, 14288, 14789, 24125, 30444, 29224, 9204, 17018, 13849, 21455, 17831, 8628, 1219, 6999, 22257, 7093, 21735, 9971, 17377, 12209, 17336, 13298, 25329, 13935, 31161, 22448, 23774, 748, 20329, 534, 30021, 14973, 6819, 20014, 22457, 29490, 21, 16223, 5492, 12695, 17176, 11757, 21868, 9953, 11467, 19631, 8310, 22225, 21181, 2503, 31558, 3028, 16996, 22232, 3690, 21498, 3742, 5285, 7486, 30377, 28383, 24183, 25623, 19988, 15639, 30002, 31411, 10780, 17521, 20937, 15612, 20057, 8355, 8916, 974, 30669, 18007, 164, 24930, 5119, 31156, 5946, 7294, 12805, 8349, 24333, 25220, 22156, 17136, 30967, 22668, 18047, 23242, 31038, 16002, 6195, 7639, 3549, 26399, 24178, 2848, 5888, 12496, 7480, 23608, 479, 31809, 30003, 26686, 19203, 22386, 7131, 4202, 3938, 4982, 31438, 3689, 29917, 19597, 28127, 4193, 18764, 2921, 4958, 22711, 93, 9594, 2494, 25492, 29359, 1596, 19777, 16806, 31869, 30211, 18345, 25026, 7879, 31933, 3583, 24569, 13110, 26598, 28383, 18403, 31994, 26340, 16875, 7114, 7372, 21954, 27227, 9279, 9757, 29061, 8525, 13101, 7744, 14296, 3679, 20769, 681, 12047, 3626, 14519, 1882, 3318, 17983, 19078, 10225, 11902, 22704, 448, 17143, 4973, 4354, 8100, 16630, 21754, 17219, 21381, 17471, 15750, 21204, 16511, 13165, 15525, 21326, 30660, 17947, 13702, 3995, 4059, 20, 30822, 22434, 19823, 7723, 13703, 20727, 11601, 17352, 13278, 31426, 20254, 6780, 8720, 17786, 15357, 5186, 11210, 23357, 6095, 21162, 640, 17668, 26775, 15785, 24912, 3374, 16072, 1838, 10180, 10731, 21572, 29611, 19191, 515, 10627, 12119, 6484, 9732, 8013, 22587, 1849, 3148, 18262, 15175, 13366, 20509, 5587, 30812, 2584, 31511, 11407, 6734, 18259, 13605, 9521, 25685, 30029, 31019, 6722, 3166, 15975, 12804, 17449, 29155, 26789, 23069, 19316, 26635, 29030, 21767, 24352, 12835, 5827, 21404, 15769, 15340, 31644, 6557, 4483, 15009, 5492, 30064, 29790, 30548, 22490, 30943, 12428, 29600, 5910, 12041, 26366, 28920, 3731, 5983, 1577, 3275, 15440, 4307, 10031, 20999, 8512, 766, 8616, 23190, 2754, 17507, 8830, 28490, 19489, 30404, 18750, 19824, 9129, 13398, 28868, 9680, 14908, 1086, 25230, 3432, 18402, 21096, 26573, 13830, 10086, 30708, 29992, 2173, 22163, 1572, 7598, 26022, 20475, 29632, 13133, 21975, 13792, 29371, 18452, 17421, 27734, 5914, 7317, 21842, 10833, 9780, 19507, 456, 15224, 20667, 45, 25414, 17738, 527, 31635, 31812, 8268, 23148, 24295, 1167, 2536, 14759, 10377, 2069, 13663, 12073, 16907, 29637, 5153, 4634, 25994, 397, 31527, 1150, 18942, 28864, 25195, 20448, 6497, 16291, 25399, 6059, 20762, 10191, 9196, 5438, 30897, 9234, 21348, 15318, 10919, 8330, 1781, 4175, 22058, 12618, 23993, 27484, 19815, 13835, 14605, 30530, 12528, 15855, 15094, 25708, 16082, 14820, 19526, 7676, 9215, 19222, 21365, 20375, 8183, 7369, 7940, 17555, 24506, 8138, 3027, 10721, 17146, 18460, 12332, 5174, 12780, 25184, 2895, 19014, 7408, 19011, 1396, 4581, 23738, 18612, 18277, 2646, 27617, 17913, 14895, 11038, 22787, 23271, 4618, 29633, 28035, 25643, 6758, 29526, 2681, 2217, 22770, 1632, 20076, 30737, 4613, 6318, 19603, 24994, 2587, 24149, 7230, 29733, 21695, 12255, 22514, 26849, 5111, 17797, 24847, 16833, 12742, 12003, 5286, 17873, 10942, 23972, 21230, 6546, 14866, 18500, 15393, 22536, 8133, 25296, 22484, 19982, 13087, 29776, 23359, 10425, 22028, 11190, 16693, 2118, 23351, 27817, 21382, 1189, 25925, 19520, 27026, 2639, 15749, 18384, 29283, 29672, 21813, 19320, 31083, 23918, 26421, 11032, 25719, 19729, 30445, 14226, 8696, 29600, 9000, 15486, 29377, 1422, 12197, 6116, 3543, 21149, 28361, 6570, 26061, 3658, 21072, 2339, 19848, 17606, 2944, 24911, 6300, 13493, 16401, 19117, 31785, 22760, 24634, 26375, 7856, 20481, 25122, 14345, 16559, 6296, 27652, 13643, 15577, 21088, 1292, 6931, 31824, 15488, 25473, 19310, 20581, 21956, 9402, 4613, 1639, 840, 26369, 28685, 30877, 17166, 659, 28898, 11557, 19939, 31031, 18452, 29644, 19566, 12301, 472, 20018, 19573, 8257, 25520, 3814, 10656, 13039, 14661, 2207, 26849, 21633, 23418, 16230, 13791, 6774, 27429, 9088, 3167, 15050, 7711, 20597, 24940, 26294, 16510, 4960, 1806, 25994, 20792, 5446, 10808, 6183, 17514, 3541, 28826, 22857, 23680, 16870, 20164, 3110, 5153, 19392, 26894, 9187, 721, 27523, 7362, 1268, 15641, 15800, 11869, 10599, 12818, 13302, 19468, 26556, 29696, 30405, 9210, 1918, 13974, 17268, 19746, 13401, 17902, 9654, 26288, 17900, 23369, 3759, 2450, 30977, 30906, 17485, 17301, 26017, 14638}; + +#ifdef PPLNN_ENABLE_KERNEL_PROFILING + +static ppl::nn::ProfilingStatistics prefill_kernel_stat; +static ppl::nn::ProfilingStatistics decode_kernel_stat; + +static void PrintProfilingStatistics(const ppl::nn::ProfilingStatistics& stat, int32_t run_count) { + std::map> type_stat; + std::map type_count; + char float_buf_0[128]; + char float_buf_1[128]; + // LOG(INFO) << "----- Op statistics by Node -----"; + for (auto x = stat.prof_info.begin(); x != stat.prof_info.end(); ++x) { + auto ext_type = (x->domain == "" ? "" : x->domain + ".") + x->type; + double time = (double)x->exec_microseconds / 1000; + double avg_time = time / x->exec_count; + if (type_stat.find(ext_type) == type_stat.end()) { + type_stat[ext_type] = std::make_pair(avg_time, time); + type_count[ext_type] = 1; + } else { + std::pair& time_pair = type_stat[ext_type]; + time_pair.first += avg_time; + time_pair.second += time; + type_count[ext_type]++; + } + // sprintf(float_buf_0, "%8.4f", avg_time); + // string temp = x->name; + // temp.insert(temp.length(), temp.length() > 50 ? 0 : 50 - temp.length(), ' '); + // LOG(INFO) << "Name: [" << temp << "], " + // << "Avg time: [" << float_buf_0 << "], " + // << "Exec count: [" << x->exec_count << "]"; + } + // LOG(INFO) << "----- Op statistics by OpType -----"; + double tot_kernel_time = 0; + for (auto it = type_stat.begin(); it != type_stat.end(); ++it) { + tot_kernel_time += it->second.second; + } + for (auto it = type_stat.begin(); it != type_stat.end(); ++it) { + sprintf(float_buf_0, "%8.4f", it->second.first); + sprintf(float_buf_1, "%8.4f", it->second.second / tot_kernel_time * 100); + std::string temp = it->first; + temp.insert(temp.length(), temp.length() > 20 ? 0 : 20 - temp.length(), ' '); + LOG(INFO) << "Type: [" << temp << "], Avg time: [" << float_buf_0 << "], Percentage: [" << float_buf_1 + << "], Exec count [" << type_count[it->first] << "]"; + } + + // LOG(INFO) << "----- Total statistics -----"; + sprintf(float_buf_0, "%8.4f", tot_kernel_time / run_count); + LOG(INFO) << "Run count: [" << run_count << "]"; + LOG(INFO) << "Avg kernel time: [" << float_buf_0 << "]"; + sprintf(float_buf_0, "%8.4f", tot_kernel_time); + LOG(INFO) << "Total kernel time: [" << float_buf_0 << "]"; +} +#endif + +class TimingGuard final { +public: + TimingGuard(double* res) { + diff_millisec_ = res; + begin_ = std::chrono::high_resolution_clock::now(); + } + ~TimingGuard() { + auto end = std::chrono::high_resolution_clock::now(); + *diff_millisec_ = double(std::chrono::duration_cast(end - begin_).count()) / 1000.0; + } + +private: + double* diff_millisec_; + std::chrono::time_point begin_; +}; + +struct Config { + std::string model_type; + std::string model_dir; + std::string model_param_path; + + int tensor_parallel_size = 0; + + float top_p = 0; + float top_k = 1; + float temperature = 1; + int generation_len = 0; + + int benchmark_loops = 0; + + std::string quant_method; +}; + +struct ModelConfig final { + int hidden_dim; + int intermediate_dim; + int num_layers; + int num_heads; + int num_kv_heads; + int vocab_size; + + float norm_eps; // not used + + int cache_quant_bit; + int cache_quant_group; + + int cache_layout; + int cache_mode; + + bool dynamic_batching; + bool auto_causal; + + std::string quant_method; +}; + +struct ModelInput { + std::vector token_ids; + std::vector seq_starts; + std::vector kv_starts; + std::vector cache_indices; + int64_t decoding_batches = 0; + std::vector start_pos; + int64_t max_seq_len = 0; + int64_t max_kv_len = 0; + + void* kv_cache; + void* kv_scale; + + std::vector first_fill_len; +}; + +struct WorkerThreadArg final { + std::unique_ptr host_device; + std::unique_ptr runtime; + + void* kv_cache_mem = nullptr; + void* kv_scale_mem = nullptr; + + ppl::nn::Tensor* token_ids; + ppl::nn::Tensor* attn_mask; + ppl::nn::Tensor* seq_starts; + ppl::nn::Tensor* kv_starts; + ppl::nn::Tensor* cache_indices; + ppl::nn::Tensor* decoding_batches; + ppl::nn::Tensor* start_pos; + ppl::nn::Tensor* max_seq_len; + ppl::nn::Tensor* max_kv_len; + ppl::nn::Tensor* kv_cache; + ppl::nn::Tensor* kv_scale; + + ppl::nn::Tensor* logits; +}; + +static void InitCudaThread() { + gpu_thread_pool.Run([&](uint32_t nthr, uint32_t tid) { + auto cu_ret = cudaSetDevice(tid); + if (cu_ret != cudaSuccess) { + LOG(ERROR) << "cudaSetDevice(" << tid << ") failed: " << cudaGetErrorString(cu_ret); + return ppl::common::RC_OTHER_ERROR; + } + auto rc = ppl::common::InitCudaEnv(tid); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "InitCudaEnv(" << tid << ") failed: " << ppl::common::GetRetCodeStr(rc); + return ppl::common::RC_OTHER_ERROR; + } + return ppl::common::RC_SUCCESS; + }); +} + +static void FinalizeCudaThread() { + gpu_thread_pool.Run([&](uint32_t nthr, uint32_t tid) { + auto rc = ppl::common::DestroyCudaEnv(tid); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "InitCudaEnv(" << tid << ") failed: " << ppl::common::GetRetCodeStr(rc); + return ppl::common::RC_OTHER_ERROR; + } + return ppl::common::RC_SUCCESS; + }); +} + +#ifdef PPLNN_CUDA_ENABLE_NCCL +#define NCCL_CHECK(cmd, emsg) \ + do { \ + ncclResult_t e = (cmd); \ + if (e != ncclSuccess) { \ + LOG(ERROR) << "NCCL error(code:" << (int)e << ") on " << (emsg); \ + return ppl::common::RC_OTHER_ERROR; \ + } \ + } while (0); + +static bool InitNccl(uint32_t tensor_parallel_size, std::vector* nccl_comm_list) { + nccl_comm_list->resize(tensor_parallel_size); + ncclUniqueId uuid; + ncclGetUniqueId(&uuid); + gpu_thread_pool.Run([&](uint32_t nthr, uint32_t tid) { + NCCL_CHECK(ncclCommInitRank(&nccl_comm_list->at(tid), tensor_parallel_size, uuid, tid), "ncclCommInitRank"); + return ppl::common::RC_SUCCESS; + }); + return true; +} + +static void FinalizeNccl(uint32_t tensor_parallel_size, const std::vector& nccl_comm_list) { + gpu_thread_pool.Run([&](uint32_t nthr, uint32_t tid) { + NCCL_CHECK(ncclCommDestroy(nccl_comm_list[tid]), "ncclCommDestroy"); + return ppl::common::RC_SUCCESS; + }); +} + +#endif + +static ppl::nn::Engine* CreateCudaEngine(ncclComm_t nccl_comm, int device_id, const std::string& quant_method) { + ppl::nn::llm::cuda::EngineOptions options; + options.device_id = device_id; + options.mm_policy = ppl::nn::llm::cuda::MM_COMPACT; + + if (quant_method == "none") { + options.quant_method = ppl::nn::llm::cuda::QUANT_METHOD_NONE; + } else if (quant_method == "online_i8i8") { + options.quant_method = ppl::nn::llm::cuda::QUANT_METHOD_ONLINE_I8I8; + } else if (quant_method == "online_i4f16") { + options.quant_method = ppl::nn::llm::cuda::QUANT_METHOD_ONLINE_I4F16; + } else { + LOG(ERROR) << "unknown/unsupported --quant-method option: " << quant_method; + return nullptr; + } + + if (g_cublas_layout_hint == "default") { + options.cublas_layout_hint = ppl::nn::llm::cuda::CUBLAS_LAYOUT_DEFAULT; + } else if (g_cublas_layout_hint == "ampere") { + options.cublas_layout_hint = ppl::nn::llm::cuda::CUBLAS_LAYOUT_AMPERE; + } else { + LOG(ERROR) << "unknown/unsupported --cublas-layout-hint option: " << g_cublas_layout_hint; + return nullptr; + } + + auto engine = std::unique_ptr(ppl::nn::llm::cuda::EngineFactory::Create(options)); + if (!engine) { + LOG(ERROR) << "create cuda engine failed."; + return nullptr; + } + +#ifdef PPLNN_CUDA_ENABLE_NCCL + auto rc = engine->Configure(ppl::nn::llm::cuda::ENGINE_CONF_SET_TP_NCCL_COMM, nccl_comm); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "engine configure failed"; + return nullptr; + } +#endif + + return engine.release(); +} + +static ppl::nn::Runtime* CreatePPLRuntime(ppl::nn::Engine* cuda_engine, const std::string& model_file) { + auto builder = std::unique_ptr(ppl::nn::onnx::RuntimeBuilderFactory::Create()); + if (!builder) { + LOG(ERROR) << "create onnx builder failed."; + return nullptr; + } + + auto rc = builder->LoadModel(model_file.c_str()); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "load model [" << model_file << "] failed: " << ppl::common::GetRetCodeStr(rc); + return nullptr; + } + + ppl::nn::onnx::RuntimeBuilder::Resources resources; + resources.engines = &cuda_engine; + resources.engine_num = 1; + + rc = builder->SetResources(resources); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "set resources for builder failed: " << ppl::common::GetRetCodeStr(rc); + return nullptr; + } + + rc = builder->Preprocess(); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "builder preprocess failed: " << ppl::common::GetRetCodeStr(rc); + return nullptr; + } + + return builder->CreateRuntime(); +} + +#ifdef PPLNN_ENABLE_PMX_MODEL +static ppl::nn::Runtime* CreatePMXPPLRuntime(ppl::nn::Engine* cuda_engine, const std::string& model_file) { + auto builder = std::unique_ptr(ppl::nn::pmx::RuntimeBuilderFactory::Create()); + if (!builder) { + LOG(ERROR) << "create PmxRuntimeBuilder failed."; + return nullptr; + } + + ppl::nn::pmx::RuntimeBuilder::Resources resources; + resources.engines = &cuda_engine; + resources.engine_num = 1; + + std::string external_data_dir_fix; + ppl::nn::pmx::LoadModelOptions opt; + auto status = builder->LoadModel(model_file.c_str(), resources, opt); + if (status != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "PmxRuntimeBuilder LoadModel failed: " << ppl::common::GetRetCodeStr(status); + return nullptr; + } + + status = builder->Preprocess(); + if (status != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "pmx preprocess failed: " << ppl::common::GetRetCodeStr(status); + return nullptr; + } + + return builder->CreateRuntime(); +} +#endif //PPLNN_ENABLE_PMX_MODEL + +static void UpdateInputPrefill(int gen_len, ModelInput* model_input) { + int batch_size = model_input->first_fill_len.size(); + model_input->decoding_batches = 0; + + model_input->seq_starts.reserve(batch_size + 1); + model_input->seq_starts.push_back(0); + + model_input->kv_starts.reserve(batch_size + 1); + model_input->kv_starts.push_back(0); + + model_input->start_pos.reserve(batch_size); + + model_input->cache_indices.reserve(batch_size); + model_input->cache_indices.push_back(0); + + for (int i = 0; i < batch_size; ++i) { + model_input->start_pos.push_back(0); + model_input->seq_starts.push_back(model_input->seq_starts[i] + model_input->first_fill_len[i]); + model_input->kv_starts.push_back(model_input->kv_starts[i] + model_input->first_fill_len[i]); + model_input->max_seq_len = std::max(model_input->first_fill_len[i], model_input->max_seq_len); + model_input->max_kv_len = std::max(model_input->first_fill_len[i], model_input->max_kv_len); + + if (i > 0) { + model_input->cache_indices.push_back(model_input->cache_indices[i - 1] + + model_input->first_fill_len[i - 1] + gen_len - 1); + } + } +} + +static void UpdateInputDecode(int step, const std::vector& gen_tokens, ModelInput* model_input) { + int batch_size = model_input->first_fill_len.size(); + model_input->decoding_batches = batch_size; + model_input->max_seq_len = 1; + model_input->max_kv_len = model_input->max_kv_len + 1; + + model_input->token_ids.resize(batch_size); + + for (int i = 0; i < batch_size; ++i) { + model_input->token_ids[i] = gen_tokens.at(i); + model_input->seq_starts[i + 1] = model_input->seq_starts[i] + 1; + model_input->kv_starts[i + 1] = model_input->kv_starts[i] + model_input->first_fill_len[i] + step; + if (step == 1) { + model_input->start_pos[i] = model_input->first_fill_len[i]; + } else { + model_input->start_pos[i]++; + } + } +} + +static std::shared_ptr CreateCudaSampler(ppl::nn::Runtime* runtime) { + ppl::nn::DeviceContext::Type needed_type; + *((int64_t*)needed_type.str) = 0; + needed_type.str[0] = 'c'; + needed_type.str[1] = 'u'; + needed_type.str[2] = 'd'; + needed_type.str[3] = 'a'; + + ppl::nn::DeviceContext* dev = nullptr; + for (uint32_t i = 0; i < runtime->GetDeviceContextCount(); ++i) { + if (runtime->GetDeviceContext(i)->GetType() == needed_type) { + dev = runtime->GetDeviceContext(i); + break; + } + } + + if (!dev) { + LOG(ERROR) << "cannot find cuda device in runtime."; + return std::shared_ptr(); + } + + cudaStream_t stream; + auto rc = dev->Configure(ppl::nn::llm::cuda::DEV_CONF_GET_STREAM, &stream); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "Configure ppl::nn::llm::cuda::DEV_CONF_GET_STREAM failed: " << ppl::common::GetRetCodeStr(rc); + return std::shared_ptr(); + } + + return std::make_shared(stream); +} + +static bool CheckParameters(const ModelConfig& model_config) { + if (model_config.auto_causal != true) { + LOG(ERROR) << "only support auto_causal == true"; + return false; + } + + if (model_config.cache_mode != 0) { + LOG(ERROR) << "only support cache_mode == 0"; + return false; + } + + if (model_config.cache_layout != 0 && model_config.cache_layout != 3) { + LOG(ERROR) << "only support cache_layout == 0 || cache_layout == 3"; + return false; + } + + if (model_config.cache_quant_bit != 8 && model_config.cache_quant_group != 8) { + LOG(ERROR) << "only support cache_quant_bit == 8 and cache_quant_group == 8"; + return false; + } + + if (model_config.dynamic_batching != true) { + LOG(ERROR) << "only support dynamic_batching == true"; + return false; + } + + return true; +} + + +class LLM { +public: + LLM(const Config& config) + : tensor_parallel_size_(config.tensor_parallel_size) + , top_p_(config.top_p) + , top_k_(config.top_k) + , temperature_(config.temperature) + , generation_len_(config.generation_len) + , nccl_comm_list_(config.tensor_parallel_size) + , engine_list_(config.tensor_parallel_size) + , worker_thread_args_(config.tensor_parallel_size) {} + + ~LLM() {} + +#ifdef PPLNN_ENABLE_KERNEL_PROFILING + void SetKernelProfiling(bool flag) { + kernel_profiling_ = flag; + } +#endif + + void Finalize() { + gpu_thread_pool.Run([&](uint32_t nthr, uint32_t tid) { + sampler_.reset(); + worker_thread_args_[tid].runtime.reset(); + engine_list_[tid].reset(); + cudaFree(worker_thread_args_[tid].kv_cache_mem); + cudaFree(worker_thread_args_[tid].kv_scale_mem); + worker_thread_args_[tid].host_device.reset(); + return ppl::common::RC_SUCCESS; + }); +#ifdef PPLNN_CUDA_ENABLE_NCCL + FinalizeNccl(tensor_parallel_size_, nccl_comm_list_); +#endif + } + + bool Init(const ModelConfig& model_config, const std::string& model_dir) { + bool rc = CheckParameters(model_config); + if (!rc) { + LOG(ERROR) << "CheckParameters failed."; + return false; + } + + vocab_size_ = model_config.vocab_size; + kv_cache_block_bytes_ = model_config.num_layers * 2 * model_config.num_kv_heads / tensor_parallel_size_ * + model_config.hidden_dim / model_config.num_heads * sizeof(int8_t); + kv_scale_block_bytes_ = model_config.num_layers * 2 * model_config.num_kv_heads / tensor_parallel_size_ * + model_config.hidden_dim / model_config.num_heads / model_config.cache_quant_group * sizeof(int16_t); + +#ifdef PPLNN_CUDA_ENABLE_NCCL + rc = InitNccl(tensor_parallel_size_, &nccl_comm_list_); + if (!rc) { + LOG(ERROR) << "NCCL init failed."; + return false; + } + LOG(INFO) << "Init Nccl successed"; +#endif + + gpu_thread_pool.Run([&](uint32_t nthr, uint32_t tid) { + engine_list_[tid] = std::unique_ptr(CreateCudaEngine(nccl_comm_list_[tid], tid, model_config.quant_method)); + if (!engine_list_[tid]) { + LOG(ERROR) << "create cuda engine [" << tid << "] failed."; + return ppl::common::RC_OTHER_ERROR; + } + LOG(INFO) << "Create cuda engine [" << tid << "] success"; + +#ifdef PPLNN_ENABLE_PMX_MODEL + if (g_flag_use_pmx) + { + const std::string model_path = model_dir + "/model_slice_" + std::to_string(tid) + "/model.pmx"; + worker_thread_args_[tid].host_device.reset(ppl::nn::llm::cuda::EngineFactory::CreateHostDeviceContext( + ppl::nn::llm::cuda::HostDeviceOptions())); + worker_thread_args_[tid].runtime = std::unique_ptr(CreatePMXPPLRuntime(engine_list_[tid].get(), model_path)); + } + else +#endif + { + const std::string model_path = model_dir + "/model_slice_" + std::to_string(tid) + "/model.onnx"; + worker_thread_args_[tid].host_device.reset(ppl::nn::llm::cuda::EngineFactory::CreateHostDeviceContext( + ppl::nn::llm::cuda::HostDeviceOptions())); + worker_thread_args_[tid].runtime = std::unique_ptr(CreatePPLRuntime(engine_list_[tid].get(), model_path)); + } + if (!worker_thread_args_[tid].runtime) { + LOG(ERROR) << "create runtime [" << tid << "] failed."; + return ppl::common::RC_OTHER_ERROR; + } + LOG(INFO) << "Create runtime [" << tid << "] success"; + + if (tid == 0) { + sampler_ = CreateCudaSampler(worker_thread_args_[0].runtime.get()); + if (!sampler_) { + LOG(ERROR) << "CreateCudaSampler failed"; + return ppl::common::RC_OTHER_ERROR; + } + LOG(INFO) << "Create cuda sampler success"; + } + + return ppl::common::RC_SUCCESS; + }); + + return true; + } + + bool PrepareInput(int batch_size, int kv_cache_tokens, const ModelConfig& model_config) { + temperature_list_.resize(batch_size); + for (size_t i = 0; i < temperature_list_.size(); ++i) { + temperature_list_[i] = temperature_; + } + + gpu_thread_pool.Run([&](uint32_t nthr, uint32_t tid) { + auto cu_ret = cudaMalloc(&worker_thread_args_[tid].kv_cache_mem, kv_cache_tokens * kv_cache_block_bytes_); + if (cu_ret != cudaSuccess) { + LOG(ERROR) << "alloc kv cache [" << kv_cache_tokens * kv_cache_block_bytes_ + << "] failed: " << cudaGetErrorString(cu_ret); + return ppl::common::RC_OTHER_ERROR; + } + cu_ret = cudaMalloc(&worker_thread_args_[tid].kv_scale_mem, kv_cache_tokens * kv_scale_block_bytes_); + if (cu_ret != cudaSuccess) { + cudaFree(worker_thread_args_[tid].kv_cache_mem); + LOG(ERROR) << "alloc kv scale [" << kv_cache_tokens * kv_scale_block_bytes_ + << "] failed: " << cudaGetErrorString(cu_ret); + return ppl::common::RC_OTHER_ERROR; + } + + // init tensor + auto* arg = &worker_thread_args_[tid]; + arg->token_ids = arg->runtime->GetInputTensor(0); + arg->attn_mask = arg->runtime->GetInputTensor(1); + arg->seq_starts = arg->runtime->GetInputTensor(2); + arg->kv_starts = arg->runtime->GetInputTensor(3); + arg->cache_indices = arg->runtime->GetInputTensor(4); + arg->decoding_batches = arg->runtime->GetInputTensor(5); + arg->start_pos = arg->runtime->GetInputTensor(6); + arg->max_seq_len = arg->runtime->GetInputTensor(7); + arg->max_kv_len = arg->runtime->GetInputTensor(8); + arg->kv_cache = arg->runtime->GetInputTensor(9); + arg->kv_scale = arg->runtime->GetInputTensor(10); + + arg->logits = arg->runtime->GetOutputTensor(0); + + arg->decoding_batches->SetDeviceContext(arg->host_device.get()); + arg->max_seq_len->SetDeviceContext(arg->host_device.get()); + arg->max_kv_len->SetDeviceContext(arg->host_device.get()); + + arg->kv_cache->SetBufferPtr(arg->kv_cache_mem); + arg->kv_scale->SetBufferPtr(arg->kv_scale_mem); + + // set kv cache, kv scale shape + if (model_config.cache_layout == 0) { + arg->kv_cache->GetShape()->Reshape({(int64_t)kv_cache_tokens, model_config.num_layers, 2, + model_config.num_kv_heads / tensor_parallel_size_, + model_config.hidden_dim / model_config.num_heads}); + arg->kv_scale->GetShape()->Reshape( + {(int64_t)kv_cache_tokens, model_config.num_layers, 2, + model_config.num_kv_heads / tensor_parallel_size_, + model_config.hidden_dim / model_config.num_heads / model_config.cache_quant_group}); + } else if (model_config.cache_layout == 3) { + arg->kv_cache->GetShape()->Reshape( + {model_config.num_layers, 2, model_config.num_kv_heads / tensor_parallel_size_, + (int64_t)kv_cache_tokens, model_config.hidden_dim / model_config.num_heads}); + arg->kv_scale->GetShape()->Reshape( + {model_config.num_layers, 2, model_config.num_kv_heads / tensor_parallel_size_, + (int64_t)kv_cache_tokens, + model_config.hidden_dim / model_config.num_heads / model_config.cache_quant_group}); + } else { + LOG(ERROR) << "impossible status: cache_layout = [" << model_config.cache_layout << "]"; + return ppl::common::RC_OTHER_ERROR; + } + return ppl::common::RC_SUCCESS; + }); + + return true; + } + + bool SetInputTensor(const ModelInput& model_input, int id, int step) { + ppl::common::RetCode rc; + // token ids + // if (step < 2) { + worker_thread_args_[id].token_ids->GetShape()->Reshape({int64_t(model_input.token_ids.size())}); + rc = worker_thread_args_[id].token_ids->CopyFromHostAsync(model_input.token_ids.data()); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "set token_ids [" << worker_thread_args_[id].token_ids->GetName() + << "] failed: " << ppl::common::GetRetCodeStr(rc); + return false; + } + // } + + // kv_starts + worker_thread_args_[id].kv_starts->GetShape()->Reshape({int64_t(model_input.kv_starts.size())}); + rc = worker_thread_args_[id].kv_starts->CopyFromHostAsync(model_input.kv_starts.data()); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "set kv_starts " << worker_thread_args_[id].kv_starts->GetName() + << " failed: " << ppl::common::GetRetCodeStr(rc); + return false; + } + + // start_pos + worker_thread_args_[id].start_pos->GetShape()->Reshape({int64_t(model_input.start_pos.size())}); + rc = worker_thread_args_[id].start_pos->CopyFromHostAsync(model_input.start_pos.data()); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "set start_pos [" << worker_thread_args_[id].start_pos->GetName() + << "] failed: " << ppl::common::GetRetCodeStr(rc); + return false; + } + + // max_kv_len + rc = worker_thread_args_[id].max_kv_len->CopyFromHostAsync(&model_input.max_kv_len); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "set max_kv_len [" << worker_thread_args_[id].max_kv_len->GetName() + << "] failed: " << ppl::common::GetRetCodeStr(rc); + return false; + } + + // prefill + if (step < 1) { + // cache_indices + worker_thread_args_[id].cache_indices->GetShape()->Reshape({int64_t(model_input.cache_indices.size())}); + rc = worker_thread_args_[id].cache_indices->CopyFromHostAsync(model_input.cache_indices.data()); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "set cache_indices [" << worker_thread_args_[id].cache_indices->GetName() + << "] failed: " << ppl::common::GetRetCodeStr(rc); + return false; + } + } + + if (step < 2) { + // seq_start + // LOG(INFO) << "model_input.seq_starts: "; + // PrintVector(model_input.seq_starts); + worker_thread_args_[id].seq_starts->GetShape()->Reshape({int64_t(model_input.seq_starts.size())}); + rc = worker_thread_args_[id].seq_starts->CopyFromHostAsync(model_input.seq_starts.data()); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "set seq_starts [" << worker_thread_args_[id].seq_starts->GetName() + << "] failed: " << ppl::common::GetRetCodeStr(rc); + return false; + } + + // decoding batches + rc = worker_thread_args_[id].decoding_batches->CopyFromHostAsync(&model_input.decoding_batches); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "set decoding_batches [" << worker_thread_args_[id].decoding_batches->GetName() + << "] failed: " << ppl::common::GetRetCodeStr(rc); + return false; + } + + // max_seq_len + rc = worker_thread_args_[id].max_seq_len->CopyFromHostAsync(&model_input.max_seq_len); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "set max_seq_len [" << worker_thread_args_[id].max_seq_len->GetName() + << "] failed: " << ppl::common::GetRetCodeStr(rc); + return false; + } + } + + // rc = worker_thread_args_[id].runtime->Synchronize(); + // if (rc != ppl::common::RC_SUCCESS) { + // LOG(ERROR) << "set input tensor synchronize fail"; + // return false; + // } + return true; + } + + void Generate(ModelInput* model_input, std::vector>* output_tokens) { + int batch_size = model_input->first_fill_len.size(); + + double step_latency = 0; + for (int step = 0; step < generation_len_; ++step) { + { + TimingGuard __timing__(&step_latency); + if (step == 0) { + UpdateInputPrefill(generation_len_, model_input); + } else { + UpdateInputDecode(step, output_tokens->at(step - 1), model_input); + } + + gpu_thread_pool.Run([&](uint32_t nthr, uint32_t tid) { +#ifdef PPLNN_ENABLE_KERNEL_PROFILING + if (tid == 0 && kernel_profiling_ && (step == 0 || step == 1)) { + auto rc = worker_thread_args_[tid].runtime->Configure(ppl::nn::RUNTIME_CONF_SET_KERNEL_PROFILING_FLAG, true); + if (rc != ppl::common::RC_SUCCESS) { + LOG(WARNING) << "enable kernel profiling failed: " << ppl::common::GetRetCodeStr(rc); + } + } +#endif + bool ret = SetInputTensor(*model_input, tid, step); + if (!ret) { + LOG(ERROR) << "SetInputTensor failed"; + return ppl::common::RC_OTHER_ERROR; + } + + auto rc = worker_thread_args_[tid].runtime->Run(); + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "model run failed"; + return ppl::common::RC_OTHER_ERROR; + } + +#ifdef PPLNN_ENABLE_KERNEL_PROFILING + if (tid == 0 && kernel_profiling_ && (step == 0 || step == generation_len_ - 1)) { + if (step == 0) { + auto rc = worker_thread_args_[tid].runtime->GetProfilingStatistics(&prefill_kernel_stat); + if (rc != ppl::common::RC_SUCCESS) { + LOG(WARNING) << "get prefill kernel profiling stats failed: " << ppl::common::GetRetCodeStr(rc); + } + } else { + auto rc = worker_thread_args_[tid].runtime->GetProfilingStatistics(&decode_kernel_stat); + if (rc != ppl::common::RC_SUCCESS) { + LOG(WARNING) << "get decode kernel profiling stats failed: " << ppl::common::GetRetCodeStr(rc); + } + } + auto rc = worker_thread_args_[tid].runtime->Configure(ppl::nn::RUNTIME_CONF_SET_KERNEL_PROFILING_FLAG, false); + if (rc != ppl::common::RC_SUCCESS) { + LOG(WARNING) << "enable profiling failed: " << ppl::common::GetRetCodeStr(rc); + } + } +#endif + if (tid == 0) { + auto logits = worker_thread_args_[tid].logits; + auto rc = sampler_->SampleTopPTopK((float*)logits->GetBufferPtr(), temperature_list_.data(), batch_size, + vocab_size_, logits->GetShape()->GetDim(1), top_p_, top_k_, output_tokens->at(step).data()); + + if (rc != ppl::common::RC_SUCCESS) { + LOG(ERROR) << "SampleTopPTopK failed: " << ppl::common::GetRetCodeStr(rc); + return ppl::common::RC_OTHER_ERROR; + } + } + return ppl::common::RC_SUCCESS; + }); + } + + profiling.step_latency[step] += step_latency; + } + } + +private: + int tensor_parallel_size_ = 0; + + float top_p_ = 0; + float top_k_ = 1; + float temperature_ = 1; + int generation_len_ = 0; + + std::vector nccl_comm_list_; + std::vector> engine_list_; + std::vector worker_thread_args_; + + std::vector temperature_list_; + std::shared_ptr sampler_; + int vocab_size_ = 0; + uint64_t kv_cache_block_bytes_ = 0; + uint64_t kv_scale_block_bytes_ = 0; + +#ifdef PPLNN_ENABLE_KERNEL_PROFILING + bool kernel_profiling_ = false; +#endif +}; + +static void ParseConfig(Config* config) { + config->model_type = g_flag_model_type; + config->model_dir = g_flag_model_dir; + config->model_param_path = g_flag_model_param_path; + config->tensor_parallel_size = g_flag_tensor_parallel_size; + config->top_p = g_flag_top_p; + config->top_k = g_flag_top_k; + config->temperature = g_flag_temperature; + config->generation_len = g_flag_generation_len; + config->benchmark_loops = g_flag_benchmark_loops; + config->quant_method = g_flag_quant_method; + + LOG(INFO) << "config.model_type: " << config->model_type; + LOG(INFO) << "config.model_dir: " << config->model_dir; + LOG(INFO) << "config.model_param_path: " << config->model_param_path; + + LOG(INFO) << "config.tensor_parallel_size: " << config->tensor_parallel_size; + + LOG(INFO) << "config.top_k: " << config->top_k; + LOG(INFO) << "config.top_p: " << config->top_p; + LOG(INFO) << "config.temperature: " << config->temperature; + LOG(INFO) << "config.generation_len: " << config->generation_len; + + LOG(INFO) << "config.benchmark_loops: " << config->benchmark_loops; + + LOG(INFO) << "config.quant_method: " << config->quant_method; +} + +static bool ParseModelConfig(const std::string& model_param_path, ModelConfig* model_config) { + std::ifstream ifs(model_param_path); + rapidjson::IStreamWrapper isw(ifs); + rapidjson::Document document; + if (document.ParseStream(isw) == false) { + LOG(ERROR) << "ParseStream failed"; + return false; + } + document.ParseStream(isw); + + auto it = document.FindMember("num_heads"); + if (it == document.MemberEnd()) { + LOG(ERROR) << "find key [num_heads] failed"; + return false; + } + model_config->num_heads = it->value.GetInt(); + + it = document.FindMember("num_kv_heads"); + if (it == document.MemberEnd()) { + model_config->num_kv_heads = model_config->num_heads; + } else { + model_config->num_kv_heads = it->value.GetInt(); + } + + it = document.FindMember("num_layers"); + if (it == document.MemberEnd()) { + LOG(ERROR) << "find key [num_layers] failed"; + return false; + } + model_config->num_layers = it->value.GetInt(); + + it = document.FindMember("hidden_dim"); + if (it == document.MemberEnd()) { + LOG(ERROR) << "find key [hidden_dim] failed"; + return false; + } + model_config->hidden_dim = it->value.GetInt(); + + it = document.FindMember("intermediate_dim"); + if (it == document.MemberEnd()) { + LOG(ERROR) << "find key [intermediate_dim] failed"; + return false; + } + model_config->intermediate_dim = it->value.GetInt(); + + it = document.FindMember("vocab_size"); + if (it == document.MemberEnd()) { + LOG(ERROR) << "find key [vocab_size] failed"; + return false; + } + model_config->vocab_size = it->value.GetInt(); + + it = document.FindMember("cache_quant_bit"); + if (it == document.MemberEnd()) { + LOG(ERROR) << "find key [cache_quant_bit] failed"; + return false; + } + model_config->cache_quant_bit = it->value.GetInt(); + + it = document.FindMember("cache_quant_group"); + if (it == document.MemberEnd()) { + LOG(ERROR) << "find key [cache_quant_group] failed"; + return false; + } + model_config->cache_quant_group = it->value.GetInt(); + + it = document.FindMember("cache_layout"); + if (it == document.MemberEnd()) { + LOG(ERROR) << "find key [cache_layout] failed"; + return false; + } + model_config->cache_layout = it->value.GetInt(); + + it = document.FindMember("cache_mode"); + if (it == document.MemberEnd()) { + LOG(ERROR) << "find key [cache_mode] failed"; + return false; + } + model_config->cache_mode = it->value.GetInt(); + + it = document.FindMember("dynamic_batching"); + if (it == document.MemberEnd()) { + LOG(ERROR) << "find key [dynamic_batching] failed"; + return false; + } + model_config->dynamic_batching = it->value.GetBool(); + + it = document.FindMember("auto_causal"); + if (it == document.MemberEnd()) { + LOG(ERROR) << "find key [auto_causal] failed"; + return false; + } + model_config->auto_causal = it->value.GetBool(); + + LOG(INFO) << "model_config.num_layers: " << model_config->num_layers; + LOG(INFO) << "model_config.num_heads: " << model_config->num_heads; + LOG(INFO) << "model_config.num_kv_heads: " << model_config->num_kv_heads; + LOG(INFO) << "model_config.hidden_dim: " << model_config->hidden_dim; + LOG(INFO) << "model_config.intermediate_dim: " << model_config->intermediate_dim; + LOG(INFO) << "model_config.vocab_size: " << model_config->vocab_size; + + LOG(INFO) << "model_config.cache_quant_bit: " << model_config->cache_quant_bit; + LOG(INFO) << "model_config.cache_quant_group: " << model_config->cache_quant_group; + LOG(INFO) << "model_config.cache_layout: " << model_config->cache_layout; + LOG(INFO) << "model_config.cache_mode: " << model_config->cache_mode; + + LOG(INFO) << "model_config.dynamic_batching: " << model_config->dynamic_batching; + LOG(INFO) << "model_config.auto_causal: " << model_config->auto_causal; + + return true; +} + +static bool WriteOutput(const std::string& token_file, const std::vector> &output_tokens) { + std::ofstream fout(token_file, std::ios::out); + if (!fout.is_open()) { + LOG(ERROR) << "Error Openning " << token_file; + return false; + } + + for (size_t b = 0; b < output_tokens[0].size(); ++b) { + for (size_t l = 0; l < output_tokens.size(); ++l) { + fout << output_tokens[l][b]; + if (l + 1 < output_tokens.size()) + fout << ", "; + } + fout << std::endl; + } + return true; +} + +static bool ParseInput(const std::string& token_file, ModelInput* model_input) { + std::ifstream fin(token_file, std::ios::in); + if (!fin.is_open()) { + LOG(ERROR) << "Error Openning " << token_file; + return false; + } + + std::string line; + uint32_t line_cnt = 0; + while (std::getline(fin, line) && line_cnt < g_flag_batch_size) { + std::stringstream line_stream(line); + if (line.empty()) { + continue; + } + std::string vals; + model_input->first_fill_len.push_back(0); + // each request + while (std::getline(line_stream, vals, ',')) { + model_input->token_ids.push_back(std::stoi(vals)); + ++(model_input->first_fill_len.back()); + } + line_cnt++; + } + return true; +} + +static void GenInput(int vocab_size, ModelInput* model_input) { + model_input->first_fill_len.assign(g_flag_batch_size, g_flag_input_len); + model_input->token_ids.resize(g_flag_batch_size * g_flag_input_len); + for(uint32_t i = 0; i < model_input->token_ids.size(); ++i) { + model_input->token_ids[i] = random_input[i % 1024] % vocab_size; + } +} + +int main(int argc, char* argv[]) { + simple_flags::parse_args(argc, argv); + + if (g_flag_help) { + simple_flags::print_args_info(); + return 0; + } + + if (!simple_flags::get_unknown_flags().empty()) { + std::string content; + for (auto it : simple_flags::get_unknown_flags()) { + content += "'" + it + "', "; + } + content.resize(content.size() - 2); // remove last ', ' + content.append("."); + LOG(ERROR) << "unknown option(s): " << content.c_str(); + simple_flags::print_args_info(); + return -1; + } + + LOG(INFO) << "ppl.nn version: [" << PPLNN_VERSION_MAJOR << "." << PPLNN_VERSION_MINOR << "." << PPLNN_VERSION_PATCH + << "], commit: [" << PPLNN_COMMIT_STR << "]"; + + Config config; + ParseConfig(&config); + + ModelConfig model_config; + if (!ParseModelConfig(config.model_param_path, &model_config)) { + LOG(ERROR) << "PaseModelConfig failed, model_param_path: " << config.model_param_path; + return -1; + } + model_config.quant_method = config.quant_method; + + gpu_thread_pool.Init(config.tensor_parallel_size); + InitCudaThread(); + + LOG(INFO) << "input_file: " << g_flag_input_file; + LOG(INFO) << "input_length: " << g_flag_input_len; + LOG(INFO) << "batch_size: " << g_flag_batch_size; + + ModelInput raw_model_input; + if (g_flag_input_len == 0) { + if (!ParseInput(g_flag_input_file, &raw_model_input)) { + LOG(ERROR) << "ParseInput failed, input file: " << g_flag_input_file; + return -1; + } + } else { + GenInput(model_config.vocab_size, &raw_model_input); + } + int64_t batch_size = raw_model_input.first_fill_len.size(); + int64_t total_input_length = 0; + int64_t kv_cache_length = 0; + for (auto input_length : raw_model_input.first_fill_len) { + total_input_length += input_length; + kv_cache_length += input_length + config.generation_len - 1; + } + + profiling.step_latency.resize(config.generation_len); + + LLM llm(config); + bool ret = llm.Init(model_config, config.model_dir); + if (!ret) { + LOG(ERROR) << "Init failed"; + return -1; + } + + ret = llm.PrepareInput(batch_size, kv_cache_length, model_config); + if (!ret) { + LOG(ERROR) << "PrepareInput failed"; + return -1; + } + LOG(INFO) << "PrepareInput success"; + + std::vector> output_tokens(config.generation_len, std::vector(batch_size)); + + LOG(INFO) << "Request batch size: " << batch_size; + LOG(INFO) << "Total input length: " << total_input_length; + LOG(INFO) << "KV cache length: " << kv_cache_length; + + // warmup + for (uint32_t i = 0; i < g_flag_warmup_loops; ++i) { + LOG(INFO) << "Warmup " << i; + ModelInput model_input = raw_model_input; + double latency = 0; + { + TimingGuard __timeing__(&latency); + llm.Generate(&model_input, &output_tokens); + } + LOG(INFO) << "Time " << latency << " ms"; + } + + profiling.Reset(); + for (int i = 0; i < config.benchmark_loops; ++i) { + LOG(INFO) << "Benchmark " << i; + ModelInput model_input = raw_model_input; + double latency = 0; + { + TimingGuard __timeing__(&latency); + llm.Generate(&model_input, &output_tokens); + } + profiling.total_latency += latency; + LOG(INFO) << "Time " << latency << " ms"; + } + size_t avail_bytes = 0, total = 0; + gpu_thread_pool.Run([&](uint32_t nthr, uint32_t tid) { + if (tid == 0) + cudaMemGetInfo(&avail_bytes, &total); + return ppl::common::RC_SUCCESS; + }); + profiling.mem_usage = double(total - avail_bytes) / 1024 / 1024 / 1024; + + + // profiling 结果 + double avg_prefill_latency = 0; + double max_decode_latency = 0; + double min_decode_latency = DBL_MAX; + double avg_decode_latency = 0; + double avg_step_latency = 0; + for (size_t step = 0; step < profiling.step_latency.size(); ++step) { + if (step > 0) { + avg_decode_latency += profiling.step_latency[step]; + max_decode_latency = std::max(max_decode_latency, profiling.step_latency[step]); + min_decode_latency = std::min(min_decode_latency, profiling.step_latency[step]); + } + avg_step_latency += profiling.step_latency[step]; + } + + int max_latency_step = std::max_element(profiling.step_latency.begin() + 1, profiling.step_latency.end()) - + profiling.step_latency.begin(); + int min_latency_step = std::min_element(profiling.step_latency.begin() + 1, profiling.step_latency.end()) - + profiling.step_latency.begin(); + + avg_prefill_latency = profiling.step_latency[0] / config.benchmark_loops; + avg_decode_latency = avg_decode_latency / (config.benchmark_loops * (config.generation_len - 1)); + min_decode_latency = min_decode_latency / config.benchmark_loops; + max_decode_latency = max_decode_latency / config.benchmark_loops; + avg_step_latency = avg_step_latency / (config.benchmark_loops * config.generation_len); + double tokens_per_second = 1000 / avg_step_latency * batch_size; + + LOG(INFO) << "Memory usage(GB): " << profiling.mem_usage; + LOG(INFO) << "Prefill latency(ms): " << avg_prefill_latency; + LOG(INFO) << "Min decode latency(ms)[" << min_latency_step << "]: " << min_decode_latency; + LOG(INFO) << "Max decode latency(ms)[" << max_latency_step << "]: " << max_decode_latency; + LOG(INFO) << "Avg decode latency(ms): " << avg_decode_latency; + LOG(INFO) << "Avg step latency(ms): " << avg_step_latency; + LOG(INFO) << "Tokens per second: " << tokens_per_second; + + LOG(INFO) << "CSV format header:prefill(ms),decode(ms),avg(ms),tps(ms),mem(gib)"; + LOG(INFO) << "CSV format output:" << avg_prefill_latency << "," + << avg_decode_latency << "," + << avg_step_latency << "," + << tokens_per_second << "," + << profiling.mem_usage; + +#ifdef PPLNN_ENABLE_KERNEL_PROFILING + if (g_flag_kernel_profiling) { + LOG(INFO) << "Kernel profiling"; + ModelInput model_input = raw_model_input; + llm.SetKernelProfiling(true); + double latency = 0; + { + TimingGuard __timeing__(&latency); + llm.Generate(&model_input, &output_tokens); + } + LOG(INFO) << "Time " << latency << " ms"; + LOG(INFO) << "----- Prefill statistics -----"; + PrintProfilingStatistics(prefill_kernel_stat, 1); + LOG(INFO) << "----- Decode statistics -----"; + PrintProfilingStatistics(decode_kernel_stat, config.generation_len - 1); + } +#endif + + if (!g_flag_output_file.empty()) { + WriteOutput(g_flag_output_file, output_tokens); + } + + llm.Finalize(); + + FinalizeCudaThread(); + + return 0; +} \ No newline at end of file diff --git a/tools/benchmark_llama_deprecated/sampler.cc b/tools/benchmark_llama_deprecated/sampler.cc new file mode 100644 index 000000000..d8266a428 --- /dev/null +++ b/tools/benchmark_llama_deprecated/sampler.cc @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "sampler.h" + +#include "ppl/nn/engines/llm_cuda/options.h" +#include "ppl/common/log.h" +#include "ppl/kernel/llm/cuda/pmx/sample.h" + +using namespace ppl::common; + +namespace ppl { namespace llm { namespace cuda { + +void Sampler::Clear() { + if (cu_output_) { + auto err = cudaFreeAsync(cu_output_, stream_); + if (err != cudaSuccess) { + LOG(ERROR) << "cudaFreeAsync failed: " << cudaGetErrorString(err); + } + err = cudaStreamSynchronize(stream_); + if (err != cudaSuccess) { + LOG(ERROR) << "cudaStreamSynchronize failed: " << cudaGetErrorString(err); + } + cu_output_ = nullptr; + } +} + +RetCode Sampler::SampleTopPTopK(const float* logits_device, const float* temperatures_host, int32_t batch, + int32_t vocab_size, int32_t batch_stride, float top_p, float top_k, int32_t* output_host) { + const int64_t output_size = batch * sizeof(int32_t); + int32_t output_offset = 0; + int64_t needed_output_size = output_size; + cudaError_t err; + + if (top_k != 1 || top_p != 0.0) { + LOG(ERROR) << "currently only support top_k == 1, top_p == 0"; + return RC_UNSUPPORTED; + } + + if (needed_output_size > cu_output_size_) { + if (cu_output_) { + err = cudaFreeAsync(cu_output_, stream_); + if (err != cudaSuccess) { + LOG(ERROR) << "cudaFreeAsync failed: " << cudaGetErrorString(err); + return RC_DEVICE_MEMORY_ERROR; + } + } + err = cudaMallocAsync(&cu_output_, needed_output_size, stream_); + if (err != cudaSuccess) { + LOG(ERROR) << "cudaMallocAsync failed: " << cudaGetErrorString(err); + return RC_OUT_OF_MEMORY; + } + cu_output_size_ = needed_output_size; + } + + RetCode rc; + rc = ppl::kernel::llm::cuda::pmx::sample_argmax(stream_, logits_device, batch, vocab_size, batch_stride, cu_output_); + + if (rc != RC_SUCCESS) { + LOG(ERROR) << "sampling kernel failed: " << GetRetCodeStr(rc); + return rc; + } + + err = cudaMemcpyAsync(output_host, cu_output_ + output_offset, output_size, cudaMemcpyDeviceToHost, stream_); + if (err != cudaSuccess) { + LOG(ERROR) << "cudaMemcpyAsync output failed: " << cudaGetErrorString(err); + return RC_DEVICE_MEMORY_ERROR; + } + + err = cudaStreamSynchronize(stream_); + if (err != cudaSuccess) { + LOG(ERROR) << "cudaStreamSynchronize failed: " << cudaGetErrorString(err); + return RC_DEVICE_RUNTIME_ERROR; + } + + return RC_SUCCESS; +} + +}}} // namespace ppl::llm::cuda diff --git a/tools/benchmark_llama_deprecated/sampler.h b/tools/benchmark_llama_deprecated/sampler.h new file mode 100644 index 000000000..814f3000d --- /dev/null +++ b/tools/benchmark_llama_deprecated/sampler.h @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef __PPL_LLM_CUDA_SAMPLER_H__ +#define __PPL_LLM_CUDA_SAMPLER_H__ + +#include "ppl/common/retcode.h" + +#include + +namespace ppl { namespace llm { namespace cuda { + +class Sampler final { +public: + Sampler(cudaStream_t stream) : stream_(stream) {} + virtual ~Sampler() { + Clear(); + } + + ppl::common::RetCode SampleTopPTopK(const float* logits_device, const float* temperatures_host, int32_t batch, + int32_t vocab_size, int32_t batch_stride, float top_p, float top_k, int32_t* output_host); + +private: + void Clear(); + +private: + cudaStream_t stream_ = 0; + int32_t* cu_output_ = nullptr; + int64_t cu_output_size_ = 0; +}; + +}}}; // namespace ppl::llm::cuda + +#endif diff --git a/tools/pplnn_llm.cc b/tools/pplnn_llm.cc index 923f3a4a8..1a73b3d6a 100644 --- a/tools/pplnn_llm.cc +++ b/tools/pplnn_llm.cc @@ -239,7 +239,7 @@ Define_string_opt("--in-devices", g_flag_in_devices, "", Define_string_opt("--quant-method", g_flag_quant_method, "none", "llm cuda quantization mehtod, only accept " - "\"none\", \"online_i8i8\" and \"online_i4f16\", " + "\"none\", \"online_i8i8\", \"online_f8f8\" and \"online_i4f16\", " "default: \"none\""); Define_string_opt("--cublas-layout-hint", g_cublas_layout_hint, "default", @@ -247,6 +247,24 @@ Define_string_opt("--cublas-layout-hint", g_cublas_layout_hint, "default", "\"default\", \"ampere\". " "default: \"default\""); +Define_bool_opt("--enable-cache-prefill", g_flag_enable_cache_prefill, + false, "enable cache prefill flash attention"); + +Define_bool_opt("--disable-decoding-shm-mha", g_flag_disable_decoding_shm_mha, + false, "disable shared memory decoding attention algorithm"); +Define_bool_opt("--disable-decoding-inf-mha", g_flag_disable_decoding_inf_mha, + false, "disable infinity decoding attention algorithm"); +Define_bool_opt("--disable-decoding-inf-gqa", g_flag_disable_decoding_inf_gqa, + false, "disable infinity grouped query decoding attention algorithm"); +Define_int32_opt("--configure-decoding-attn-split-k", g_flag_configure_decoding_attn_split_k, 1, + "configuring split-k decoding attention algorithm, " + "accepted values: always-on(2)/heuristic(1)/off(0)," + "default is heuristic(1)"); +Define_int32_opt("--specify-decoding-attn-tpb", g_flag_specify_decoding_attn_tpb, 0, + "specify decoding attention kernel threads per block, " + "accepted values: 512/256/heuristic(0)," + "default is heuristic(0)"); + Define_bool_opt("--disable-graph-fusion", g_flag_disable_graph_fusion, false, "disable graph kernel fusion rules"); Define_bool_opt("--enable-tensor-debug", g_flag_enable_tensor_debug, false, "dump tensors' data"); Define_string_opt("--debug-data-dir", g_flag_debug_data_dir, ".", "directory to save dumped tensors' data"); @@ -331,6 +349,8 @@ static bool RegisterLlmCudaEngine(vector>* engines) { options.quant_method = llm::cuda::QUANT_METHOD_NONE; } else if (g_flag_quant_method == "online_i8i8") { options.quant_method = llm::cuda::QUANT_METHOD_ONLINE_I8I8; + } else if (g_flag_quant_method == "online_f8f8") { + options.quant_method = llm::cuda::QUANT_METHOD_ONLINE_F8F8; } else if (g_flag_quant_method == "online_i4f16") { options.quant_method = llm::cuda::QUANT_METHOD_ONLINE_I4F16; } else { @@ -363,7 +383,39 @@ static bool RegisterLlmCudaEngine(vector>* engines) { } #endif - auto rc = llm_cuda_engine->Configure(llm::cuda::ENGINE_CONF_GRAPH_FUSION, g_flag_disable_graph_fusion ? 0 : 1); + auto rc = llm_cuda_engine->Configure(llm::cuda::ENGINE_CONF_CACHE_PREFILL, g_flag_enable_cache_prefill ? 1 : 0); + if (RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_CACHE_PREFILL failed: " << GetRetCodeStr(rc); + return false; + } + + rc = llm_cuda_engine->Configure(llm::cuda::ENGINE_CONF_DECODING_SHM_MHA, g_flag_disable_decoding_shm_mha ? 0 : 1); + if (RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_DECODING_SHM_MHA failed: " << GetRetCodeStr(rc); + return false; + } + rc = llm_cuda_engine->Configure(llm::cuda::ENGINE_CONF_DECODING_INF_MHA, g_flag_disable_decoding_inf_mha ? 0 : 1); + if (RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_DECODING_INF_MHA failed: " << GetRetCodeStr(rc); + return false; + } + rc = llm_cuda_engine->Configure(llm::cuda::ENGINE_CONF_DECODING_INF_GQA, g_flag_disable_decoding_inf_gqa ? 0 : 1); + if (RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_DECODING_INF_GQA failed: " << GetRetCodeStr(rc); + return false; + } + rc = llm_cuda_engine->Configure(llm::cuda::ENGINE_CONF_DECODING_ATTN_SPLIT_K, g_flag_configure_decoding_attn_split_k); + if (RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_DECODING_ATTN_SPLIT_K failed: " << GetRetCodeStr(rc); + return false; + } + rc = llm_cuda_engine->Configure(llm::cuda::ENGINE_CONF_DECODING_ATTN_TPB, g_flag_specify_decoding_attn_tpb); + if (RC_SUCCESS != rc) { + LOG(ERROR) << "configure ENGINE_CONF_DECODING_ATTN_TPB failed: " << GetRetCodeStr(rc); + return false; + } + + rc = llm_cuda_engine->Configure(llm::cuda::ENGINE_CONF_GRAPH_FUSION, g_flag_disable_graph_fusion ? 0 : 1); if (RC_SUCCESS != rc) { LOG(ERROR) << "configure ENGINE_CONF_GRAPH_FUSION failed: " << GetRetCodeStr(rc); return false; @@ -1151,13 +1203,13 @@ int main(int argc, char* argv[]) { if (!g_flag_pmx_external_data_dir.empty()) { opt.external_data_dir = g_flag_pmx_external_data_dir.c_str(); } - + auto status = builder->LoadModel(g_flag_pmx_model.c_str(), resources, opt); if (status != RC_SUCCESS) { LOG(ERROR) << "PmxRuntimeBuilder LoadModel failed: " << GetRetCodeStr(status); return -1; } - + status = builder->Preprocess(); if (status != RC_SUCCESS) { LOG(ERROR) << "pmx preprocess failed: " << GetRetCodeStr(status);