Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/infiniop/ops/softplus/kunlun/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef __SOFTPLUS_KUNLUN_KERNEL_H__
#define __SOFTPLUS_KUNLUN_KERNEL_H__

namespace op::softplus::kunlun {

typedef struct SoftplusOp {
public:
static constexpr int num_inputs = 1;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
if constexpr (std::is_same_v<T, half>) {
float xf = __half2float(inputs[0]);
float out = (xf > 20.0f) ? xf : log(1 + exp(xf));
return __float2half(out);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
float xf = __bfloat162float(inputs[0]);
float out = (xf > 20.0f) ? xf : log(1 + exp(xf));
return __float2bfloat16(out);
} else {
float xf = inputs[0];
return (xf > 20.0f) ? xf : log(1 + exp(xf));
}
}
} SoftplusOp;

} // namespace op::softplus::kunlun

#endif // __SOFTPLUS_KUNLUN_KERNEL_H__
8 changes: 8 additions & 0 deletions src/infiniop/ops/softplus/kunlun/softplus_kunlun.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __SOFTPLUS_KUNLUN_H__
#define __SOFTPLUS_KUNLUN_H__

#include "../../../elementwise/kunlun/elementwise_kunlun_api.h"

ELEMENTWISE_DESCRIPTOR(softplus, kunlun)

#endif
56 changes: 56 additions & 0 deletions src/infiniop/ops/softplus/kunlun/softplus_kunlun.xpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
#include "kernel.h"
#include "softplus_kunlun.h"

namespace op::softplus::kunlun {

Descriptor::~Descriptor() = default;

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {

auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();

const auto &x_desc = input_desc_vec.at(0);
const auto &y_shape = out_desc->shape();
const auto &x_shape = x_desc->shape();

CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);

CHECK_SAME_SHAPE(y_shape, x_shape);

// create KUNLUN elementwise descriptor
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {

if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}

switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<8, kunlun::SoftplusOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<8, kunlun::SoftplusOp, bfloat16_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<8, kunlun::SoftplusOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

return INFINI_STATUS_SUCCESS;
}
} // namespace op::softplus::kunlun
19 changes: 15 additions & 4 deletions src/infiniop/ops/softplus/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#ifdef ENABLE_METAX_API
#include "metax/softplus_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/softplus_kunlun.h"
#endif

__C infiniStatus_t infiniopCreateSoftplusDescriptor(
infiniopHandle_t handle,
Expand Down Expand Up @@ -43,7 +46,9 @@ __C infiniStatus_t infiniopCreateSoftplusDescriptor(
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif

#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand Down Expand Up @@ -74,7 +79,9 @@ __C infiniStatus_t infiniopGetSoftplusWorkspaceSize(infiniopSoftplusDescriptor_t
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif

#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand Down Expand Up @@ -113,7 +120,9 @@ __C infiniStatus_t infiniopSoftplus(
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif

#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand Down Expand Up @@ -146,7 +155,9 @@ infiniopDestroySoftplusDescriptor(infiniopSoftplusDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif

#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand Down