Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/build_kernel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ jobs:
- name: Build relu kernel (specific Torch version)
run: ( cd examples/relu-specific-torch && nix build . )

- name: Build relu kernel (compiler flags)
run: ( cd examples/relu-compiler-flags && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )

- name: Test that we can build a test shell (e.g. that gcc corresponds to CUDA-required)
run: ( cd examples/relu && nix build .#devShells.x86_64-linux.test )

Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/build_kernel_xpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ jobs:
# kernels. Also run tests once we have a XPU runner.
- name: Build relu kernel
run: ( cd examples/relu && nix build .\#redistributable.torch29-cxx11-xpu20252-x86_64-linux -L )

- name: Build relu kernel (compiler flags)
run: ( cd examples/relu-compiler-flags && nix build .\#redistributable.torch29-cxx11-xpu20252-x86_64-linux )
37 changes: 37 additions & 0 deletions examples/relu-compiler-flags/build.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
[general]
name = "relu"
universal = false

[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h",
]

[kernel.activation]
backend = "cuda"
depends = ["torch"]
src = ["relu_cuda/relu.cu"]
cuda-flags = [ "-DWHO_AM_I_IF_NOT_THE_CANARY" ]

[kernel.activation_rocm]
backend = "rocm"
rocm-archs = [
"gfx906",
"gfx908",
"gfx90a",
"gfx940",
"gfx941",
"gfx942",
"gfx1030",
"gfx1100",
"gfx1101",
]
depends = ["torch"]
src = ["relu_cuda/relu.cu"]

[kernel.activation_xpu]
backend = "xpu"
depends = ["torch"]
src = ["relu_xpu/relu.cpp"]
sycl-flags = [ "-DWHO_AM_I_IF_NOT_THE_CANARY" ]
17 changes: 17 additions & 0 deletions examples/relu-compiler-flags/flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
description = "Flake for ReLU kernel";

inputs = {
kernel-builder.url = "path:../..";
};

outputs =
{
self,
kernel-builder,
}:
kernel-builder.lib.genFlakeOutputs {
inherit self;
path = ./.;
};
}
47 changes: 47 additions & 0 deletions examples/relu-compiler-flags/relu_cuda/relu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>

#include <cmath>

#ifndef WHO_AM_I_IF_NOT_THE_CANARY
#error "Kernel flags are not correctly handled."
#endif

__global__ void relu_kernel(float *__restrict__ out,
float const *__restrict__ input, const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
auto x = input[token_idx * d + idx];
out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
}
}

void relu(torch::Tensor &out, torch::Tensor const &input) {
TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float &&
input.scalar_type() == at::ScalarType::Float,
"relu_kernel only supports float32");

TORCH_CHECK(input.sizes() == out.sizes(),
"Tensors must have the same shape. Got input shape: ",
input.sizes(), " and output shape: ", out.sizes());

TORCH_CHECK(input.scalar_type() == out.scalar_type(),
"Tensors must have the same data type. Got input dtype: ",
input.scalar_type(), " and output dtype: ", out.scalar_type());

TORCH_CHECK(input.device() == out.device(),
"Tensors must be on the same device. Got input device: ",
input.device(), " and output device: ", out.device());

int d = input.size(-1);
int64_t num_tokens = input.numel() / d;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
relu_kernel<<<grid, block, 0, stream>>>(out.data_ptr<float>(),
input.data_ptr<float>(), d);
}
44 changes: 44 additions & 0 deletions examples/relu-compiler-flags/relu_xpu/relu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include <sycl/sycl.hpp>
#include <torch/torch.h>

using namespace sycl;

#ifndef WHO_AM_I_IF_NOT_THE_CANARY
#error "Kernel flags are not correctly handled."
#endif

void relu_xpu_impl(torch::Tensor& output, const torch::Tensor& input) {
// Create SYCL queue directly
sycl::queue queue;

auto input_ptr = input.data_ptr<float>();
auto output_ptr = output.data_ptr<float>();
auto numel = input.numel();

// Launch SYCL kernel
queue.parallel_for(range<1>(numel), [=](id<1> idx) {
auto i = idx[0];
output_ptr[i] = input_ptr[i] > 0.0f ? input_ptr[i] : 0.0f;
}).wait();
}

void relu(torch::Tensor& out, const torch::Tensor& input) {
TORCH_CHECK(input.device().is_xpu(), "input must be a XPU tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(input.scalar_type() == torch::kFloat,
"Unsupported data type: ", input.scalar_type());

TORCH_CHECK(input.sizes() == out.sizes(),
"Tensors must have the same shape. Got input shape: ",
input.sizes(), " and output shape: ", out.sizes());

TORCH_CHECK(input.scalar_type() == out.scalar_type(),
"Tensors must have the same data type. Got input dtype: ",
input.scalar_type(), " and output dtype: ", out.scalar_type());

TORCH_CHECK(input.device() == out.device(),
"Tensors must be on the same device. Got input device: ",
input.device(), " and output device: ", out.device());

relu_xpu_impl(out, input);
}
12 changes: 12 additions & 0 deletions examples/relu-compiler-flags/torch-ext/relu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Optional

import torch

from ._ops import ops


def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
if out is None:
out = torch.empty_like(x)
ops.relu(out, x)
return out
17 changes: 17 additions & 0 deletions examples/relu-compiler-flags/torch-ext/torch_binding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("relu(Tensor! out, Tensor input) -> ()");
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
ops.impl("relu", torch::kCUDA, &relu);
#elif defined(METAL_KERNEL)
ops.impl("relu", torch::kMPS, relu);
#elif defined(XPU_KERNEL)
ops.impl("relu", torch::kXPU, &relu);
#endif
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
5 changes: 5 additions & 0 deletions examples/relu-compiler-flags/torch-ext/torch_binding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#include <torch/torch.h>

void relu(torch::Tensor &out, torch::Tensor const &input);
Loading