Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@
buildVariants =
(import ./lib/build-variants.nix {
inherit (nixpkgs) lib;
torchVersions = torchVersions';
}).buildVariants;
}).buildVariants
torchVersions';
in
builtins.toJSON buildVariants;
genFlakeOutputs =
Expand Down Expand Up @@ -104,7 +104,6 @@
pythonNativeCheckInputs
;
build = buildPerSystem.${system};
buildSet = buildSetPerSystem.${system};
}
);
}
Expand Down
22 changes: 13 additions & 9 deletions lib/build-variants.nix
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{ lib, torchVersions }:
{ lib }:
let
inherit (import ./torch-version-utils.nix { inherit lib; })
flattenSystems
Expand All @@ -22,8 +22,7 @@ rec {
else
throw "Could not find compute framework: no CUDA, ROCm, XPU version specified and Metal is not enabled";

# Build variants included in bundle builds.
buildVariants =
buildName =
let
inherit (import ./version-utils.nix { inherit lib; }) abiString flattenVersion;
computeString =
Expand All @@ -38,12 +37,17 @@ rec {
"xpu${flattenVersion (lib.versions.majorMinor version.xpuVersion)}"
else
throw "No compute framework set in Torch version";
buildName =
version:
if version.system == "aarch64-darwin" then
"torch${flattenVersion version.torchVersion}-${computeString version}-${version.system}"
else
"torch${flattenVersion version.torchVersion}-${abiString version.cxx11Abi}-${computeString version}-${version.system}";
in
version:
if version.system == "aarch64-darwin" then
"torch${flattenVersion version.torchVersion}-${computeString version}-${version.system}"
else
"torch${flattenVersion version.torchVersion}-${abiString version.cxx11Abi}-${computeString version}-${version.system}";

# Build variants included in bundle builds.
buildVariants =
torchVersions:
let
bundleBuildVersions = lib.filter (version: version.bundleBuild or false);
in
lib.foldl' (
Expand Down
27 changes: 14 additions & 13 deletions lib/build.nix
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ rec {
lib.foldl (backends: kernel: backends // { ${kernelBackend kernel} = true; }) init kernels;

readBuildConfig = path: validateBuildConfig (readToml (path + "/build.toml"));
tracedReadBuildConfig = path: readBuildConfig path;

srcFilter =
src: name: type:
Expand All @@ -81,7 +80,7 @@ rec {
mkSourceSet = import ./source-set.nix { inherit lib; };

# Filter buildsets that are applicable to a given kernel build config.
applicableBuildSets =
filterApplicableBuildSets =
buildConfig: buildSets:
let
backends' = backends buildConfig;
Expand All @@ -107,6 +106,8 @@ rec {
in
builtins.filter supportedBuildSet buildSets;

applicableBuildSets = path: filterApplicableBuildSets (readBuildConfig path) buildSets;

# Build a single Torch extension.
buildTorchExtension =
{
Expand Down Expand Up @@ -180,17 +181,16 @@ rec {
name = torchBuildVersion buildSet;
value = buildTorchExtension buildSet { inherit path rev; };
};
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
in
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) filteredBuildSets);
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) (applicableBuildSets path));

# Build multiple Torch extensions.
buildDistTorchExtensions =
{
buildSets,
path,
rev,
doGetKernelCheck,
bundleOnly,
}:
let
extensionForTorch =
Expand All @@ -203,9 +203,13 @@ rec {
oldLinuxCompat = true;
};
};
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
applicableBuildSets' =
if bundleOnly then
builtins.filter (buildSet: buildSet.bundleBuild) (applicableBuildSets path)
else
(applicableBuildSets path);
in
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) filteredBuildSets);
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) applicableBuildSets');

buildTorchExtensionBundle =
{
Expand All @@ -216,10 +220,9 @@ rec {
let
# We just need to get any nixpkgs for use by the path join.
pkgs = (builtins.head buildSets).pkgs;
bundleBuildSets = builtins.filter (buildSet: buildSet.bundleBuild) buildSets;
extensions = buildDistTorchExtensions {
inherit path rev doGetKernelCheck;
buildSets = bundleBuildSets;
bundleOnly = true;
};
buildConfig = readBuildConfig path;
namePaths =
Expand Down Expand Up @@ -273,9 +276,8 @@ rec {
'';
};
};
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
in
builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) filteredBuildSets);
builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) (applicableBuildSets path));

torchDevShells =
{
Expand Down Expand Up @@ -315,7 +317,6 @@ rec {
venvDir = "./.venv";
};
};
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
in
builtins.listToAttrs (lib.map shellForBuildSet filteredBuildSets);
builtins.listToAttrs (lib.map shellForBuildSet (applicableBuildSets path));
}
65 changes: 42 additions & 23 deletions lib/gen-flake-outputs.nix
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
lib,
build,
buildSet,
system,

writeScriptBin,
Expand All @@ -16,6 +16,8 @@
}:

let
inherit (import ./build-variants.nix { inherit lib; }) buildName;

supportedFormat = ''
kernel-builder.lib.genFlakeOutputs {
inherit self;
Expand All @@ -34,8 +36,45 @@ let
throw "Flake's `self` must be passed to `genFlakeOutputs` as follows:\n\n${supportedFormat}";

revUnderscored = builtins.replaceStrings [ "-" ] [ "_" ] flakeRev;

# For picking a default shell, etc. we want to use the following logic:
#
# - Prefer bundle builds over non-bundle builds.
# - Prefer CUDA over other frameworks.
# - Prefer newer Torch versions over older.
# - Prefer older frameworks over newer (best compatibility).

# Enrich the build configs with generic attributes for framework
# order/version. Also make bundleBuild attr explicit.
buildConfigs = map (
set:
let
inherit (set) buildConfig;
in
buildConfig
// {
bundleBuild = buildConfig.bundleBuild or false;
frameworkOrder = if buildConfig ? cudaVersion then 0 else 1;
frameworkVersion =
buildConfig.cudaVersion or buildConfig.rocmVersion or buildConfig.xpuVersion or "0.0";
}
) (build.applicableBuildSets path);
configCompare =
a: b:
if a.bundleBuild != b.bundleBuild then
a.bundleBuild
else if a.frameworkOrder != b.frameworkOrder then
a.frameworkOrder < b.frameworkOrder
else if a.torchVersion != b.torchVersion then
builtins.compareVersions a.torchVersion b.torchVersion > 0
else
builtins.compareVersions a.frameworkVersion b.frameworkVersion < 0;
buildConfigsSorted = lib.sort configCompare buildConfigs;
shellTorch =
if system == "aarch64-darwin" then "torch28-metal-${system}" else "torch28-cxx11-cu126-${system}";
if buildConfigsSorted == [ ] then
throw "No build variant is compatible with this system"
else
buildName (builtins.head buildConfigsSorted);
in

{
Expand Down Expand Up @@ -90,28 +129,8 @@ in
};
redistributable = build.buildDistTorchExtensions {
inherit path doGetKernelCheck;
buildSets = buildSet;
bundleOnly = false;
rev = revUnderscored;
};
buildTree =
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This output has been broken for a while, so I don't think anyone uses it. IIRC it precedes shell support, it's more convenient to nix develop now and build2cmake there.

let
src = build.mkSourceSet path;
in
runCommand "torch-extension-build-tree"
{
nativeBuildInputs = [ buildSet.pkgs.build2cmake ];
inherit src;
meta = {
description = "Build tree for torch extension with source files and CMake configuration";
};
}
''
# Copy sources
install -dm755 $out/src
cp -r $src/. $out/src/

# Generate cmake files
build2cmake generate-torch --ops-id "${revUnderscored}" $src/build.toml $out --force
'';
};
}
Loading