Skip to content

Commit e237349

Browse files
committed
add max/min cuda kernel and test
1 parent 5e13361 commit e237349

File tree

8 files changed

+337
-43
lines changed

8 files changed

+337
-43
lines changed

src/04kernel/include/kernel/collectors/select.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ namespace refactor::kernel {
1010
Min,
1111
};
1212

13+
std::string_view opName(SelectType type);
14+
1315
struct SelectCollector final : public InfoCollector {
1416
SelectType selectType;
1517

src/04kernel/src/attributes/broadcaster.cc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,17 @@ namespace refactor::kernel {
8686
}()) {}
8787

8888
void Broadcaster::locate(dim_t k, dim_t ans[]) const noexcept {
89-
long rem = k;
90-
std::fill_n(ans, inputsCount, 0);
91-
for (auto i : range0_(strides.size() / (inputsCount + 1))) {
92-
auto dim = strides.data() + (inputsCount + 1) * i;
93-
auto div = std::div(rem, dim[inputsCount]);
94-
for (auto j : range0_(inputsCount)) { ans[j] += dim[j] * div.quot; }
95-
rem = div.rem;
89+
if (!needBroadcast()) {
90+
std::fill_n(ans, inputsCount, k);
91+
} else {
92+
long rem = k;
93+
std::fill_n(ans, inputsCount, 0);
94+
for (auto i : range0_(strides.size() / (inputsCount + 1))) {
95+
auto dim = strides.data() + (inputsCount + 1) * i;
96+
auto div = std::div(rem, dim[inputsCount]);
97+
for (auto j : range0_(inputsCount)) { ans[j] += dim[j] * div.quot; }
98+
rem = div.rem;
99+
}
96100
}
97101
}
98102

src/04kernel/src/collectors/select.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,27 @@
11
#include "kernel/collectors/select.h"
2+
#include "../kernels/select/cpu_kernel.hh"
3+
#include "../kernels/select/cuda_kernel.hh"
24

35
namespace refactor::kernel {
46

7+
#define REGISTER(T) \
8+
if (auto ptr = T::build(selectType, inputs); ptr) { \
9+
ans.emplace_back(std::move(ptr)); \
10+
}
11+
12+
#define CASE(OP) \
13+
case SelectType::OP: \
14+
return #OP
15+
16+
std::string_view opName(SelectType type) {
17+
switch (type) {
18+
CASE(Max);
19+
CASE(Min);
20+
default:
21+
UNREACHABLE();
22+
}
23+
}
24+
525
SelectCollector::SelectCollector(decltype(_target) target, SelectType type) noexcept
626
: InfoCollector(target), selectType(type) {}
727

@@ -10,8 +30,10 @@ namespace refactor::kernel {
1030
std::vector<KernelBox> ans;
1131
switch (_target) {
1232
case decltype(_target)::Cpu:
33+
REGISTER(SelectCpu)
1334
break;
1435
case decltype(_target)::Nvidia:
36+
REGISTER(SelectCuda)
1537
break;
1638
default:
1739
UNREACHABLEX(void, "Unknown target");

src/04kernel/src/kernels/concat/cuda_kernel.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,6 @@ extern "C" __global__ void kernel(
8181
}
8282
auto segments = ss.str();
8383

84-
ss.str("");
85-
for (auto i : range0_(inputCount)) {
86-
ss << std::endl
87-
<< " reinterpret_cast<char const *>(inputs[" << i << "]), ";
88-
}
89-
auto castInputs = ss.str();
90-
9184
ss.str("");
9285
ss << "Concat_" << info.blockCount << ',' << unit;
9386
for (auto seg : info.segments) {

src/04kernel/src/kernels/select/cpu_kernel.cc

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -50,37 +50,21 @@ namespace refactor::kernel {
5050
UNREACHABLE();
5151
}
5252

53-
if (broadcaster.needBroadcast()) {
54-
return [broadcaster, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
55-
auto output = reinterpret_cast<T *>(outputs[0]);
56-
for (auto i : range0_(broadcaster.outputsCount)) {
57-
std::vector<dim_t> ans(broadcaster.inputsCount);
58-
broadcaster.locate(i, ans.data());
59-
for (auto inputIdx : range0_(inputsNum)) {
60-
auto input = reinterpret_cast<const T *>(inputs[inputIdx]);
61-
if (inputIdx == 0) {
62-
output[i] = input[ans[inputIdx]];
63-
} else {
64-
output[i] = op(output[i], input[ans[inputIdx]]);
65-
}
53+
return [broadcaster, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
54+
auto output = reinterpret_cast<T *>(outputs[0]);
55+
for (auto i : range0_(broadcaster.outputsCount)) {
56+
std::vector<dim_t> ans(broadcaster.inputsCount);
57+
broadcaster.locate(i, ans.data());
58+
for (auto inputIdx : range0_(inputsNum)) {
59+
auto input = reinterpret_cast<const T *>(inputs[inputIdx]);
60+
if (inputIdx == 0) {
61+
output[i] = input[ans[inputIdx]];
62+
} else {
63+
output[i] = op(output[i], input[ans[inputIdx]]);
6664
}
6765
}
68-
};
69-
} else {
70-
return [n = broadcaster.outputsCount, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
71-
auto output = reinterpret_cast<T *>(outputs[0]);
72-
for (auto i : range0_(n)) {
73-
for (auto inputIdx : range0_(inputsNum)) {
74-
auto input = reinterpret_cast<const T *>(inputs[inputIdx]);
75-
if (inputIdx == 0) {
76-
output[i] = input[i];
77-
} else {
78-
output[i] = op(output[i], input[i]);
79-
}
80-
}
81-
};
82-
};
83-
}
66+
}
67+
};
8468
}
8569

8670
auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
#include "cuda_kernel.hh"
2+
3+
#ifdef USE_CUDA
4+
#include "../../generator/nvrtc_repo.h"
5+
#include "kernel/cuda/threads_distributer.cuh"
6+
#endif
7+
8+
namespace refactor::kernel {
9+
using K = SelectCuda;
10+
11+
K::SelectCuda(decltype(dataType) dataType_,
12+
decltype(selectType) selectType_,
13+
decltype(broadcaster) broadcaster_,
14+
decltype(inputsNum) inputsNum_) noexcept
15+
: dataType(dataType_),
16+
selectType(selectType_),
17+
broadcaster(broadcaster_),
18+
inputsNum(inputsNum_) {}
19+
20+
auto K::build(SelectType selectType_, TensorRefs inputs_) noexcept -> KernelBox {
21+
#ifndef USE_CUDA
22+
return nullptr;
23+
#endif
24+
25+
return std::make_unique<K>(inputs_[0].get().dataType, selectType_, Broadcaster(inputs_), inputs_.size());
26+
}
27+
28+
auto K::typeId() noexcept -> size_t {
29+
static uint8_t ID = 1;
30+
return reinterpret_cast<size_t>(&ID);
31+
}
32+
33+
auto K::kernelTypeId() const noexcept -> size_t {
34+
return typeId();
35+
}
36+
auto K::description() const noexcept -> std::string_view {
37+
return "Performing select operation on Nvidia GPU";
38+
}
39+
40+
#ifdef USE_CUDA
41+
42+
constexpr static const char *NO_BROADCAST = R"~(
43+
struct Inputs {{
44+
{dt} const *const addr[{inputsNum}];
45+
}};
46+
47+
__device__ __forceinline__ static {dt} fn({dt} a, {dt} b) {{
48+
return {op};
49+
}}
50+
51+
extern "C" __global__ void kernel(
52+
{dt} *__restrict__ output,
53+
Inputs inputs
54+
) {{
55+
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
56+
step = blockDim.x * gridDim.x;
57+
tid < {n};
58+
tid += step) {{
59+
output[tid] = inputs.addr[0][tid];
60+
for (auto idx = 1; idx < {inputsNum}; ++idx) {{
61+
output[tid] = fn(inputs.addr[idx][tid], output[tid]);
62+
}}
63+
}}
64+
}}
65+
)~";
66+
67+
constexpr static const char *BROADCAST = R"~(
68+
struct Inputs {{
69+
{dt} const *const addr[{inputsNum}];
70+
}};
71+
72+
struct Strides {{
73+
unsigned int s[({inputsNum}+1) * {rank}];
74+
}};
75+
76+
__device__ __forceinline__ static {dt} fn({dt} a, {dt} b) {{
77+
return {op};
78+
}}
79+
80+
extern "C" __global__ void kernel(
81+
{dt} *__restrict__ output,
82+
Inputs inputs,
83+
Strides strides
84+
) {{
85+
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
86+
step = blockDim.x * gridDim.x;
87+
tid < {n};
88+
tid += step) {{
89+
auto rem = tid;
90+
size_t ans[{inputsNum}]{{}};
91+
for (auto i = 0; i < {rank}; ++i) {{
92+
auto dim = strides.s + ({inputsNum} + 1) * i;
93+
auto quot = rem / dim[{inputsNum}];
94+
for (auto j = 0; j < {inputsNum}; ++j) {{ ans[j] += dim[j] * quot; }}
95+
rem %= dim[{inputsNum}];
96+
}}
97+
output[tid] = inputs.addr[0][ans[0]];
98+
for (auto idx = 1; idx < {inputsNum}; ++idx) {{
99+
output[tid] = fn(inputs.addr[idx][ans[idx]], output[tid]);
100+
}}
101+
}}
102+
}}
103+
)~";
104+
105+
constexpr static std::string_view op(SelectType op, DataType dt) {
106+
switch (op) {
107+
case SelectType::Max:
108+
return "a > b ? a : b";
109+
case SelectType::Min:
110+
return "a < b ? a : b";
111+
default:
112+
UNREACHABLE();
113+
}
114+
}
115+
116+
auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
117+
using namespace runtime;
118+
119+
auto postfix = fmt::format("_{}_{}", dataType.name(), opName(selectType));
120+
auto dt_ = nvrtc::dataType(dataType);
121+
auto op_ = op(selectType, dataType);
122+
auto params = cuda::ThreadsDistributer()(broadcaster.outputsCount);
123+
124+
if (!broadcaster.needBroadcast()) {
125+
auto name = fmt::format("select{}", postfix);
126+
auto code = fmt::format(NO_BROADCAST,
127+
fmt::arg("dt", dt_),
128+
fmt::arg("op", op_),
129+
fmt::arg("inputsNum", inputsNum),
130+
fmt::arg("n", params.n));
131+
return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel")]//
132+
(Resources &, void *, void const *const *inputs, void *const *outputs) {
133+
auto output = outputs[0];
134+
void *args[]{&output, const_cast<void **>(inputs)};
135+
h->launch(params.gridSize, 1, 1,
136+
params.blockSize, 1, 1,
137+
0, args);
138+
};
139+
} else {
140+
auto name = fmt::format("select{}", postfix);
141+
auto rank = broadcaster.strides.size() / (broadcaster.inputsCount + 1);
142+
auto code = fmt::format(
143+
BROADCAST,
144+
fmt::arg("dt", dt_),
145+
fmt::arg("op", op_),
146+
fmt::arg("inputsNum", inputsNum),
147+
fmt::arg("n", params.n),
148+
fmt::arg("rank", rank));
149+
return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
150+
strides = broadcaster.strides]//
151+
(Resources &, void *, void const *const *inputs, void *const *outputs) {
152+
void *args[]{const_cast<void **>(outputs), const_cast<void **>(inputs), const_cast<dim_t *>(strides.data())};
153+
h->launch(params.gridSize, 1, 1,
154+
params.blockSize, 1, 1,
155+
0, args);
156+
};
157+
}
158+
}
159+
160+
#endif
161+
162+
}// namespace refactor::kernel
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#ifndef KERNEL_SELECT_CUDA_KERNEL_HH
2+
#define KERNEL_SELECT_CUDA_KERNEL_HH
3+
4+
#include "kernel/attributes/broadcaster.h"
5+
#include "kernel/collectors/select.h"
6+
#include "kernel/kernel.h"
7+
#include "kernel/tensor.h"
8+
9+
namespace refactor::kernel {
10+
11+
struct SelectCuda final : public Kernel {
12+
DataType dataType;
13+
SelectType selectType;
14+
Broadcaster broadcaster;
15+
size_t inputsNum;
16+
17+
SelectCuda(decltype(dataType), decltype(selectType), decltype(broadcaster), decltype(inputsNum)) noexcept;
18+
19+
static KernelBox build(SelectType, TensorRefs) noexcept;
20+
static size_t typeId() noexcept;
21+
22+
size_t kernelTypeId() const noexcept final;
23+
std::string_view description() const noexcept final;
24+
#ifdef USE_CUDA
25+
RoutineWorkspace lower(Resources &) const noexcept final;
26+
#endif
27+
};
28+
29+
}// namespace refactor::kernel
30+
31+
#endif// KERNEL_SELECT_CUDA_KERNEL_HH

0 commit comments

Comments
 (0)