Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
system: buildSet:
import lib/build.nix {
inherit (nixpkgs) lib;
buildSets = buildSetPerSystem.${system};
}
) buildSetPerSystem;

Expand Down Expand Up @@ -104,6 +103,7 @@
pythonNativeCheckInputs
;
build = buildPerSystem.${system};
buildSets = buildSetPerSystem.${system};
}
);
}
Expand Down
66 changes: 30 additions & 36 deletions lib/build.nix
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
{
lib,

# List of build sets. Each build set is a attrset of the form
#
# { pkgs = <nixpkgs>, torch = <torch drv> }
#
# The Torch derivation is built as-is. So e.g. the ABI version should
# already be set.
buildSets,
# Every `buildSets` argument is a list of build sets. Each build set is
# a attrset of the form
#
# { pkgs = <nixpkgs>, torch = <torch drv> }
#
# The Torch derivation is built as-is. So e.g. the ABI version should
# already be set.
}:

let
Expand Down Expand Up @@ -106,10 +106,11 @@ rec {
in
builtins.filter supportedBuildSet buildSets;

applicableBuildSets = path: filterApplicableBuildSets (readBuildConfig path) buildSets;
applicableBuildSets =
{ path, buildSets }: filterApplicableBuildSets (readBuildConfig path) buildSets;

# Build a single Torch extension.
buildTorchExtension =
mkTorchExtension =
{
buildConfig,
pkgs,
Expand Down Expand Up @@ -172,56 +173,47 @@ rec {
});

# Build multiple Torch extensions.
buildNixTorchExtensions =
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this function, we do not use it anywhere.

{ path, rev }:
let
extensionForTorch =
{ path, rev }:
buildSet: {
name = torchBuildVersion buildSet;
value = buildTorchExtension buildSet { inherit path rev; };
};
in
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) (applicableBuildSets path));

# Build multiple Torch extensions.
buildDistTorchExtensions =
mkDistTorchExtensions =
{
path,
rev,
doGetKernelCheck,
bundleOnly,
buildSets,
}:
let
extensionForTorch =
{ path, rev }:
buildSet: {
name = torchBuildVersion buildSet;
value = buildTorchExtension buildSet {
value = mkTorchExtension buildSet {
inherit path rev doGetKernelCheck;
stripRPath = true;
oldLinuxCompat = true;
};
};
applicableBuildSets' =
if bundleOnly then
builtins.filter (buildSet: buildSet.bundleBuild) (applicableBuildSets path)
else
(applicableBuildSets path);
if bundleOnly then builtins.filter (buildSet: buildSet.bundleBuild) buildSets else buildSets;
in
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) applicableBuildSets');

buildTorchExtensionBundle =
mkTorchExtensionBundle =
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the mk naming pattern is a bit nicer, so renamed everything here.

{
path,
rev,
doGetKernelCheck,
buildSets,
}:
let
# We just need to get any nixpkgs for use by the path join.
pkgs = (builtins.head buildSets).pkgs;
extensions = buildDistTorchExtensions {
inherit path rev doGetKernelCheck;
extensions = mkDistTorchExtensions {
inherit
buildSets
path
rev
doGetKernelCheck
;
bundleOnly = true;
};
buildConfig = readBuildConfig path;
Expand All @@ -243,6 +235,7 @@ rec {
{
path,
rev,
buildSets,
doGetKernelCheck,
pythonCheckInputs,
pythonNativeCheckInputs,
Expand Down Expand Up @@ -271,18 +264,19 @@ rec {
++ (pythonCheckInputs python3.pkgs);
shellHook = ''
export PYTHONPATH=''${PYTHONPATH}:${
buildTorchExtension buildSet { inherit path rev doGetKernelCheck; }
mkTorchExtension buildSet { inherit path rev doGetKernelCheck; }
}
'';
};
};
in
builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) (applicableBuildSets path));
builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) buildSets);

torchDevShells =
mkTorchDevShells =
{
path,
rev,
buildSets,
doGetKernelCheck,
pythonCheckInputs,
pythonNativeCheckInputs,
Expand All @@ -309,7 +303,7 @@ rec {
]
++ (pythonNativeCheckInputs python3.pkgs);
buildInputs = with pkgs; [ python3.pkgs.pytest ] ++ (pythonCheckInputs python3.pkgs);
inputsFrom = [ (buildTorchExtension buildSet { inherit path rev doGetKernelCheck; }) ];
inputsFrom = [ (mkTorchExtension buildSet { inherit path rev doGetKernelCheck; }) ];
env = lib.optionalAttrs rocmSupport {
PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" buildSet.torch.rocmArchs;
HIP_PATH = pkgs.rocmPackages.clr;
Expand All @@ -318,5 +312,5 @@ rec {
};
};
in
builtins.listToAttrs (lib.map shellForBuildSet (applicableBuildSets path));
builtins.listToAttrs (lib.map shellForBuildSet buildSets);
}
48 changes: 42 additions & 6 deletions lib/gen-flake-outputs.nix
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
runCommand,

path,
buildSets,
rev ? null,
self ? null,

Expand Down Expand Up @@ -37,6 +38,8 @@ let

revUnderscored = builtins.replaceStrings [ "-" ] [ "_" ] flakeRev;

applicableBuildSets = build.applicableBuildSets { inherit path buildSets; };

# For picking a default shell, etc. we want to use the following logic:
#
# - Prefer bundle builds over non-bundle builds.
Expand All @@ -46,7 +49,7 @@ let

# Enrich the build configs with generic attributes for framework
# order/version. Also make bundleBuild attr explicit.
buildSets = map (
addSortOrder = map (
set:
let
inherit (set) buildConfig;
Expand All @@ -57,12 +60,23 @@ let

buildConfig // {
bundleBuild = buildConfig.bundleBuild or false;
framework =
if buildConfig ? cudaVersion then
"cuda"
else if buildConfig ? rocmVersion then
"rocm"
else if buildConfig ? xpuVersion then
"xpu"
else if system == "aarch64-darwin" then
"metal"
else
throw "Cannot determine framework for build set";
frameworkOrder = if buildConfig ? cudaVersion then 0 else 1;
frameworkVersion =
buildConfig.cudaVersion or buildConfig.rocmVersion or buildConfig.xpuVersion or "0.0";
};
}
) (build.applicableBuildSets path);
);
configCompare =
setA: setB:
let
Expand All @@ -77,25 +91,27 @@ let
builtins.compareVersions a.torchVersion b.torchVersion > 0
else
builtins.compareVersions a.frameworkVersion b.frameworkVersion < 0;
buildSetsSorted = lib.sort configCompare buildSets;
buildSetsSorted = lib.sort configCompare (addSortOrder applicableBuildSets);
bestBuildSet =
if buildSetsSorted == [ ] then
throw "No build variant is compatible with this system"
else
builtins.head buildSetsSorted;
shellTorch = buildName bestBuildSet.buildConfig;
headOrEmpty = l: if l == [ ] then [ ] else [ (builtins.head l) ];
in
{
devShells = rec {
default = devShells.${shellTorch};
test = testShells.${shellTorch};
devShells = build.torchDevShells {
devShells = build.mkTorchDevShells {
inherit
path
doGetKernelCheck
pythonCheckInputs
pythonNativeCheckInputs
;
buildSets = applicableBuildSets;
rev = revUnderscored;
};
testShells = build.torchExtensionShells {
Expand All @@ -105,13 +121,15 @@ in
pythonCheckInputs
pythonNativeCheckInputs
;
buildSets = applicableBuildSets;
rev = revUnderscored;
};
};
packages =
let
bundle = build.buildTorchExtensionBundle {
bundle = build.mkTorchExtensionBundle {
inherit path doGetKernelCheck;
buildSets = applicableBuildSets;
rev = revUnderscored;
};
in
Expand Down Expand Up @@ -140,6 +158,23 @@ in
chmod -R +w build
'';

ci =
let
setsWithFramework =
framework: builtins.filter (set: set.buildConfig.framework == framework) buildSetsSorted;
# It is too costly to build all variants in CI, so we just build one per framework.
onePerFramework =
(headOrEmpty (setsWithFramework "cuda"))
++ (headOrEmpty (setsWithFramework "metal"))
++ (headOrEmpty (setsWithFramework "rocm"))
++ (headOrEmpty (setsWithFramework "xpu"));
in
build.mkTorchExtensionBundle {
inherit path doGetKernelCheck;
buildSets = onePerFramework;
rev = revUnderscored;
};

kernels =
bestBuildSet.pkgs.python3.withPackages (
ps: with ps; [
Expand All @@ -151,10 +186,11 @@ in
meta.mainProgram = "kernels";
};

redistributable = build.buildDistTorchExtensions {
redistributable = build.mkDistTorchExtensions {
inherit path doGetKernelCheck;
bundleOnly = false;
rev = revUnderscored;
buildSets = applicableBuildSets;
};
};
}
Loading