Skip to content

Commit f817c39

Browse files
committed
issue/744: kunlun softplus
issue/744: softplus issue/744: kunlun softplus issue/744: delete F64
1 parent b38d5d1 commit f817c39

File tree

4 files changed

+107
-4
lines changed

4 files changed

+107
-4
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#ifndef __SOFTPLUS_KUNLUN_KERNEL_H__
2+
#define __SOFTPLUS_KUNLUN_KERNEL_H__
3+
4+
namespace op::softplus::kunlun {
5+
6+
typedef struct SoftplusOp {
7+
public:
8+
static constexpr int num_inputs = 1;
9+
template <typename T>
10+
inline __device__ T operator()(const T *inputs) const {
11+
if constexpr (std::is_same_v<T, half>) {
12+
float xf = __half2float(inputs[0]);
13+
float out = (xf > 20.0f) ? xf : log(1 + exp(xf));
14+
return __float2half(out);
15+
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
16+
float xf = __bfloat162float(inputs[0]);
17+
float out = (xf > 20.0f) ? xf : log(1 + exp(xf));
18+
return __float2bfloat16(out);
19+
} else {
20+
float xf = inputs[0];
21+
return (xf > 20.0f) ? xf : log(1 + exp(xf));
22+
}
23+
}
24+
} SoftplusOp;
25+
26+
} // namespace op::softplus::kunlun
27+
28+
#endif // __SOFTPLUS_KUNLUN_KERNEL_H__
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __SOFTPLUS_KUNLUN_H__
2+
#define __SOFTPLUS_KUNLUN_H__
3+
4+
#include "../../../elementwise/kunlun/elementwise_kunlun_api.h"
5+
6+
ELEMENTWISE_DESCRIPTOR(softplus, kunlun)
7+
8+
#endif
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
2+
#include "kernel.h"
3+
#include "softplus_kunlun.h"
4+
5+
namespace op::softplus::kunlun {
6+
7+
Descriptor::~Descriptor() = default;
8+
9+
infiniStatus_t Descriptor::create(
10+
infiniopHandle_t handle_,
11+
Descriptor **desc_ptr,
12+
infiniopTensorDescriptor_t out_desc,
13+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
14+
15+
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
16+
auto dtype = out_desc->dtype();
17+
18+
const auto &x_desc = input_desc_vec.at(0);
19+
const auto &y_shape = out_desc->shape();
20+
const auto &x_shape = x_desc->shape();
21+
22+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
23+
24+
CHECK_SAME_SHAPE(y_shape, x_shape);
25+
26+
// create KUNLUN elementwise descriptor
27+
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
28+
29+
return INFINI_STATUS_SUCCESS;
30+
}
31+
32+
infiniStatus_t Descriptor::calculate(
33+
void *workspace,
34+
size_t workspace_size,
35+
void *output,
36+
std::vector<const void *> inputs,
37+
void *stream) const {
38+
39+
if (workspace_size < _workspace_size) {
40+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
41+
}
42+
43+
switch (_dtype) {
44+
case INFINI_DTYPE_F16:
45+
return _device_info->calculate<8, kunlun::SoftplusOp, half>(_info, workspace, output, inputs, stream);
46+
case INFINI_DTYPE_BF16:
47+
return _device_info->calculate<8, kunlun::SoftplusOp, bfloat16_t>(_info, workspace, output, inputs, stream);
48+
case INFINI_DTYPE_F32:
49+
return _device_info->calculate<8, kunlun::SoftplusOp, float>(_info, workspace, output, inputs, stream);
50+
default:
51+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
52+
}
53+
54+
return INFINI_STATUS_SUCCESS;
55+
}
56+
} // namespace op::softplus::kunlun

src/infiniop/ops/softplus/operator.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#ifdef ENABLE_METAX_API
1212
#include "metax/softplus_metax.h"
1313
#endif
14+
#ifdef ENABLE_KUNLUN_API
15+
#include "kunlun/softplus_kunlun.h"
16+
#endif
1417

1518
__C infiniStatus_t infiniopCreateSoftplusDescriptor(
1619
infiniopHandle_t handle,
@@ -43,7 +46,9 @@ __C infiniStatus_t infiniopCreateSoftplusDescriptor(
4346
#ifdef ENABLE_METAX_API
4447
CREATE(INFINI_DEVICE_METAX, metax);
4548
#endif
46-
49+
#ifdef ENABLE_KUNLUN_API
50+
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
51+
#endif
4752
default:
4853
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
4954
}
@@ -74,7 +79,9 @@ __C infiniStatus_t infiniopGetSoftplusWorkspaceSize(infiniopSoftplusDescriptor_t
7479
#ifdef ENABLE_METAX_API
7580
GET(INFINI_DEVICE_METAX, metax);
7681
#endif
77-
82+
#ifdef ENABLE_KUNLUN_API
83+
GET(INFINI_DEVICE_KUNLUN, kunlun);
84+
#endif
7885
default:
7986
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
8087
}
@@ -113,7 +120,9 @@ __C infiniStatus_t infiniopSoftplus(
113120
#ifdef ENABLE_METAX_API
114121
CALCULATE(INFINI_DEVICE_METAX, metax);
115122
#endif
116-
123+
#ifdef ENABLE_KUNLUN_API
124+
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
125+
#endif
117126
default:
118127
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
119128
}
@@ -146,7 +155,9 @@ infiniopDestroySoftplusDescriptor(infiniopSoftplusDescriptor_t desc) {
146155
#ifdef ENABLE_METAX_API
147156
DELETE(INFINI_DEVICE_METAX, metax);
148157
#endif
149-
158+
#ifdef ENABLE_KUNLUN_API
159+
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
160+
#endif
150161
default:
151162
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
152163
}

0 commit comments

Comments
 (0)