diff --git a/.github/workflows/build_kernel.yaml b/.github/workflows/build_kernel.yaml index 5a7f9462..922bd5f5 100644 --- a/.github/workflows/build_kernel.yaml +++ b/.github/workflows/build_kernel.yaml @@ -33,6 +33,11 @@ jobs: - name: Copy relu kernel run: cp -rL examples/relu/result relu-kernel + - name: Build relu kernel (CPU) + run: ( cd examples/relu && nix build .\#redistributable.torch29-cxx11-cpu-x86_64-linux ) + - name: Copy relu kernel (CPU) + run: cp -rL examples/relu/result relu-kernel-cpu + - name: Build cutlass GEMM kernel run: ( cd examples/cutlass-gemm && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux ) - name: Copy cutlass GEMM kernel @@ -66,6 +71,7 @@ jobs: activation-kernel cutlass-gemm-kernel relu-kernel + relu-kernel-cpu relu-backprop-compile-kernel silu-and-mul-universal-kernel diff --git a/build-variants.json b/build-variants.json index 86e49cdc..4409d1ed 100644 --- a/build-variants.json +++ b/build-variants.json @@ -1,11 +1,19 @@ { "aarch64-darwin": { + "cpu": [ + "torch28-cpu-aarch64-darwin", + "torch29-cpu-aarch64-darwin" + ], "metal": [ "torch28-metal-aarch64-darwin", "torch29-metal-aarch64-darwin" ] }, "aarch64-linux": { + "cpu": [ + "torch28-cxx11-cpu-aarch64-linux", + "torch29-cxx11-cpu-aarch64-linux" + ], "cuda": [ "torch28-cxx11-cu129-aarch64-linux", "torch29-cxx11-cu126-aarch64-linux", @@ -14,6 +22,10 @@ ] }, "x86_64-linux": { + "cpu": [ + "torch28-cxx11-cpu-x86_64-linux", + "torch29-cxx11-cpu-x86_64-linux" + ], "cuda": [ "torch28-cxx11-cu126-x86_64-linux", "torch28-cxx11-cu128-x86_64-linux", diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index b4dfa075..ecbdd9ec 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -32,6 +32,7 @@ impl Build { self.kernels .values() .map(|kernel| match kernel { + Kernel::Cpu { .. } => Backend::Cpu, Kernel::Cuda { .. } => Backend::Cuda, Kernel::Metal { .. } => Backend::Metal, Kernel::Rocm { .. } => Backend::Rocm, @@ -96,6 +97,13 @@ impl Torch { #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case", tag = "backend")] pub enum Kernel { + #[serde(rename_all = "kebab-case")] + Cpu { + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, #[serde(rename_all = "kebab-case")] Cuda { cuda_capabilities: Option>, @@ -135,7 +143,8 @@ pub enum Kernel { impl Kernel { pub fn cxx_flags(&self) -> Option<&[String]> { match self { - Kernel::Cuda { cxx_flags, .. } + Kernel::Cpu { cxx_flags, .. } + | Kernel::Cuda { cxx_flags, .. } | Kernel::Metal { cxx_flags, .. } | Kernel::Rocm { cxx_flags, .. } | Kernel::Xpu { cxx_flags, .. } => cxx_flags.as_deref(), @@ -144,7 +153,8 @@ impl Kernel { pub fn include(&self) -> Option<&[String]> { match self { - Kernel::Cuda { include, .. } + Kernel::Cpu { include, .. } + | Kernel::Cuda { include, .. } | Kernel::Metal { include, .. } | Kernel::Rocm { include, .. } | Kernel::Xpu { include, .. } => include.as_deref(), @@ -153,6 +163,7 @@ impl Kernel { pub fn backend(&self) -> Backend { match self { + Kernel::Cpu { .. } => Backend::Cpu, Kernel::Cuda { .. } => Backend::Cuda, Kernel::Metal { .. } => Backend::Metal, Kernel::Rocm { .. } => Backend::Rocm, @@ -162,7 +173,8 @@ impl Kernel { pub fn depends(&self) -> &[Dependencies] { match self { - Kernel::Cuda { depends, .. } + Kernel::Cpu { depends, .. } + | Kernel::Cuda { depends, .. } | Kernel::Metal { depends, .. } | Kernel::Rocm { depends, .. } | Kernel::Xpu { depends, .. } => depends, @@ -171,7 +183,8 @@ impl Kernel { pub fn src(&self) -> &[String] { match self { - Kernel::Cuda { src, .. } + Kernel::Cpu { src, .. } + | Kernel::Cuda { src, .. } | Kernel::Metal { src, .. } | Kernel::Rocm { src, .. } | Kernel::Xpu { src, .. } => src, @@ -182,6 +195,7 @@ impl Kernel { #[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] pub enum Backend { + Cpu, Cuda, Metal, Rocm, @@ -191,6 +205,7 @@ pub enum Backend { impl Display for Backend { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Backend::Cpu => write!(f, "cpu"), Backend::Cuda => write!(f, "cuda"), Backend::Metal => write!(f, "metal"), Backend::Rocm => write!(f, "rocm"), @@ -204,6 +219,7 @@ impl FromStr for Backend { fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { + "cpu" => Ok(Backend::Cpu), "cuda" => Ok(Backend::Cuda), "metal" => Ok(Backend::Metal), "rocm" => Ok(Backend::Rocm), diff --git a/build2cmake/src/main.rs b/build2cmake/src/main.rs index 3d9be9e2..528b37ec 100644 --- a/build2cmake/src/main.rs +++ b/build2cmake/src/main.rs @@ -10,7 +10,8 @@ use minijinja::Environment; mod torch; use torch::{ - write_torch_ext_cuda, write_torch_ext_metal, write_torch_ext_universal, write_torch_ext_xpu, + write_torch_ext_cpu, write_torch_ext_cuda, write_torch_ext_metal, write_torch_ext_universal, + write_torch_ext_xpu, }; mod config; @@ -178,6 +179,7 @@ fn generate_torch( }; let file_set = match backend { + Backend::Cpu => write_torch_ext_cpu(&env, &build, target_dir.clone(), ops_id)?, Backend::Cuda | Backend::Rocm => { write_torch_ext_cuda(&env, backend, &build, target_dir.clone(), ops_id)? } @@ -376,6 +378,7 @@ fn get_generated_files( for backend in build.backends() { let set = match backend { + Backend::Cpu => write_torch_ext_cpu(env, build, target_dir.clone(), ops_id.clone())?, Backend::Cuda | Backend::Rocm => { write_torch_ext_cuda(env, backend, build, target_dir.clone(), ops_id.clone())? } diff --git a/build2cmake/src/templates/cpu/kernel.cmake b/build2cmake/src/templates/cpu/kernel.cmake new file mode 100644 index 00000000..2e9d2442 --- /dev/null +++ b/build2cmake/src/templates/cpu/kernel.cmake @@ -0,0 +1,24 @@ +set({{kernel_name}}_SRC + {{ sources }} +) + +{% if includes %} +# TODO: check if CLion support this: +# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories +set_source_files_properties( + {{'${' + kernel_name + '_SRC}'}} + PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}") +{% endif %} + +{% if cxx_flags %} +foreach(_KERNEL_SRC {{'${' + kernel_name + '_SRC}'}}) + set_property( + SOURCE ${_KERNEL_SRC} + APPEND PROPERTY + COMPILE_OPTIONS "$<$:{{ cxx_flags }}>" + ) +endforeach() +{% endif %} + +# Add C++ sources to main source list +list(APPEND SRC {{'"${' + kernel_name + '_SRC}"'}}) diff --git a/build2cmake/src/templates/cpu/preamble.cmake b/build2cmake/src/templates/cpu/preamble.cmake new file mode 100644 index 00000000..bbd064f2 --- /dev/null +++ b/build2cmake/src/templates/cpu/preamble.cmake @@ -0,0 +1,28 @@ +cmake_minimum_required(VERSION 3.26) +project({{name}} LANGUAGES CXX) + +set(CMAKE_OSX_DEPLOYMENT_TARGET "15.0" CACHE STRING "Minimum macOS deployment version") + +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) + +include(FetchContent) +file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists +message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") + +include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) + +if(DEFINED Python3_EXECUTABLE) + # Allow passing through the interpreter (e.g. from setup.py). + find_package(Python3 COMPONENTS Development Development.SABIModule Interpreter) + if (NOT Python3_FOUND) + message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") + endif() +else() + find_package(Python3 REQUIRED COMPONENTS Development Development.SABIModule Interpreter) +endif() + +append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") + +find_package(Torch REQUIRED) + +add_compile_definitions(CPU_KERNEL) diff --git a/build2cmake/src/templates/cpu/setup.py b/build2cmake/src/templates/cpu/setup.py new file mode 100644 index 00000000..08164f1a --- /dev/null +++ b/build2cmake/src/templates/cpu/setup.py @@ -0,0 +1,121 @@ +import logging +import os +from shutil import which, move +import subprocess +import sys +from pathlib import Path + +from setuptools import Extension, find_packages, setup +from setuptools.command.build_ext import build_ext + +logger = logging.getLogger(__name__) + + +def is_sccache_available() -> bool: + return which("sccache") is not None + + +def is_ccache_available() -> bool: + return which("ccache") is not None + + +def is_ninja_available() -> bool: + return which("ninja") is not None + + +class CMakeExtension(Extension): + def __init__(self, name: str, sourcedir: str = "") -> None: + super().__init__(name, sources=[], py_limited_api=True) + self.sourcedir = os.fspath(Path(sourcedir).resolve()) + + +class CMakeBuild(build_ext): + def build_extension(self, ext: CMakeExtension) -> None: + ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) + extdir = ext_fullpath.parent.resolve() + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + cmake_generator = os.environ.get("CMAKE_GENERATOR", "") + + # Set Python3_EXECUTABLE instead if you use PYBIND11_FINDPYTHON + # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code + # from Python. + cmake_args = [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", + f"-DPython3_EXECUTABLE={sys.executable}", + f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm + ] + build_args = [] + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + if not cmake_generator or cmake_generator == "Ninja": + try: + import ninja + + ninja_executable_path = Path(ninja.BIN_DIR) / "ninja" + cmake_args += [ + "-GNinja", + f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", + ] + except ImportError: + pass + + if is_sccache_available(): + cmake_args += [ + "-DCMAKE_C_COMPILER_LAUNCHER=sccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache", + "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache", + "-DCMAKE_OBJC_COMPILER_LAUNCHER=sccache", + "-DCMAKE_OBJCXX_COMPILER_LAUNCHER=sccache", + ] + elif is_ccache_available(): + cmake_args += [ + "-DCMAKE_C_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", + "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache", + "-DCMAKE_OBJC_COMPILER_LAUNCHER=ccache", + "-DCMAKE_OBJCXX_COMPILER_LAUNCHER=ccache", + ] + + num_jobs = os.getenv("MAX_JOBS", None) + if num_jobs is not None: + num_jobs = int(num_jobs) + logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) + else: + try: + # os.sched_getaffinity() isn't universally available, so fall + # back to os.cpu_count() if we get an error here. + num_jobs = len(os.sched_getaffinity(0)) + except AttributeError: + num_jobs = os.cpu_count() + + build_temp = Path(self.build_temp) / ext.name + if not build_temp.exists(): + build_temp.mkdir(parents=True) + + subprocess.run( + ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True + ) + subprocess.run( + ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True + ) + + +setup( + name="{{ name }}", + # The version is just a stub, it's not used by the final build artefact. + version="0.1.0", + ext_modules=[CMakeExtension("{{ name }}.{{ ops_name }}")], + cmdclass={"build_ext": CMakeBuild}, + packages=find_packages(where="torch-ext", include=["{{ name }}*"]), + package_dir={"": "torch-ext"}, +{% if data_globs %} + package_data={"{{ name }}": [ {{ data_globs }} ]}, +{% endif %} + zip_safe=False, + install_requires=["torch"], + python_requires=">=3.9", +) diff --git a/build2cmake/src/templates/cpu/torch-binding.cmake b/build2cmake/src/templates/cpu/torch-binding.cmake new file mode 100644 index 00000000..799fa0d3 --- /dev/null +++ b/build2cmake/src/templates/cpu/torch-binding.cmake @@ -0,0 +1,13 @@ +set(TORCH_{{name}}_SRC + {{ src|join(' ') }} +) + +{% if includes %} +# TODO: check if CLion support this: +# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories +set_source_files_properties( + {{'${TORCH_' + name + '_SRC}'}} + PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}") +{% endif %} + +list(APPEND SRC {{'"${TORCH_' + name + '_SRC}"'}}) diff --git a/build2cmake/src/templates/cpu/torch-extension.cmake b/build2cmake/src/templates/cpu/torch-extension.cmake new file mode 100644 index 00000000..9d667e4a --- /dev/null +++ b/build2cmake/src/templates/cpu/torch-extension.cmake @@ -0,0 +1,9 @@ +define_gpu_extension_target( + {{ ops_name }} + DESTINATION {{ ops_name }} + LANGUAGE ${GPU_LANG} + SOURCES ${SRC} + COMPILE_FLAGS ${GPU_FLAGS} + ARCHITECTURES ${GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) diff --git a/build2cmake/src/torch/cpu.rs b/build2cmake/src/torch/cpu.rs new file mode 100644 index 00000000..bad4ea38 --- /dev/null +++ b/build2cmake/src/torch/cpu.rs @@ -0,0 +1,271 @@ +use std::{io::Write, path::PathBuf}; + +use eyre::{bail, Context, Result}; +use itertools::Itertools; +use minijinja::{context, Environment}; + +use super::kernel_ops_identifier; +use crate::{ + config::{Build, Kernel, Torch}, + fileset::FileSet, +}; + +static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); +static REGISTRATION_H: &str = include_str!("../templates/registration.h"); + +pub fn write_torch_ext_cpu( + env: &Environment, + build: &Build, + target_dir: PathBuf, + ops_id: Option, +) -> Result { + let torch_ext = match build.torch.as_ref() { + Some(torch_ext) => torch_ext, + None => bail!("Build configuration does not have `torch` section"), + }; + + let mut file_set = FileSet::default(); + + let ops_name = kernel_ops_identifier(&target_dir, &build.general.name, ops_id); + + write_cmake( + env, + build, + torch_ext, + &build.general.name, + &ops_name, + &mut file_set, + )?; + + write_setup_py( + env, + torch_ext, + &build.general.name, + &ops_name, + &mut file_set, + )?; + + write_ops_py(env, &build.general.name, &ops_name, &mut file_set)?; + + write_pyproject_toml(env, &mut file_set)?; + + write_torch_registration_macros(&mut file_set)?; + + Ok(file_set) +} + +fn write_cmake( + env: &Environment, + build: &Build, + torch: &Torch, + name: &str, + ops_name: &str, + file_set: &mut FileSet, +) -> Result<()> { + let mut utils_path = PathBuf::new(); + utils_path.push("cmake"); + utils_path.push("utils.cmake"); + file_set + .entry(utils_path.clone()) + .extend_from_slice(CMAKE_UTILS.as_bytes()); + + let cmake_writer = file_set.entry("CMakeLists.txt"); + + render_preamble(env, name, cmake_writer)?; + + // Add deps once we have any non-CUDA deps. + // render_deps(env, build, cmake_writer)?; + + render_binding(env, torch, name, cmake_writer)?; + + for (kernel_name, kernel) in build + .kernels + .iter() + .filter(|(_, kernel)| matches!(kernel, Kernel::Cpu { .. })) + { + render_kernel(env, kernel_name, kernel, cmake_writer)?; + } + + render_extension(env, name, ops_name, cmake_writer)?; + + Ok(()) +} + +fn render_binding( + env: &Environment, + torch: &Torch, + name: &str, + write: &mut impl Write, +) -> Result<()> { + env.get_template("cpu/torch-binding.cmake") + .wrap_err("Cannot get Torch binding template")? + .render_to_write( + context! { + includes => torch.include.as_ref().map(prefix_and_join_includes), + name => name, + src => torch.src + }, + &mut *write, + ) + .wrap_err("Cannot render Torch binding template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +pub fn render_extension( + env: &Environment, + name: &str, + ops_name: &str, + write: &mut impl Write, +) -> Result<()> { + env.get_template("cpu/torch-extension.cmake") + .wrap_err("Cannot get Torch extension template")? + .render_to_write( + context! { + name => name, + ops_name => ops_name, + }, + &mut *write, + ) + .wrap_err("Cannot render Torch extension template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +pub fn render_kernel( + env: &Environment, + kernel_name: &str, + kernel: &Kernel, + write: &mut impl Write, +) -> Result<()> { + // Easier to do in Rust than Jinja. + let sources = kernel + .src() + .iter() + .map(|src| format!("\"{src}\"")) + .collect_vec() + .join("\n"); + + env.get_template("cpu/kernel.cmake") + .wrap_err("Cannot get kernel template")? + .render_to_write( + context! { + cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")), + includes => kernel.include().map(prefix_and_join_includes), + kernel_name => kernel_name, + sources => sources, + }, + &mut *write, + ) + .wrap_err("Cannot render kernel template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> { + env.get_template("cpu/preamble.cmake") + .wrap_err("Cannot get CMake prelude template")? + .render_to_write( + context! { + name => name, + }, + &mut *write, + ) + .wrap_err("Cannot render CMake prelude template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +fn write_ops_py( + env: &Environment, + name: &str, + ops_name: &str, + file_set: &mut FileSet, +) -> Result<()> { + let mut path = PathBuf::new(); + path.push("torch-ext"); + path.push(name); + path.push("_ops.py"); + let writer = file_set.entry(path); + + env.get_template("_ops.py") + .wrap_err("Cannot get _ops.py template")? + .render_to_write( + context! { + ops_name => ops_name, + }, + writer, + ) + .wrap_err("Cannot render kernel template")?; + + Ok(()) +} + +fn write_pyproject_toml(env: &Environment, file_set: &mut FileSet) -> Result<()> { + let writer = file_set.entry("pyproject.toml"); + + env.get_template("pyproject.toml") + .wrap_err("Cannot get pyproject.toml template")? + .render_to_write(context! {}, writer) + .wrap_err("Cannot render kernel template")?; + + Ok(()) +} + +fn write_setup_py( + env: &Environment, + torch: &Torch, + name: &str, + ops_name: &str, + file_set: &mut FileSet, +) -> Result<()> { + let writer = file_set.entry("setup.py"); + + let data_globs = torch.data_globs().map(|globs| globs.join(", ")); + + env.get_template("cpu/setup.py") + .wrap_err("Cannot get setup.py template")? + .render_to_write( + context! { + data_globs => data_globs, + ops_name => ops_name, + name => name, + version => "0.1.0", + }, + writer, + ) + .wrap_err("Cannot render kernel template")?; + + Ok(()) +} + +fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> { + let mut path = PathBuf::new(); + path.push("torch-ext"); + path.push("registration.h"); + file_set + .entry(path) + .extend_from_slice(REGISTRATION_H.as_bytes()); + + Ok(()) +} + +fn prefix_and_join_includes(includes: impl AsRef<[S]>) -> String +where + S: AsRef, +{ + includes + .as_ref() + .iter() + .map(|include| format!("${{CMAKE_SOURCE_DIR}}/{}", include.as_ref())) + .collect_vec() + .join(";") +} diff --git a/build2cmake/src/torch/mod.rs b/build2cmake/src/torch/mod.rs index d3896938..3637c53a 100644 --- a/build2cmake/src/torch/mod.rs +++ b/build2cmake/src/torch/mod.rs @@ -1,3 +1,6 @@ +mod cpu; +pub use cpu::write_torch_ext_cpu; + mod cuda; pub use cuda::write_torch_ext_cuda; diff --git a/docs/build-variants.md b/docs/build-variants.md index 39b5c975..c691958b 100644 --- a/docs/build-variants.md +++ b/docs/build-variants.md @@ -5,11 +5,21 @@ architecture (e.g. x86_64). For compliance with a compute framework and architecture combination, all the build variants listed below must be available. This list will be updated as new PyTorch versions are released. +## CPU aarch64-darwin + +- `torch28-cpu-aarch64-darwin` +- `torch29-cpu-aarch64-darwin` + ## Metal aarch64-darwin - `torch28-metal-aarch64-darwin` - `torch29-metal-aarch64-darwin` +## CPU aarch64-linux + +- `torch28-cxx11-cpu-aarch64-linux` +- `torch29-cxx11-cpu-aarch64-linux` + ## CUDA aarch64-linux - `torch28-cxx11-cu129-aarch64-linux` @@ -17,6 +27,11 @@ available. This list will be updated as new PyTorch versions are released. - `torch29-cxx11-cu128-aarch64-linux` - `torch29-cxx11-cu130-aarch64-linux` +## CPU x86_64-linux + +- `torch28-cxx11-cpu-x86_64-linux` +- `torch29-cxx11-cpu-x86_64-linux` + ## CUDA x86_64-linux - `torch28-cxx11-cu126-x86_64-linux` diff --git a/examples/relu/build.toml b/examples/relu/build.toml index cfe490ba..84eb068a 100644 --- a/examples/relu/build.toml +++ b/examples/relu/build.toml @@ -42,3 +42,8 @@ src = ["relu_cuda/relu.cu"] backend = "xpu" depends = ["torch"] src = ["relu_xpu/relu.cpp"] + +[kernel.relu_cpu] +backend = "cpu" +depends = ["torch"] +src = ["relu_cpu/relu_cpu.cpp"] diff --git a/examples/relu/relu_cpu/relu_cpu.cpp b/examples/relu/relu_cpu/relu_cpu.cpp new file mode 100644 index 00000000..6197a9f1 --- /dev/null +++ b/examples/relu/relu_cpu/relu_cpu.cpp @@ -0,0 +1,56 @@ +#include + +#ifdef __SSE__ +#include +#endif + +#ifdef __ARM_NEON +#include +#endif + +#ifdef __SSE__ +void relu_forward_sse(float* out, const float* input, size_t size) { + size_t i = 0; + + for (; i + 4 <= size; i += 4) { + __m128 vec_input = _mm_load_ps(input + i); + __m128 vec_zero = _mm_setzero_ps(); + __m128 vec_output = _mm_max_ps(vec_input, vec_zero); + _mm_store_ps(out + i, vec_output); + } + + for (; i < size; ++i) { + out[i] = input[i] > 0 ? input[i] : 0; + } +} +#endif + +#ifdef __ARM_NEON +void relu_forward_neon(float* out, const float* input, size_t size) { + size_t i = 0; + + for (; i + 4 <= size; i += 4) { + float32x4_t vec_input = vld1q_f32(input + i); + float32x4_t vec_output = vmaxq_f32(vec_input, vdupq_n_f32(0)); + vst1q_f32(out + i, vec_output); + } + + for (; i < size; ++i) { + out[i] = input[i] > 0 ? input[i] : 0; + } +} +#endif + +void relu(torch::Tensor &out, torch::Tensor const &input) { + TORCH_CHECK(out.dtype() == torch::kFloat32, "Output tensor must be of dtype float"); + TORCH_CHECK(input.dtype() == torch::kFloat32, "Input tensor must be of dtype float"); + TORCH_CHECK(out.numel() == input.numel(), "Input and output tensors must have the same number of elements"); + +#if defined(__SSE__) + relu_forward_sse(out.data_ptr(), input.data_ptr(), input.numel()); +#elif defined(__ARM_NEON) + relu_forward_neon(out.data_ptr(), input.data_ptr(), input.numel()); +#else + #error "Unsupported architecture; please use a CPU with SSE or ARM NEON support." +#endif +} diff --git a/examples/relu/tests/test_relu.py b/examples/relu/tests/test_relu.py index d2adc057..65544aa4 100644 --- a/examples/relu/tests/test_relu.py +++ b/examples/relu/tests/test_relu.py @@ -11,7 +11,9 @@ def test_relu(): device = torch.device("mps") elif hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device("xpu") - else: + elif torch.version.cuda is not None and torch.cuda.is_available(): device = torch.device("cuda") + else: + device = torch.device("cpu") x = torch.randn(1024, 1024, dtype=torch.float32, device=device) torch.testing.assert_allclose(F.relu(x), relu.relu(x)) diff --git a/examples/relu/torch-ext/torch_binding.cpp b/examples/relu/torch-ext/torch_binding.cpp index 8b50483a..1765d92d 100644 --- a/examples/relu/torch-ext/torch_binding.cpp +++ b/examples/relu/torch-ext/torch_binding.cpp @@ -5,7 +5,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("relu(Tensor! out, Tensor input) -> ()"); -#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) +#if defined(CPU_KERNEL) + ops.impl("relu", torch::kCPU, &relu); +#elif defined(CUDA_KERNEL) || defined(ROCM_KERNEL) ops.impl("relu", torch::kCUDA, &relu); #elif defined(METAL_KERNEL) ops.impl("relu", torch::kMPS, relu); @@ -14,4 +16,4 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #endif } -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/flake.lock b/flake.lock index 3a4a185b..f1af9f5a 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "flake-compat": { "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", "owner": "edolstra", "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", "type": "github" }, "original": { @@ -73,11 +73,11 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1760814603, - "narHash": "sha256-i5uuhnJPxOrd0dC8+btp31WMfzPDL8Uwz0TPG2n6nHE=", + "lastModified": 1761756835, + "narHash": "sha256-Vjrv8ZIhkQRgQ3MHGVFaj/fUcE4yuGr+vnoKYRwWmYw=", "owner": "huggingface", "repo": "hf-nix", - "rev": "c0b62ec3d0abb11dd2d960e3dfee3a46fc46d111", + "rev": "6839b6998be18679992978c2f3abddc902276280", "type": "github" }, "original": { diff --git a/lib/build-sets.nix b/lib/build-sets.nix index 8d407471..c5dcc901 100644 --- a/lib/build-sets.nix +++ b/lib/build-sets.nix @@ -12,6 +12,7 @@ let inherit (import ./torch-version-utils.nix { inherit lib; }) flattenSystems + isCpu isCuda isMetal isRocm @@ -63,6 +64,7 @@ let # Construct the nixpkgs package set for the given versions. pkgsForVersions = buildConfig@{ + cpu ? false, cudaVersion ? null, metal ? false, rocmVersion ? null, @@ -75,7 +77,9 @@ let }: let pkgs = - if isCuda buildConfig then + if isCpu buildConfig then + pkgsForCpu + else if isCuda buildConfig then pkgsByCudaVer.${cudaVersion} else if isRocm buildConfig then pkgsByRocmVer.${rocmVersion} @@ -126,8 +130,13 @@ let ); pkgsByXpuVer = pkgsForXpuVersions xpuVersions; - pkgsForMetal = import nixpkgs { + pkgsForMetal = pkgsForCpu; + + pkgsForCpu = import nixpkgs { inherit system; + config = { + allowUnfree = true; + }; overlays = [ hf-nix overlay diff --git a/lib/build-variants.nix b/lib/build-variants.nix index 5f614221..98bb834d 100644 --- a/lib/build-variants.nix +++ b/lib/build-variants.nix @@ -2,6 +2,7 @@ let inherit (import ./torch-version-utils.nix { inherit lib; }) flattenSystems + isCpu isCuda isMetal isRocm @@ -11,23 +12,27 @@ in rec { computeFramework = buildConfig: - if buildConfig ? cudaVersion then + if buildConfig.cpu or false then + "cpu" + else if buildConfig ? cudaVersion then "cuda" - else if buildConfig ? metal then + else if buildConfig.metal or false then "metal" else if buildConfig ? "rocmVersion" then "rocm" else if buildConfig ? xpuVersion then "xpu" else - throw "Could not find compute framework: no CUDA, ROCm, XPU version specified and Metal is not enabled"; + throw "Could not find compute framework: no CUDA, ROCm, XPU version specified and CPU and Metal are not enabled"; buildName = let inherit (import ./version-utils.nix { inherit lib; }) abiString flattenVersion; computeString = version: - if isCuda version then + if isCpu version then + "cpu" + else if isCuda version then "cu${flattenVersion (lib.versions.majorMinor version.cudaVersion)}" else if isRocm version then "rocm${flattenVersion (lib.versions.majorMinor version.rocmVersion)}" diff --git a/lib/build.nix b/lib/build.nix index ce8edc73..dfcdb72a 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -17,6 +17,7 @@ let builtins.readFile ../build2cmake/src/cuda_supported_archs.json ); inherit (import ./torch-version-utils.nix { inherit lib; }) + isCpu isCuda isMetal isRocm @@ -48,6 +49,7 @@ rec { kernels = lib.attrValues (buildToml.kernel or { }); kernelBackend = kernel: kernel.backend; init = { + cpu = false; cuda = false; metal = false; rocm = false; @@ -79,7 +81,8 @@ rec { buildSet: let backendSupported = - (isCuda buildSet.buildConfig && backends'.cuda) + (isCpu buildSet.buildConfig && backends'.cpu) + || (isCuda buildSet.buildConfig && backends'.cuda) || (isRocm buildSet.buildConfig && backends'.rocm) || (isMetal buildSet.buildConfig && backends'.metal) || (isXpu buildSet.buildConfig && backends'.xpu) @@ -146,6 +149,7 @@ rec { else extension.mkExtension { inherit + buildConfig doGetKernelCheck extraDeps nvccThreads diff --git a/lib/torch-extension/arch.nix b/lib/torch-extension/arch.nix index fb06d0d1..20e72be7 100644 --- a/lib/torch-extension/arch.nix +++ b/lib/torch-extension/arch.nix @@ -26,6 +26,8 @@ }: { + buildConfig, + # Whether to do ABI checks. doAbiCheck ? true, @@ -48,6 +50,12 @@ src, }: +# Extra validation - the environment should correspind to the build config. +assert (buildConfig ? cudaVersion) -> cudaSupport; +assert (buildConfig ? rocmVersion) -> rocmSupport; +assert (buildConfig ? xpuVersion) -> xpuSupport; +assert (buildConfig.metal or false) -> stdenv.hostPlatform.isDarwin; + let # On Darwin, we need the host's xcrun for `xcrun metal` to compile Metal shaders. # It's not supported by the nixpkgs shim. @@ -57,6 +65,8 @@ let /usr/bin/xcrun $@ ''; + metalSupport = buildConfig.metal or false; + in stdenv.mkDerivation (prevAttrs: { @@ -73,8 +83,10 @@ stdenv.mkDerivation (prevAttrs: { "rocm" else if xpuSupport then "xpu" - else + else if metalSupport then "metal" + else + "cpu" } --ops-id ${rev} build.toml ''; @@ -176,7 +188,7 @@ stdenv.mkDerivation (prevAttrs: { (lib.cmakeFeature "CMAKE_HIP_COMPILER_ROCM_ROOT" "${clr}") (lib.cmakeFeature "HIP_ROOT_DIR" "${clr}") ] - ++ lib.optionals stdenv.hostPlatform.isDarwin [ + ++ lib.optionals metalSupport [ # Use host compiler for Metal. Not included in the redistributable SDK. (lib.cmakeFeature "METAL_COMPILER" "${xcrunHost}/bin/xcrunHost") ]; @@ -207,7 +219,7 @@ stdenv.mkDerivation (prevAttrs: { getKernelCheck = extensionName; # We need access to the host system on Darwin for the Metal compiler. - __noChroot = stdenv.hostPlatform.isDarwin; + __noChroot = metalSupport; passthru = { inherit torch; diff --git a/lib/torch-version-utils.nix b/lib/torch-version-utils.nix index 6624c9da..7bda0313 100644 --- a/lib/torch-version-utils.nix +++ b/lib/torch-version-utils.nix @@ -7,6 +7,7 @@ ++ map (system: (builtins.removeAttrs version [ "systems" ]) // { inherit system; }) version.systems ) [ ]; + isCpu = version: version.cpu or false; isCuda = version: version ? cudaVersion; isMetal = version: version.metal or false; isRocm = version: version ? rocmVersion; diff --git a/overlay.nix b/overlay.nix index 427895a2..8907f6e7 100644 --- a/overlay.nix +++ b/overlay.nix @@ -25,8 +25,8 @@ final: prev: { src = final.fetchFromGitHub { owner = "huggingface"; repo = "kernels"; - rev = "5d21b86a5d611100c10c10b79ffa7965edf567fd"; - sha256 = "sha256-lKQUVbjhpeXKj1SeZRxgPSsOtBUZ7zQeO6pRoA1h+W8="; + rev = "0e18dbf076fc44de5dac4027616e9f3d9e2da45a"; + sha256 = "sha256-6N1W3jLQIS1yEAdNR2X9CuFdMw4Ia0yzBBVQ4Kujv8U="; }; }); } diff --git a/scripts/gen_variants_markdown.py b/scripts/gen_variants_markdown.py index e5ad3358..a6f8d8d9 100755 --- a/scripts/gen_variants_markdown.py +++ b/scripts/gen_variants_markdown.py @@ -4,6 +4,7 @@ from pathlib import Path _PLATFORM_NAMES = { + "cpu": "CPU", "cuda": "CUDA", "metal": "Metal", "rocm": "ROCm", diff --git a/tests/Dockerfile.test-kernel b/tests/Dockerfile.test-kernel index 3fafaf62..90a09ef5 100644 --- a/tests/Dockerfile.test-kernel +++ b/tests/Dockerfile.test-kernel @@ -64,14 +64,12 @@ RUN uv add numpy pytest # Copy kernels and tests COPY relu-kernel ./relu-kernel +COPY relu-kernel-cpu ./relu-kernel-cpu COPY cutlass-gemm-kernel ./cutlass-gemm-kernel COPY silu-and-mul-universal-kernel ./silu-and-mul-universal-kernel COPY examples/relu/tests ./relu_tests -COPY examples/cutlass-gemm/tests ./tests/cutlass_gemm_tests +COPY examples/cutlass-gemm/tests ./cutlass_gemm_tests # Run tests -ENV PYTHONPATH="relu-kernel:cutlass-gemm-kernel:silu-and-mul-universal-kernel:$PYTHONPATH" -CMD ["/bin/sh", "-c", ".venv/bin/pytest", "relu_tests", "cutlass_gemm_tests"] - -# We only care about importing, the kernel is trivial. -CMD ["/bin/sh", "-c", ".venv/bin/python", "-c", "'import silu_and_mul_universal'"] +ADD tests/run-tests.sh ./run-tests.sh +CMD ["sh", "run-tests.sh"] diff --git a/tests/run-tests.sh b/tests/run-tests.sh new file mode 100644 index 00000000..f449a2f6 --- /dev/null +++ b/tests/run-tests.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +PYTHONPATH="relu-kernel:cutlass-gemm-kernel:$PYTHONPATH" \ + .venv/bin/pytest relu_tests cutlass_gemm_tests + +# We only care about importing, the kernel is trivial. +PYTHONPATH="silu-and-mul-universal-kernel:$PYTHONPATH" \ + .venv/bin/python -c "import silu_and_mul_universal" + +PYTHONPATH="relu-kernel-cpu:$PYTHONPATH" \ + CUDA_VISIBLE_DEVICES="" \ + .venv/bin/pytest relu_tests diff --git a/versions.nix b/versions.nix index 59352cca..deb295dd 100644 --- a/versions.nix +++ b/versions.nix @@ -62,6 +62,17 @@ bundleBuild = true; sourceBuild = true; } + { + torchVersion = "2.8"; + cxx11Abi = true; + cpu = true; + systems = [ + "aarch64-darwin" + "x86_64-linux" + "aarch64-linux" + ]; + bundleBuild = true; + } { torchVersion = "2.9"; xpuVersion = "2025.2.1"; @@ -127,6 +138,17 @@ bundleBuild = true; sourceBuild = true; } + { + torchVersion = "2.9"; + cxx11Abi = true; + cpu = true; + systems = [ + "aarch64-darwin" + "x86_64-linux" + "aarch64-linux" + ]; + bundleBuild = true; + } # Non-standard versions; not included in bundle builds. {