Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/build_kernel_rocm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ jobs:
# kernels. Also run tests once we have a ROCm runner.
- name: Build relu kernel
run: ( cd examples/relu && nix build .\#redistributable.torch29-cxx11-rocm63-x86_64-linux -L )

- name: Build relu kernel (compiler flags)
run: ( cd examples/relu-compiler-flags && nix build .\#redistributable.torch29-cxx11-rocm63-x86_64-linux )
4 changes: 3 additions & 1 deletion build2cmake/src/config/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ pub enum Kernel {
cxx_flags: Option<Vec<String>>,
depends: Vec<Dependencies>,
rocm_archs: Option<Vec<String>>,
hip_flags: Option<Vec<String>>,
include: Option<Vec<String>>,
src: Vec<String>,
},
Expand Down Expand Up @@ -257,7 +258,7 @@ fn convert_kernels(v1_kernels: HashMap<String, v1::Kernel>) -> Result<HashMap<St

for (name, kernel) in v1_kernels {
if kernel.language == Language::CudaHipify {
// We need to add an affix to avoid confflict with the CUDA kernel.
// We need to add an affix to avoid conflict with the CUDA kernel.
let rocm_name = format!("{name}_rocm");
if kernels.contains_key(&rocm_name) {
bail!("Found an existing kernel with name `{rocm_name}` while expanding `{name}`")
Expand All @@ -268,6 +269,7 @@ fn convert_kernels(v1_kernels: HashMap<String, v1::Kernel>) -> Result<HashMap<St
Kernel::Rocm {
cxx_flags: None,
rocm_archs: kernel.rocm_archs,
hip_flags: None,
depends: kernel.depends.clone(),
include: kernel.include.clone(),
src: kernel.src.clone(),
Expand Down
13 changes: 13 additions & 0 deletions build2cmake/src/templates/cuda/kernel.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ if(GPU_LANG STREQUAL "CUDA")
list(APPEND SRC {{'"${' + kernel_name + '_SRC}"'}})
{% if supports_hipify %}
elseif(GPU_LANG STREQUAL "HIP")
{% if hip_flags %}

foreach(_KERNEL_SRC {{'${' + kernel_name + '_SRC}'}})
if(_KERNEL_SRC MATCHES ".*\\.(cu|hip)$")
set_property(
SOURCE ${_KERNEL_SRC}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:HIP>:{{ hip_flags }}>"
)
endif()
endforeach()
{% endif %}

hip_archs_loose_intersection({{kernel_name}}_ARCHS "{{ rocm_archs|join(";") }}" "${ROCM_ARCHS}")
message(STATUS "Archs for kernel {{kernel_name}}: {{ '${' + kernel_name + '_ARCHS}'}}")

Expand Down
16 changes: 14 additions & 2 deletions build2cmake/src/torch/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ pub fn render_kernel(
.collect_vec()
.join("\n");

let (cuda_capabilities, rocm_archs, cuda_flags, cuda_minver) = match kernel {
let (cuda_capabilities, rocm_archs, cuda_flags, hip_flags, cuda_minver) = match kernel {
Kernel::Cuda {
cuda_capabilities,
cuda_flags,
Expand All @@ -331,9 +331,20 @@ pub fn render_kernel(
cuda_capabilities.as_deref(),
None,
cuda_flags.as_deref(),
None,
cuda_minver.as_ref(),
),
Kernel::Rocm { rocm_archs, .. } => (None, rocm_archs.as_deref(), None, None),
Kernel::Rocm {
rocm_archs,
hip_flags,
..
} => (
None,
rocm_archs.as_deref(),
None,
hip_flags.as_deref(),
None,
),
_ => unreachable!("Unsupported kernel type for CUDA rendering"),
};

Expand All @@ -346,6 +357,7 @@ pub fn render_kernel(
cuda_minver => cuda_minver.map(ToString::to_string),
cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")),
rocm_archs => rocm_archs,
hip_flags => hip_flags.map(|flags| flags.join(";")),
includes => kernel.include().map(prefix_and_join_includes),
kernel_name => kernel_name,
supports_hipify => matches!(kernel, Kernel::Rocm{ .. }),
Expand Down
10 changes: 4 additions & 6 deletions examples/relu-compiler-flags/build.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@ name = "relu"
universal = false

[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h",
]
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]

[kernel.activation]
backend = "cuda"
depends = ["torch"]
src = ["relu_cuda/relu.cu"]
cuda-flags = [ "-DWHO_AM_I_IF_NOT_THE_CANARY" ]
cuda-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"]

[kernel.activation_rocm]
backend = "rocm"
Expand All @@ -29,9 +26,10 @@ rocm-archs = [
]
depends = ["torch"]
src = ["relu_cuda/relu.cu"]
hip-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"]

[kernel.activation_xpu]
backend = "xpu"
depends = ["torch"]
src = ["relu_xpu/relu.cpp"]
sycl-flags = [ "-DWHO_AM_I_IF_NOT_THE_CANARY" ]
sycl-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"]