From 36a884de7339418c93d1c4e488a261b8093bf89d Mon Sep 17 00:00:00 2001 From: shadeMe Date: Fri, 17 Oct 2025 11:30:23 +0200 Subject: [PATCH 1/2] feat: Add support for ROCm/HIP flags --- build2cmake/src/config/v2.rs | 4 +++- build2cmake/src/templates/cuda/kernel.cmake | 13 +++++++++++++ build2cmake/src/torch/cuda.rs | 16 ++++++++++++++-- examples/relu-compiler-flags/build.toml | 10 ++++------ 4 files changed, 34 insertions(+), 9 deletions(-) diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index 16385818..606f3177 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -109,6 +109,7 @@ pub enum Kernel { cxx_flags: Option>, depends: Vec, rocm_archs: Option>, + hip_flags: Option>, include: Option>, src: Vec, }, @@ -257,7 +258,7 @@ fn convert_kernels(v1_kernels: HashMap) -> Result) -> Result:{{ hip_flags }}>" + ) + endif() + endforeach() + {% endif %} + hip_archs_loose_intersection({{kernel_name}}_ARCHS "{{ rocm_archs|join(";") }}" "${ROCM_ARCHS}") message(STATUS "Archs for kernel {{kernel_name}}: {{ '${' + kernel_name + '_ARCHS}'}}") diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index 08d4fb51..65fdb082 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -321,7 +321,7 @@ pub fn render_kernel( .collect_vec() .join("\n"); - let (cuda_capabilities, rocm_archs, cuda_flags, cuda_minver) = match kernel { + let (cuda_capabilities, rocm_archs, cuda_flags, hip_flags, cuda_minver) = match kernel { Kernel::Cuda { cuda_capabilities, cuda_flags, @@ -331,9 +331,20 @@ pub fn render_kernel( cuda_capabilities.as_deref(), None, cuda_flags.as_deref(), + None, cuda_minver.as_ref(), ), - Kernel::Rocm { rocm_archs, .. } => (None, rocm_archs.as_deref(), None, None), + Kernel::Rocm { + rocm_archs, + hip_flags, + .. + } => ( + None, + rocm_archs.as_deref(), + None, + hip_flags.as_deref(), + None, + ), _ => unreachable!("Unsupported kernel type for CUDA rendering"), }; @@ -346,6 +357,7 @@ pub fn render_kernel( cuda_minver => cuda_minver.map(ToString::to_string), cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")), rocm_archs => rocm_archs, + hip_flags => hip_flags.map(|flags| flags.join(";")), includes => kernel.include().map(prefix_and_join_includes), kernel_name => kernel_name, supports_hipify => matches!(kernel, Kernel::Rocm{ .. }), diff --git a/examples/relu-compiler-flags/build.toml b/examples/relu-compiler-flags/build.toml index e493b3e8..e99d7ae3 100644 --- a/examples/relu-compiler-flags/build.toml +++ b/examples/relu-compiler-flags/build.toml @@ -3,16 +3,13 @@ name = "relu" universal = false [torch] -src = [ - "torch-ext/torch_binding.cpp", - "torch-ext/torch_binding.h", -] +src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"] [kernel.activation] backend = "cuda" depends = ["torch"] src = ["relu_cuda/relu.cu"] -cuda-flags = [ "-DWHO_AM_I_IF_NOT_THE_CANARY" ] +cuda-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"] [kernel.activation_rocm] backend = "rocm" @@ -29,9 +26,10 @@ rocm-archs = [ ] depends = ["torch"] src = ["relu_cuda/relu.cu"] +hip-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"] [kernel.activation_xpu] backend = "xpu" depends = ["torch"] src = ["relu_xpu/relu.cpp"] -sycl-flags = [ "-DWHO_AM_I_IF_NOT_THE_CANARY" ] +sycl-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"] From 1de6d4f84974d7f9456420fc83641066c3311f48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 17 Oct 2025 09:55:52 +0000 Subject: [PATCH 2/2] CI: add relu-compiler-flags to ROCm build --- .github/workflows/build_kernel_rocm.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build_kernel_rocm.yaml b/.github/workflows/build_kernel_rocm.yaml index 9e46b7ec..f03774e3 100644 --- a/.github/workflows/build_kernel_rocm.yaml +++ b/.github/workflows/build_kernel_rocm.yaml @@ -27,3 +27,6 @@ jobs: # kernels. Also run tests once we have a ROCm runner. - name: Build relu kernel run: ( cd examples/relu && nix build .\#redistributable.torch29-cxx11-rocm63-x86_64-linux -L ) + + - name: Build relu kernel (compiler flags) + run: ( cd examples/relu-compiler-flags && nix build .\#redistributable.torch29-cxx11-rocm63-x86_64-linux )