|
| 1 | +#include "cuda_kernel.hh" |
| 2 | + |
| 3 | +#ifdef USE_CUDA |
| 4 | +#include "../../generator/nvrtc_repo.h" |
| 5 | +#include "kernel/cuda/threads_distributer.cuh" |
| 6 | +#endif |
| 7 | + |
| 8 | +namespace refactor::kernel { |
| 9 | + using K = SelectCuda; |
| 10 | + |
| 11 | + K::SelectCuda(decltype(dataType) dataType_, |
| 12 | + decltype(selectType) selectType_, |
| 13 | + decltype(broadcaster) broadcaster_, |
| 14 | + decltype(inputsNum) inputsNum_) noexcept |
| 15 | + : dataType(dataType_), |
| 16 | + selectType(selectType_), |
| 17 | + broadcaster(broadcaster_), |
| 18 | + inputsNum(inputsNum_) {} |
| 19 | + |
| 20 | + auto K::build(SelectType selectType_, TensorRefs inputs_) noexcept -> KernelBox { |
| 21 | +#ifndef USE_CUDA |
| 22 | + return nullptr; |
| 23 | +#endif |
| 24 | + |
| 25 | + return std::make_unique<K>(inputs_[0].get().dataType, selectType_, Broadcaster(inputs_), inputs_.size()); |
| 26 | + } |
| 27 | + |
| 28 | + auto K::typeId() noexcept -> size_t { |
| 29 | + static uint8_t ID = 1; |
| 30 | + return reinterpret_cast<size_t>(&ID); |
| 31 | + } |
| 32 | + |
| 33 | + auto K::kernelTypeId() const noexcept -> size_t { |
| 34 | + return typeId(); |
| 35 | + } |
| 36 | + auto K::description() const noexcept -> std::string_view { |
| 37 | + return "Performing select operation on Nvidia GPU"; |
| 38 | + } |
| 39 | + |
| 40 | +#ifdef USE_CUDA |
| 41 | + |
| 42 | + constexpr static const char *NO_BROADCAST = R"~( |
| 43 | +struct Inputs {{ |
| 44 | + {dt} const *const addr[{inputsNum}]; |
| 45 | +}}; |
| 46 | +
|
| 47 | +__device__ __forceinline__ static {dt} fn({dt} a, {dt} b) {{ |
| 48 | + return {op}; |
| 49 | +}} |
| 50 | +
|
| 51 | +extern "C" __global__ void kernel( |
| 52 | + {dt} *__restrict__ output, |
| 53 | + Inputs inputs |
| 54 | +) {{ |
| 55 | + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, |
| 56 | + step = blockDim.x * gridDim.x; |
| 57 | + tid < {n}; |
| 58 | + tid += step) {{ |
| 59 | + output[tid] = inputs.addr[0][tid]; |
| 60 | + for (auto idx = 1; idx < {inputsNum}; ++idx) {{ |
| 61 | + output[tid] = fn(inputs.addr[idx][tid], output[tid]); |
| 62 | + }} |
| 63 | + }} |
| 64 | +}} |
| 65 | +)~"; |
| 66 | + |
| 67 | + constexpr static const char *BROADCAST = R"~( |
| 68 | +struct Inputs {{ |
| 69 | + {dt} const *const addr[{inputsNum}]; |
| 70 | +}}; |
| 71 | +
|
| 72 | +struct Strides {{ |
| 73 | + unsigned int s[({inputsNum}+1) * {rank}]; |
| 74 | +}}; |
| 75 | +
|
| 76 | +__device__ __forceinline__ static {dt} fn({dt} a, {dt} b) {{ |
| 77 | + return {op}; |
| 78 | +}} |
| 79 | +
|
| 80 | +extern "C" __global__ void kernel( |
| 81 | + {dt} *__restrict__ output, |
| 82 | + Inputs inputs, |
| 83 | + Strides strides |
| 84 | +) {{ |
| 85 | + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, |
| 86 | + step = blockDim.x * gridDim.x; |
| 87 | + tid < {n}; |
| 88 | + tid += step) {{ |
| 89 | + auto rem = tid; |
| 90 | + size_t ans[{inputsNum}]{{}}; |
| 91 | + for (auto i = 0; i < {rank}; ++i) {{ |
| 92 | + auto dim = strides.s + ({inputsNum} + 1) * i; |
| 93 | + auto quot = rem / dim[{inputsNum}]; |
| 94 | + for (auto j = 0; j < {inputsNum}; ++j) {{ ans[j] += dim[j] * quot; }} |
| 95 | + rem %= dim[{inputsNum}]; |
| 96 | + }} |
| 97 | + output[tid] = inputs.addr[0][ans[0]]; |
| 98 | + for (auto idx = 1; idx < {inputsNum}; ++idx) {{ |
| 99 | + output[tid] = fn(inputs.addr[idx][ans[idx]], output[tid]); |
| 100 | + }} |
| 101 | + }} |
| 102 | +}} |
| 103 | +)~"; |
| 104 | + |
| 105 | + constexpr static std::string_view op(SelectType op, DataType dt) { |
| 106 | + switch (op) { |
| 107 | + case SelectType::Max: |
| 108 | + return "a > b ? a : b"; |
| 109 | + case SelectType::Min: |
| 110 | + return "a < b ? a : b"; |
| 111 | + default: |
| 112 | + UNREACHABLE(); |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { |
| 117 | + using namespace runtime; |
| 118 | + |
| 119 | + auto postfix = fmt::format("_{}_{}", dataType.name(), opName(selectType)); |
| 120 | + auto dt_ = nvrtc::dataType(dataType); |
| 121 | + auto op_ = op(selectType, dataType); |
| 122 | + auto params = cuda::ThreadsDistributer()(broadcaster.outputsCount); |
| 123 | + |
| 124 | + if (!broadcaster.needBroadcast()) { |
| 125 | + auto name = fmt::format("select{}", postfix); |
| 126 | + auto code = fmt::format(NO_BROADCAST, |
| 127 | + fmt::arg("dt", dt_), |
| 128 | + fmt::arg("op", op_), |
| 129 | + fmt::arg("inputsNum", inputsNum), |
| 130 | + fmt::arg("n", params.n)); |
| 131 | + return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel")]// |
| 132 | + (Resources &, void *, void const *const *inputs, void *const *outputs) { |
| 133 | + auto output = outputs[0]; |
| 134 | + void *args[]{&output, const_cast<void **>(inputs)}; |
| 135 | + h->launch(params.gridSize, 1, 1, |
| 136 | + params.blockSize, 1, 1, |
| 137 | + 0, args); |
| 138 | + }; |
| 139 | + } else { |
| 140 | + auto name = fmt::format("select{}", postfix); |
| 141 | + auto rank = broadcaster.strides.size() / (broadcaster.inputsCount + 1); |
| 142 | + auto code = fmt::format( |
| 143 | + BROADCAST, |
| 144 | + fmt::arg("dt", dt_), |
| 145 | + fmt::arg("op", op_), |
| 146 | + fmt::arg("inputsNum", inputsNum), |
| 147 | + fmt::arg("n", params.n), |
| 148 | + fmt::arg("rank", rank)); |
| 149 | + return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), |
| 150 | + strides = broadcaster.strides]// |
| 151 | + (Resources &, void *, void const *const *inputs, void *const *outputs) { |
| 152 | + void *args[]{const_cast<void **>(outputs), const_cast<void **>(inputs), const_cast<dim_t *>(strides.data())}; |
| 153 | + h->launch(params.gridSize, 1, 1, |
| 154 | + params.blockSize, 1, 1, |
| 155 | + 0, args); |
| 156 | + }; |
| 157 | + } |
| 158 | + } |
| 159 | + |
| 160 | +#endif |
| 161 | + |
| 162 | +}// namespace refactor::kernel |
0 commit comments