diff --git a/.github/workflows/build_kernel.yaml b/.github/workflows/build_kernel.yaml index f2e92d81..e4b859c1 100644 --- a/.github/workflows/build_kernel.yaml +++ b/.github/workflows/build_kernel.yaml @@ -47,6 +47,9 @@ jobs: - name: Build relu kernel (specific Torch version) run: ( cd examples/relu-specific-torch && nix build . ) + - name: Build relu kernel (compiler flags) + run: ( cd examples/relu-compiler-flags && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux ) + - name: Test that we can build a test shell (e.g. that gcc corresponds to CUDA-required) run: ( cd examples/relu && nix build .#devShells.x86_64-linux.test ) diff --git a/.github/workflows/build_kernel_xpu.yaml b/.github/workflows/build_kernel_xpu.yaml index 49f7d773..ebf40ecf 100644 --- a/.github/workflows/build_kernel_xpu.yaml +++ b/.github/workflows/build_kernel_xpu.yaml @@ -27,3 +27,6 @@ jobs: # kernels. Also run tests once we have a XPU runner. - name: Build relu kernel run: ( cd examples/relu && nix build .\#redistributable.torch29-cxx11-xpu20252-x86_64-linux -L ) + + - name: Build relu kernel (compiler flags) + run: ( cd examples/relu-compiler-flags && nix build .\#redistributable.torch29-cxx11-xpu20252-x86_64-linux ) diff --git a/examples/relu-compiler-flags/build.toml b/examples/relu-compiler-flags/build.toml new file mode 100644 index 00000000..e493b3e8 --- /dev/null +++ b/examples/relu-compiler-flags/build.toml @@ -0,0 +1,37 @@ +[general] +name = "relu" +universal = false + +[torch] +src = [ + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", +] + +[kernel.activation] +backend = "cuda" +depends = ["torch"] +src = ["relu_cuda/relu.cu"] +cuda-flags = [ "-DWHO_AM_I_IF_NOT_THE_CANARY" ] + +[kernel.activation_rocm] +backend = "rocm" +rocm-archs = [ + "gfx906", + "gfx908", + "gfx90a", + "gfx940", + "gfx941", + "gfx942", + "gfx1030", + "gfx1100", + "gfx1101", +] +depends = ["torch"] +src = ["relu_cuda/relu.cu"] + +[kernel.activation_xpu] +backend = "xpu" +depends = ["torch"] +src = ["relu_xpu/relu.cpp"] +sycl-flags = [ "-DWHO_AM_I_IF_NOT_THE_CANARY" ] diff --git a/examples/relu-compiler-flags/flake.nix b/examples/relu-compiler-flags/flake.nix new file mode 100644 index 00000000..bfe8717d --- /dev/null +++ b/examples/relu-compiler-flags/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for ReLU kernel"; + + inputs = { + kernel-builder.url = "path:../.."; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + inherit self; + path = ./.; + }; +} diff --git a/examples/relu-compiler-flags/relu_cuda/relu.cu b/examples/relu-compiler-flags/relu_cuda/relu.cu new file mode 100644 index 00000000..6dea98b8 --- /dev/null +++ b/examples/relu-compiler-flags/relu_cuda/relu.cu @@ -0,0 +1,47 @@ +#include +#include +#include + +#include + +#ifndef WHO_AM_I_IF_NOT_THE_CANARY +#error "Kernel flags are not correctly handled." +#endif + +__global__ void relu_kernel(float *__restrict__ out, + float const *__restrict__ input, const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + auto x = input[token_idx * d + idx]; + out[token_idx * d + idx] = x > 0.0f ? x : 0.0f; + } +} + +void relu(torch::Tensor &out, torch::Tensor const &input) { + TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Float && + input.scalar_type() == at::ScalarType::Float, + "relu_kernel only supports float32"); + + TORCH_CHECK(input.sizes() == out.sizes(), + "Tensors must have the same shape. Got input shape: ", + input.sizes(), " and output shape: ", out.sizes()); + + TORCH_CHECK(input.scalar_type() == out.scalar_type(), + "Tensors must have the same data type. Got input dtype: ", + input.scalar_type(), " and output dtype: ", out.scalar_type()); + + TORCH_CHECK(input.device() == out.device(), + "Tensors must be on the same device. Got input device: ", + input.device(), " and output device: ", out.device()); + + int d = input.size(-1); + int64_t num_tokens = input.numel() / d; + dim3 grid(num_tokens); + dim3 block(std::min(d, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + relu_kernel<<>>(out.data_ptr(), + input.data_ptr(), d); +} diff --git a/examples/relu-compiler-flags/relu_xpu/relu.cpp b/examples/relu-compiler-flags/relu_xpu/relu.cpp new file mode 100644 index 00000000..3be30957 --- /dev/null +++ b/examples/relu-compiler-flags/relu_xpu/relu.cpp @@ -0,0 +1,44 @@ +#include +#include + +using namespace sycl; + +#ifndef WHO_AM_I_IF_NOT_THE_CANARY +#error "Kernel flags are not correctly handled." +#endif + +void relu_xpu_impl(torch::Tensor& output, const torch::Tensor& input) { + // Create SYCL queue directly + sycl::queue queue; + + auto input_ptr = input.data_ptr(); + auto output_ptr = output.data_ptr(); + auto numel = input.numel(); + + // Launch SYCL kernel + queue.parallel_for(range<1>(numel), [=](id<1> idx) { + auto i = idx[0]; + output_ptr[i] = input_ptr[i] > 0.0f ? input_ptr[i] : 0.0f; + }).wait(); +} + +void relu(torch::Tensor& out, const torch::Tensor& input) { + TORCH_CHECK(input.device().is_xpu(), "input must be a XPU tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(input.scalar_type() == torch::kFloat, + "Unsupported data type: ", input.scalar_type()); + + TORCH_CHECK(input.sizes() == out.sizes(), + "Tensors must have the same shape. Got input shape: ", + input.sizes(), " and output shape: ", out.sizes()); + + TORCH_CHECK(input.scalar_type() == out.scalar_type(), + "Tensors must have the same data type. Got input dtype: ", + input.scalar_type(), " and output dtype: ", out.scalar_type()); + + TORCH_CHECK(input.device() == out.device(), + "Tensors must be on the same device. Got input device: ", + input.device(), " and output device: ", out.device()); + + relu_xpu_impl(out, input); +} diff --git a/examples/relu-compiler-flags/torch-ext/relu/__init__.py b/examples/relu-compiler-flags/torch-ext/relu/__init__.py new file mode 100644 index 00000000..8050dfd7 --- /dev/null +++ b/examples/relu-compiler-flags/torch-ext/relu/__init__.py @@ -0,0 +1,12 @@ +from typing import Optional + +import torch + +from ._ops import ops + + +def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + if out is None: + out = torch.empty_like(x) + ops.relu(out, x) + return out \ No newline at end of file diff --git a/examples/relu-compiler-flags/torch-ext/torch_binding.cpp b/examples/relu-compiler-flags/torch-ext/torch_binding.cpp new file mode 100644 index 00000000..8b50483a --- /dev/null +++ b/examples/relu-compiler-flags/torch-ext/torch_binding.cpp @@ -0,0 +1,17 @@ +#include + +#include "registration.h" +#include "torch_binding.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("relu(Tensor! out, Tensor input) -> ()"); +#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) + ops.impl("relu", torch::kCUDA, &relu); +#elif defined(METAL_KERNEL) + ops.impl("relu", torch::kMPS, relu); +#elif defined(XPU_KERNEL) + ops.impl("relu", torch::kXPU, &relu); +#endif +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file diff --git a/examples/relu-compiler-flags/torch-ext/torch_binding.h b/examples/relu-compiler-flags/torch-ext/torch_binding.h new file mode 100644 index 00000000..3bcf2904 --- /dev/null +++ b/examples/relu-compiler-flags/torch-ext/torch_binding.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +void relu(torch::Tensor &out, torch::Tensor const &input); \ No newline at end of file