From f817c394cb86fdb08b525e2acf246b0a496231ce Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 9 Dec 2025 15:12:39 +0800 Subject: [PATCH] issue/744: kunlun softplus issue/744: softplus issue/744: kunlun softplus issue/744: delete F64 --- src/infiniop/ops/softplus/kunlun/kernel.h | 28 ++++++++++ .../ops/softplus/kunlun/softplus_kunlun.h | 8 +++ .../ops/softplus/kunlun/softplus_kunlun.xpu | 56 +++++++++++++++++++ src/infiniop/ops/softplus/operator.cc | 19 +++++-- 4 files changed, 107 insertions(+), 4 deletions(-) create mode 100644 src/infiniop/ops/softplus/kunlun/kernel.h create mode 100644 src/infiniop/ops/softplus/kunlun/softplus_kunlun.h create mode 100644 src/infiniop/ops/softplus/kunlun/softplus_kunlun.xpu diff --git a/src/infiniop/ops/softplus/kunlun/kernel.h b/src/infiniop/ops/softplus/kunlun/kernel.h new file mode 100644 index 000000000..ba9dcf196 --- /dev/null +++ b/src/infiniop/ops/softplus/kunlun/kernel.h @@ -0,0 +1,28 @@ +#ifndef __SOFTPLUS_KUNLUN_KERNEL_H__ +#define __SOFTPLUS_KUNLUN_KERNEL_H__ + +namespace op::softplus::kunlun { + +typedef struct SoftplusOp { +public: + static constexpr int num_inputs = 1; + template + inline __device__ T operator()(const T *inputs) const { + if constexpr (std::is_same_v) { + float xf = __half2float(inputs[0]); + float out = (xf > 20.0f) ? xf : log(1 + exp(xf)); + return __float2half(out); + } else if constexpr (std::is_same_v) { + float xf = __bfloat162float(inputs[0]); + float out = (xf > 20.0f) ? xf : log(1 + exp(xf)); + return __float2bfloat16(out); + } else { + float xf = inputs[0]; + return (xf > 20.0f) ? xf : log(1 + exp(xf)); + } + } +} SoftplusOp; + +} // namespace op::softplus::kunlun + +#endif // __SOFTPLUS_KUNLUN_KERNEL_H__ diff --git a/src/infiniop/ops/softplus/kunlun/softplus_kunlun.h b/src/infiniop/ops/softplus/kunlun/softplus_kunlun.h new file mode 100644 index 000000000..1742519e4 --- /dev/null +++ b/src/infiniop/ops/softplus/kunlun/softplus_kunlun.h @@ -0,0 +1,8 @@ +#ifndef __SOFTPLUS_KUNLUN_H__ +#define __SOFTPLUS_KUNLUN_H__ + +#include "../../../elementwise/kunlun/elementwise_kunlun_api.h" + +ELEMENTWISE_DESCRIPTOR(softplus, kunlun) + +#endif diff --git a/src/infiniop/ops/softplus/kunlun/softplus_kunlun.xpu b/src/infiniop/ops/softplus/kunlun/softplus_kunlun.xpu new file mode 100644 index 000000000..9a0f5618b --- /dev/null +++ b/src/infiniop/ops/softplus/kunlun/softplus_kunlun.xpu @@ -0,0 +1,56 @@ +#include "../../../elementwise/kunlun/elementwise_kunlun.h" +#include "kernel.h" +#include "softplus_kunlun.h" + +namespace op::softplus::kunlun { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + // create KUNLUN elementwise descriptor + CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<8, kunlun::SoftplusOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<8, kunlun::SoftplusOp, bfloat16_t>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<8, kunlun::SoftplusOp, float>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::softplus::kunlun diff --git a/src/infiniop/ops/softplus/operator.cc b/src/infiniop/ops/softplus/operator.cc index e0fc980c6..7d09bad6c 100644 --- a/src/infiniop/ops/softplus/operator.cc +++ b/src/infiniop/ops/softplus/operator.cc @@ -11,6 +11,9 @@ #ifdef ENABLE_METAX_API #include "metax/softplus_metax.h" #endif +#ifdef ENABLE_KUNLUN_API +#include "kunlun/softplus_kunlun.h" +#endif __C infiniStatus_t infiniopCreateSoftplusDescriptor( infiniopHandle_t handle, @@ -43,7 +46,9 @@ __C infiniStatus_t infiniopCreateSoftplusDescriptor( #ifdef ENABLE_METAX_API CREATE(INFINI_DEVICE_METAX, metax); #endif - +#ifdef ENABLE_KUNLUN_API + CREATE(INFINI_DEVICE_KUNLUN, kunlun); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -74,7 +79,9 @@ __C infiniStatus_t infiniopGetSoftplusWorkspaceSize(infiniopSoftplusDescriptor_t #ifdef ENABLE_METAX_API GET(INFINI_DEVICE_METAX, metax); #endif - +#ifdef ENABLE_KUNLUN_API + GET(INFINI_DEVICE_KUNLUN, kunlun); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -113,7 +120,9 @@ __C infiniStatus_t infiniopSoftplus( #ifdef ENABLE_METAX_API CALCULATE(INFINI_DEVICE_METAX, metax); #endif - +#ifdef ENABLE_KUNLUN_API + CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -146,7 +155,9 @@ infiniopDestroySoftplusDescriptor(infiniopSoftplusDescriptor_t desc) { #ifdef ENABLE_METAX_API DELETE(INFINI_DEVICE_METAX, metax); #endif - +#ifdef ENABLE_KUNLUN_API + DELETE(INFINI_DEVICE_KUNLUN, kunlun); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; }