Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/build_kernel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@ jobs:
env:
USER: github_runner
- name: Build activation kernel
run: ( cd examples/activation && nix build .\#redistributable.torch27-cxx11-cu126-x86_64-linux )
run: ( cd examples/activation && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
- name: Copy activation kernel
run: cp -rL examples/activation/result activation-kernel

- name: Build cutlass GEMM kernel
run: ( cd examples/cutlass-gemm && nix build .\#redistributable.torch27-cxx11-cu126-x86_64-linux )
run: ( cd examples/cutlass-gemm && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
- name: Copy cutlass GEMM kernel
run: cp -rL examples/cutlass-gemm/result cutlass-gemm-kernel

- name: Build relu kernel
run: ( cd examples/relu && nix build .\#redistributable.torch27-cxx11-cu126-x86_64-linux )
run: ( cd examples/relu && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
- name: Copy relu kernel
run: cp -rL examples/relu/result relu-kernel

- name: Build relu-backprop-compile kernel
run: ( cd examples/relu-backprop-compile && nix build .\#redistributable.torch27-cxx11-cu126-x86_64-linux )
run: ( cd examples/relu-backprop-compile && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
- name: Copy relu-backprop-compile kernel
run: cp -rL examples/relu-backprop-compile/result relu-backprop-compile-kernel

Expand All @@ -51,7 +51,7 @@ jobs:
run: ( cd examples/relu && nix build .#devShells.x86_64-linux.test )

- name: Build silu-and-mul-universal kernel
run: ( cd examples/silu-and-mul-universal && nix build .\#redistributable.torch27-cxx11-cu126-x86_64-linux )
run: ( cd examples/silu-and-mul-universal && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
- name: Copy silu-and-mul-universal kernel
run: cp -rL examples/silu-and-mul-universal/result silu-and-mul-universal-kernel

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_kernel_macos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ jobs:
# For now we only test that there are no regressions in building macOS
# kernels. Also run tests once we have a macOS runner.
- name: Build relu kernel
run: ( cd examples/relu && nix build .\#redistributable.torch27-metal-aarch64-darwin -L )
run: ( cd examples/relu && nix build .\#redistributable.torch29-metal-aarch64-darwin -L )
2 changes: 1 addition & 1 deletion .github/workflows/build_kernel_rocm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ jobs:
# For now we only test that there are no regressions in building ROCm
# kernels. Also run tests once we have a ROCm runner.
- name: Build relu kernel
run: ( cd examples/relu && nix build .\#redistributable.torch27-cxx11-rocm63-x86_64-linux -L )
run: ( cd examples/relu && nix build .\#redistributable.torch29-cxx11-rocm63-x86_64-linux -L )
2 changes: 1 addition & 1 deletion .github/workflows/build_kernel_xpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ jobs:
# For now we only test that there are no regressions in building XPU
# kernels. Also run tests once we have a XPU runner.
- name: Build relu kernel
run: ( cd examples/relu && nix build .\#redistributable.torch28-cxx11-xpu20251-x86_64-linux -L )
run: ( cd examples/relu && nix build .\#redistributable.torch29-cxx11-xpu20252-x86_64-linux -L )
7 changes: 0 additions & 7 deletions build-variants.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
{
"aarch64-darwin": {
"metal": [
"torch27-metal-aarch64-darwin",
"torch28-metal-aarch64-darwin",
"torch29-metal-aarch64-darwin"
]
},
"aarch64-linux": {
"cuda": [
"torch27-cxx11-cu128-aarch64-linux",
"torch28-cxx11-cu129-aarch64-linux",
"torch29-cxx11-cu126-aarch64-linux",
"torch29-cxx11-cu128-aarch64-linux",
Expand All @@ -17,9 +15,6 @@
},
"x86_64-linux": {
"cuda": [
"torch27-cxx11-cu118-x86_64-linux",
"torch27-cxx11-cu126-x86_64-linux",
"torch27-cxx11-cu128-x86_64-linux",
"torch28-cxx11-cu126-x86_64-linux",
"torch28-cxx11-cu128-x86_64-linux",
"torch28-cxx11-cu129-x86_64-linux",
Expand All @@ -28,14 +23,12 @@
"torch29-cxx11-cu130-x86_64-linux"
],
"rocm": [
"torch27-cxx11-rocm63-x86_64-linux",
"torch28-cxx11-rocm63-x86_64-linux",
"torch28-cxx11-rocm64-x86_64-linux",
"torch29-cxx11-rocm63-x86_64-linux",
"torch29-cxx11-rocm64-x86_64-linux"
],
"xpu": [
"torch27-cxx11-xpu20250-x86_64-linux",
"torch28-cxx11-xpu20251-x86_64-linux",
"torch29-cxx11-xpu20252-x86_64-linux"
]
Expand Down
7 changes: 0 additions & 7 deletions docs/build-variants.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,18 @@ available. This list will be updated as new PyTorch versions are released.

## Metal aarch64-darwin

- `torch27-metal-aarch64-darwin`
- `torch28-metal-aarch64-darwin`
- `torch29-metal-aarch64-darwin`

## CUDA aarch64-linux

- `torch27-cxx11-cu128-aarch64-linux`
- `torch28-cxx11-cu129-aarch64-linux`
- `torch29-cxx11-cu126-aarch64-linux`
- `torch29-cxx11-cu128-aarch64-linux`
- `torch29-cxx11-cu130-aarch64-linux`

## CUDA x86_64-linux

- `torch27-cxx11-cu118-x86_64-linux`
- `torch27-cxx11-cu126-x86_64-linux`
- `torch27-cxx11-cu128-x86_64-linux`
- `torch28-cxx11-cu126-x86_64-linux`
- `torch28-cxx11-cu128-x86_64-linux`
- `torch28-cxx11-cu129-x86_64-linux`
Expand All @@ -33,15 +28,13 @@ available. This list will be updated as new PyTorch versions are released.

## ROCm x86_64-linux

- `torch27-cxx11-rocm63-x86_64-linux`
- `torch28-cxx11-rocm63-x86_64-linux`
- `torch28-cxx11-rocm64-x86_64-linux`
- `torch29-cxx11-rocm63-x86_64-linux`
- `torch29-cxx11-rocm64-x86_64-linux`

## XPU x86_64-linux

- `torch27-cxx11-xpu20250-x86_64-linux`
- `torch28-cxx11-xpu20251-x86_64-linux`
- `torch29-cxx11-xpu20252-x86_64-linux`

Expand Down
2 changes: 1 addition & 1 deletion docs/docker.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ To load a kernel locally, you should add the kernel build that is compatible wit

```bash
# PyTorch 2.6 and CUDA 12.6
export PYTHONPATH="result/torch26-cxx11-cu126-x86_64-linux"
export PYTHONPATH="result/torch29-cxx11-cu126-x86_64-linux"
Comment on lines 188 to +189
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit : the torch version in the comment

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I'll update the comments in a separate PR

```

The kernel can then be imported as a Python module:
Expand Down
2 changes: 1 addition & 1 deletion docs/nix.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ using:

```bash
$ rm -rf .venv # Remove existing venv if any.
$ nix develop .#devShells.torch27-cxx11-rocm63-x86_64-linux
$ nix develop .#devShells.torch29-cxx11-rocm64-x86_64-linux
```

## Shell for testing a kernel
Expand Down
2 changes: 1 addition & 1 deletion examples/relu-specific-torch/flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
path = ./.;
torchVersions = defaultVersions: [
{
torchVersion = "2.7";
torchVersion = "2.9";
cudaVersion = "12.8";
cxx11Abi = true;
systems = [
Expand Down
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 10 additions & 3 deletions lib/build-sets.nix
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ let
cxx11Abi,
system,
bundleBuild ? false,
sourceBuild ? false,
}:
let
pkgs =
Expand All @@ -84,9 +85,15 @@ let
pkgsByXpuVer.${xpuVersion}
else
throw "No compute framework set in Torch version";
torch = pkgs.python3.pkgs."torch_${flattenVersion torchVersion}".override {
inherit cxx11Abi;
};
torch =
if sourceBuild then
pkgs.python3.pkgs."torch_${flattenVersion torchVersion}".override {
inherit cxx11Abi;
}
else
pkgs.python3.pkgs."torch-bin_${flattenVersion torchVersion}".override {
inherit cxx11Abi;
};
extension = pkgs.callPackage ./torch-extension { inherit torch; };
in
{
Expand Down
45 changes: 24 additions & 21 deletions lib/build.nix
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,30 @@ let
isRocm
isXpu
;
inherit (import ./build-variants.nix { inherit lib; }) computeFramework;
in
rec {
resolveDeps = import ./deps.nix { inherit lib; };

readToml = path: builtins.fromTOML (builtins.readFile path);

validateBuildConfig =
buildConfig:
buildToml:
let
kernels = lib.attrValues (buildConfig.kernel or { });
hasOldUniversal = builtins.hasAttr "universal" (buildConfig.torch or { });
kernels = lib.attrValues (buildToml.kernel or { });
hasOldUniversal = builtins.hasAttr "universal" (buildToml.torch or { });
hasLanguage = lib.any (kernel: kernel ? language) kernels;

in
assert lib.assertMsg (!hasOldUniversal && !hasLanguage) ''
build.toml seems to be of an older version, update it with:
build2cmake update-build build.toml'';
buildConfig;
buildToml;

backends =
buildConfig:
buildToml:
let
kernels = lib.attrValues (buildConfig.kernel or { });
kernels = lib.attrValues (buildToml.kernel or { });
kernelBackend = kernel: kernel.backend;
init = {
cuda = false;
Expand All @@ -66,11 +67,11 @@ rec {

# Filter buildsets that are applicable to a given kernel build config.
filterApplicableBuildSets =
buildConfig: buildSets:
buildToml: buildSets:
let
backends' = backends buildConfig;
minCuda = buildConfig.general.cuda-minver or "11.8";
maxCuda = buildConfig.general.cuda-maxver or "99.9";
backends' = backends buildToml;
minCuda = buildToml.general.cuda-minver or "11.8";
maxCuda = buildToml.general.cuda-maxver or "99.9";
versionBetween =
minver: maxver: ver:
builtins.compareVersions ver minver >= 0 && builtins.compareVersions ver maxver <= 0;
Expand All @@ -82,7 +83,7 @@ rec {
|| (isRocm buildSet.buildConfig && backends'.rocm)
|| (isMetal buildSet.buildConfig && backends'.metal)
|| (isXpu buildSet.buildConfig && backends'.xpu)
|| (buildConfig.general.universal or false);
|| (buildToml.general.universal or false);
cudaVersionSupported =
!(isCuda buildSet.buildConfig)
|| versionBetween minCuda maxCuda buildSet.pkgs.cudaPackages.cudaMajorMinorVersion;
Expand Down Expand Up @@ -111,11 +112,13 @@ rec {
}:
let
inherit (lib) fileset;
buildConfig = readBuildConfig path;
kernels = buildConfig.kernel or { };
buildToml = readBuildConfig path;
kernels = lib.filterAttrs (_: kernel: computeFramework buildConfig == kernel.backend) (
buildToml.kernel or { }
);
extraDeps = resolveDeps {
inherit pkgs torch;
deps = lib.unique (lib.flatten (lib.mapAttrsToList (_: buildConfig: buildConfig.depends) kernels));
deps = lib.unique (lib.flatten (lib.mapAttrsToList (_: kernel: kernel.depends) kernels));
};

# Use the mkSourceSet function to get the source
Expand All @@ -125,11 +128,11 @@ rec {
listMax = lib.foldl' lib.max 1;
nvccThreads = listMax (
lib.mapAttrsToList (
_: buildConfig: builtins.length (buildConfig.cuda-capabilities or supportedCudaCapabilities)
) buildConfig.kernel
_: kernel: builtins.length (kernel.cuda-capabilities or supportedCudaCapabilities)
) buildToml.kernel
);
in
if buildConfig.general.universal then
if buildToml.general.universal then
# No torch extension sources? Treat it as a noarch package.

extension.mkNoArchExtension {
Expand All @@ -138,7 +141,7 @@ rec {
rev
doGetKernelCheck
;
extensionName = buildConfig.general.name;
extensionName = buildToml.general.name;
}
else
extension.mkExtension {
Expand All @@ -151,7 +154,7 @@ rec {
rev
;

extensionName = buildConfig.general.name;
extensionName = buildToml.general.name;
doAbiCheck = true;
};

Expand Down Expand Up @@ -198,9 +201,9 @@ rec {
;
bundleOnly = true;
};
buildConfig = readBuildConfig path;
buildToml = readBuildConfig path;
namePaths =
if buildConfig.general.universal then
if buildToml.general.universal then
# Noarch, just get the first extension.
{ "torch-universal" = builtins.head (builtins.attrValues extensions); }
else
Expand Down
2 changes: 1 addition & 1 deletion lib/deps.nix
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ let
];
"torch" = [
torch
torch.cxxdev
#torch.cxxdev
];
Comment thread
danieldk marked this conversation as resolved.
"cutlass_sycl" = [ torch.xpuPackages.cutlass-sycl ];
};
Expand Down
15 changes: 7 additions & 8 deletions lib/torch-extension/arch.nix
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ stdenv.mkDerivation (prevAttrs: {
++ lib.optionals rocmSupport (
with rocmPackages;
[
hipcub-devel
hipsparselt
rocprim-devel
rocthrust-devel
rocwmma-devel
]
)
Expand All @@ -145,14 +148,7 @@ stdenv.mkDerivation (prevAttrs: {
env =
lib.optionalAttrs cudaSupport {
CUDAToolkit_ROOT = "${lib.getDev cudaPackages.cuda_nvcc}";
TORCH_CUDA_ARCH_LIST =
if cudaPackages.cudaOlder "12.8" then
"7.0;7.5;8.0;8.6;8.9;9.0"
else if cudaPackages.cudaOlder "13.0" then
"7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0"
else
# sm_101 has been renamed to sm_110 in CUDA 13.
"7.5;8.0;8.6;8.9;9.0;10.0;11.0;12.0";
TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" torch.cudaCapabilities;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much cleaner ! very nice

}
// lib.optionalAttrs rocmSupport {
PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" torch.rocmArchs;
Expand All @@ -167,6 +163,9 @@ stdenv.mkDerivation (prevAttrs: {

cmakeFlags = [
(lib.cmakeFeature "Python_EXECUTABLE" "${python3.withPackages (ps: [ torch ])}/bin/python")
# Fix: file RPATH_CHANGE could not write new RPATH, we are rewriting
# rpaths anyway.
(lib.cmakeBool "CMAKE_SKIP_RPATH" true)
]
++ lib.optionals cudaSupport [
(lib.cmakeFeature "CMAKE_CUDA_HOST_COMPILER" "${stdenv.cc}/bin/g++")
Expand Down
Loading
Loading