Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
with buildSet.pkgs;
(
allOutputs buildSet.torch
++ lib.concatMap allOutputs buildSet.extension.extraBuildDeps
++ allOutputs build2cmake
++ allOutputs kernel-abi-check
++ allOutputs python3Packages.kernels
Expand Down
2 changes: 2 additions & 0 deletions lib/build-sets.nix
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ let
torch = pkgs.python3.pkgs."torch_${flattenVersion torchVersion}".override {
inherit cxx11Abi;
};
extension = pkgs.callPackage ./torch-extension { inherit torch; };
in
{
inherit
buildConfig
extension
pkgs
torch
bundleBuild
Expand Down
1 change: 1 addition & 0 deletions lib/build-version.nix
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
buildConfig,
extension,
pkgs,
torch,
bundleBuild,
Expand Down
40 changes: 10 additions & 30 deletions lib/build.nix
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,6 @@ let
isRocm
isXpu
;
mkStdenv =
buildSet: oldLinuxCompat:
let
inherit (buildSet) pkgs torch;
in
if pkgs.stdenv.hostPlatform.isDarwin then
pkgs.stdenv
else if oldLinuxCompat then
# Uses CUDA stdenv when we are building for CUDA.
pkgs.stdenvGlibc_2_27
else if torch.cudaSupport then
torch.cudaPackages.backendStdenv
else
pkgs.stdenv;

in
rec {
resolveDeps = import ./deps.nix { inherit lib; };
Expand Down Expand Up @@ -113,6 +98,7 @@ rec {
mkTorchExtension =
{
buildConfig,
extension,
pkgs,
torch,
bundleBuild,
Expand All @@ -122,7 +108,6 @@ rec {
rev,
doGetKernelCheck,
stripRPath ? false,
oldLinuxCompat ? false,
}:
let
inherit (lib) fileset;
Expand All @@ -143,34 +128,32 @@ rec {
_: buildConfig: builtins.length (buildConfig.cuda-capabilities or supportedCudaCapabilities)
) buildConfig.kernel
);
stdenv = mkStdenv { inherit pkgs torch; } oldLinuxCompat;
in
if buildConfig.general.universal then
# No torch extension sources? Treat it as a noarch package.
pkgs.callPackage ./torch-extension-noarch ({

extension.mkNoArchExtension {
inherit
src
rev
torch
doGetKernelCheck
;
extensionName = buildConfig.general.name;
})
}
else
pkgs.callPackage ./torch-extension ({
extension.mkExtension {
inherit
doGetKernelCheck
extraDeps
nvccThreads
src
stdenv
stripRPath
torch
rev
;

extensionName = buildConfig.general.name;
doAbiCheck = oldLinuxCompat;
});
doAbiCheck = true;
};

# Build multiple Torch extensions.
mkDistTorchExtensions =
Expand All @@ -189,7 +172,6 @@ rec {
value = mkTorchExtension buildSet {
inherit path rev doGetKernelCheck;
stripRPath = true;
oldLinuxCompat = true;
};
};
applicableBuildSets' =
Expand Down Expand Up @@ -247,8 +229,7 @@ rec {
let
pkgs = buildSet.pkgs;
rocmSupport = pkgs.config.rocmSupport or false;
stdenv = mkStdenv buildSet false;
mkShell = pkgs.mkShell.override { inherit stdenv; };
mkShell = pkgs.mkShell.override { inherit (buildSet.extension) stdenv; };
in
{
name = torchBuildVersion buildSet;
Expand Down Expand Up @@ -288,8 +269,7 @@ rec {
pkgs = buildSet.pkgs;
rocmSupport = pkgs.config.rocmSupport or false;
xpuSupport = pkgs.config.xpuSupport or false;
stdenv = mkStdenv buildSet false;
mkShell = pkgs.mkShell.override { inherit stdenv; };
mkShell = pkgs.mkShell.override { inherit (buildSet.extension) stdenv; };
in
{
name = torchBuildVersion buildSet;
Expand Down
219 changes: 219 additions & 0 deletions lib/torch-extension/arch.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
{
cudaSupport ? torch.cudaSupport,
rocmSupport ? torch.rocmSupport,
xpuSupport ? torch.xpuSupport,

lib,
stdenv,
cudaPackages,
cmake,
cmakeNvccThreadsHook,
ninja,
build2cmake,
get-kernel-check,
kernel-abi-check,
python3,
rewrite-nix-paths-macho,
rocmPackages,
writeScriptBin,
xpuPackages,

apple-sdk_15,
clr,
oneapi-torch-dev,
onednn-xpu,
torch,
}:

{
# Whether to do ABI checks.
doAbiCheck ? true,

# Whether to run get-kernel-check.
doGetKernelCheck ? true,

extensionName,

# Extra dependencies (such as CUTLASS).
extraDeps ? [ ],

nvccThreads,

# Wheter to strip rpath for non-nix use.
stripRPath ? false,

# Revision to bake into the ops name.
rev,

src,
}:

let
# On Darwin, we need the host's xcrun for `xcrun metal` to compile Metal shaders.
# It's not supported by the nixpkgs shim.
xcrunHost = writeScriptBin "xcrunHost" ''
# Use system SDK for Metal files.
unset DEVELOPER_DIR
/usr/bin/xcrun $@
'';

in

stdenv.mkDerivation (prevAttrs: {
name = "${extensionName}-torch-ext";

inherit doAbiCheck nvccThreads src;

# Generate build files.
postPatch = ''
build2cmake generate-torch --backend ${
if cudaSupport then
"cuda"
else if rocmSupport then
"rocm"
else if xpuSupport then
"xpu"
else
"metal"
} --ops-id ${rev} build.toml
'';

# hipify copies files, but its target is run in the CMake build and install
# phases. Since some of the files come from the Nix store, this fails the
# second time around.
preInstall = ''
chmod -R u+w .
'';

nativeBuildInputs = [
kernel-abi-check
cmake
ninja
build2cmake
]
++ lib.optionals doGetKernelCheck [
get-kernel-check
]
++ lib.optionals cudaSupport [
cmakeNvccThreadsHook
cudaPackages.cuda_nvcc
]
++ lib.optionals rocmSupport [
clr
]
++ lib.optionals xpuSupport ([
xpuPackages.ocloc
oneapi-torch-dev
])
++ lib.optionals stdenv.hostPlatform.isDarwin [
rewrite-nix-paths-macho
];

buildInputs = [
torch
torch.cxxdev
]
++ lib.optionals cudaSupport (
with cudaPackages;
[
cuda_cudart

# Make dependent on build configuration dependencies once
# the Torch dependency is gone.
cuda_cccl
libcublas
libcusolver
libcusparse
]
)
++ lib.optionals rocmSupport (
with rocmPackages;
[
hipsparselt
rocwmma-devel
]
)
++ lib.optionals xpuSupport ([
oneapi-torch-dev
onednn-xpu
])
++ lib.optionals stdenv.hostPlatform.isDarwin [
apple-sdk_15
]
++ extraDeps;

env =
lib.optionalAttrs cudaSupport {
CUDAToolkit_ROOT = "${lib.getDev cudaPackages.cuda_nvcc}";
TORCH_CUDA_ARCH_LIST =
if cudaPackages.cudaOlder "12.8" then
"7.0;7.5;8.0;8.6;8.9;9.0"
else if cudaPackages.cudaOlder "13.0" then
"7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0"
else
# sm_101 has been renamed to sm_110 in CUDA 13.
"7.5;8.0;8.6;8.9;9.0;10.0;11.0;12.0";
}
// lib.optionalAttrs rocmSupport {
PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" torch.rocmArchs;
}
// lib.optionalAttrs xpuSupport {
MKLROOT = oneapi-torch-dev;
SYCL_ROOT = oneapi-torch-dev;
};

# If we use the default setup, CMAKE_CUDA_HOST_COMPILER gets set to nixpkgs g++.
dontSetupCUDAToolkitCompilers = true;

cmakeFlags = [
(lib.cmakeFeature "Python_EXECUTABLE" "${python3.withPackages (ps: [ torch ])}/bin/python")
]
++ lib.optionals cudaSupport [
(lib.cmakeFeature "CMAKE_CUDA_HOST_COMPILER" "${stdenv.cc}/bin/g++")
]
++ lib.optionals rocmSupport [
# Ensure sure that we use HIP from our CLR override and not HIP from
# the symlink-joined ROCm toolkit.
(lib.cmakeFeature "CMAKE_HIP_COMPILER_ROCM_ROOT" "${clr}")
(lib.cmakeFeature "HIP_ROOT_DIR" "${clr}")
]
++ lib.optionals xpuSupport [
(lib.cmakeFeature "ONEDNN_XPU_INCLUDE_DIR" "${onednn-xpu}/include")
]
++ lib.optionals stdenv.hostPlatform.isDarwin [
# Use host compiler for Metal. Not included in the redistributable SDK.
(lib.cmakeFeature "METAL_COMPILER" "${xcrunHost}/bin/xcrunHost")
];

postInstall = ''
(
cd ..
cp -r torch-ext/${extensionName} $out/
)
cp $out/_${extensionName}_*/* $out/${extensionName}
rm -rf $out/_${extensionName}_*
''
+ (lib.optionalString (stripRPath && stdenv.hostPlatform.isLinux)) ''
find $out/${extensionName} -name '*.so' \
-exec patchelf --set-rpath "" {} \;
''
+ (lib.optionalString (stripRPath && stdenv.hostPlatform.isDarwin)) ''
find $out/${extensionName} -name '*.so' \
-exec rewrite-nix-paths-macho {} \;

# Stub some rpath.
find $out/${extensionName} -name '*.so' \
-exec install_name_tool -add_rpath "@loader_path/lib" {} \;
'';

doInstallCheck = true;

getKernelCheck = extensionName;

# We need access to the host system on Darwin for the Metal compiler.
__noChroot = stdenv.hostPlatform.isDarwin;

passthru = {
inherit torch;
};
})
Loading
Loading