From 4e83a7772d260cfbec779de9f326435a9641b784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 10 Oct 2025 09:56:17 +0000 Subject: [PATCH] Add `ci` output for generated flakes This output is like `bundle`, but only builds one variant for each framework. --- flake.nix | 2 +- lib/build.nix | 66 ++++++++++++++++++--------------------- lib/gen-flake-outputs.nix | 48 ++++++++++++++++++++++++---- 3 files changed, 73 insertions(+), 43 deletions(-) diff --git a/flake.nix b/flake.nix index bdcb298c..716dc0a4 100644 --- a/flake.nix +++ b/flake.nix @@ -46,7 +46,6 @@ system: buildSet: import lib/build.nix { inherit (nixpkgs) lib; - buildSets = buildSetPerSystem.${system}; } ) buildSetPerSystem; @@ -104,6 +103,7 @@ pythonNativeCheckInputs ; build = buildPerSystem.${system}; + buildSets = buildSetPerSystem.${system}; } ); } diff --git a/lib/build.nix b/lib/build.nix index 6674719e..121f4795 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -1,13 +1,13 @@ { lib, - # List of build sets. Each build set is a attrset of the form - # - # { pkgs = , torch = } - # - # The Torch derivation is built as-is. So e.g. the ABI version should - # already be set. - buildSets, +# Every `buildSets` argument is a list of build sets. Each build set is +# a attrset of the form +# +# { pkgs = , torch = } +# +# The Torch derivation is built as-is. So e.g. the ABI version should +# already be set. }: let @@ -106,10 +106,11 @@ rec { in builtins.filter supportedBuildSet buildSets; - applicableBuildSets = path: filterApplicableBuildSets (readBuildConfig path) buildSets; + applicableBuildSets = + { path, buildSets }: filterApplicableBuildSets (readBuildConfig path) buildSets; # Build a single Torch extension. - buildTorchExtension = + mkTorchExtension = { buildConfig, pkgs, @@ -172,56 +173,47 @@ rec { }); # Build multiple Torch extensions. - buildNixTorchExtensions = - { path, rev }: - let - extensionForTorch = - { path, rev }: - buildSet: { - name = torchBuildVersion buildSet; - value = buildTorchExtension buildSet { inherit path rev; }; - }; - in - builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) (applicableBuildSets path)); - - # Build multiple Torch extensions. - buildDistTorchExtensions = + mkDistTorchExtensions = { path, rev, doGetKernelCheck, bundleOnly, + buildSets, }: let extensionForTorch = { path, rev }: buildSet: { name = torchBuildVersion buildSet; - value = buildTorchExtension buildSet { + value = mkTorchExtension buildSet { inherit path rev doGetKernelCheck; stripRPath = true; oldLinuxCompat = true; }; }; applicableBuildSets' = - if bundleOnly then - builtins.filter (buildSet: buildSet.bundleBuild) (applicableBuildSets path) - else - (applicableBuildSets path); + if bundleOnly then builtins.filter (buildSet: buildSet.bundleBuild) buildSets else buildSets; in builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) applicableBuildSets'); - buildTorchExtensionBundle = + mkTorchExtensionBundle = { path, rev, doGetKernelCheck, + buildSets, }: let # We just need to get any nixpkgs for use by the path join. pkgs = (builtins.head buildSets).pkgs; - extensions = buildDistTorchExtensions { - inherit path rev doGetKernelCheck; + extensions = mkDistTorchExtensions { + inherit + buildSets + path + rev + doGetKernelCheck + ; bundleOnly = true; }; buildConfig = readBuildConfig path; @@ -243,6 +235,7 @@ rec { { path, rev, + buildSets, doGetKernelCheck, pythonCheckInputs, pythonNativeCheckInputs, @@ -271,18 +264,19 @@ rec { ++ (pythonCheckInputs python3.pkgs); shellHook = '' export PYTHONPATH=''${PYTHONPATH}:${ - buildTorchExtension buildSet { inherit path rev doGetKernelCheck; } + mkTorchExtension buildSet { inherit path rev doGetKernelCheck; } } ''; }; }; in - builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) (applicableBuildSets path)); + builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) buildSets); - torchDevShells = + mkTorchDevShells = { path, rev, + buildSets, doGetKernelCheck, pythonCheckInputs, pythonNativeCheckInputs, @@ -309,7 +303,7 @@ rec { ] ++ (pythonNativeCheckInputs python3.pkgs); buildInputs = with pkgs; [ python3.pkgs.pytest ] ++ (pythonCheckInputs python3.pkgs); - inputsFrom = [ (buildTorchExtension buildSet { inherit path rev doGetKernelCheck; }) ]; + inputsFrom = [ (mkTorchExtension buildSet { inherit path rev doGetKernelCheck; }) ]; env = lib.optionalAttrs rocmSupport { PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" buildSet.torch.rocmArchs; HIP_PATH = pkgs.rocmPackages.clr; @@ -318,5 +312,5 @@ rec { }; }; in - builtins.listToAttrs (lib.map shellForBuildSet (applicableBuildSets path)); + builtins.listToAttrs (lib.map shellForBuildSet buildSets); } diff --git a/lib/gen-flake-outputs.nix b/lib/gen-flake-outputs.nix index 58b8c1be..d7c1a82b 100644 --- a/lib/gen-flake-outputs.nix +++ b/lib/gen-flake-outputs.nix @@ -7,6 +7,7 @@ runCommand, path, + buildSets, rev ? null, self ? null, @@ -37,6 +38,8 @@ let revUnderscored = builtins.replaceStrings [ "-" ] [ "_" ] flakeRev; + applicableBuildSets = build.applicableBuildSets { inherit path buildSets; }; + # For picking a default shell, etc. we want to use the following logic: # # - Prefer bundle builds over non-bundle builds. @@ -46,7 +49,7 @@ let # Enrich the build configs with generic attributes for framework # order/version. Also make bundleBuild attr explicit. - buildSets = map ( + addSortOrder = map ( set: let inherit (set) buildConfig; @@ -57,12 +60,23 @@ let buildConfig // { bundleBuild = buildConfig.bundleBuild or false; + framework = + if buildConfig ? cudaVersion then + "cuda" + else if buildConfig ? rocmVersion then + "rocm" + else if buildConfig ? xpuVersion then + "xpu" + else if system == "aarch64-darwin" then + "metal" + else + throw "Cannot determine framework for build set"; frameworkOrder = if buildConfig ? cudaVersion then 0 else 1; frameworkVersion = buildConfig.cudaVersion or buildConfig.rocmVersion or buildConfig.xpuVersion or "0.0"; }; } - ) (build.applicableBuildSets path); + ); configCompare = setA: setB: let @@ -77,25 +91,27 @@ let builtins.compareVersions a.torchVersion b.torchVersion > 0 else builtins.compareVersions a.frameworkVersion b.frameworkVersion < 0; - buildSetsSorted = lib.sort configCompare buildSets; + buildSetsSorted = lib.sort configCompare (addSortOrder applicableBuildSets); bestBuildSet = if buildSetsSorted == [ ] then throw "No build variant is compatible with this system" else builtins.head buildSetsSorted; shellTorch = buildName bestBuildSet.buildConfig; + headOrEmpty = l: if l == [ ] then [ ] else [ (builtins.head l) ]; in { devShells = rec { default = devShells.${shellTorch}; test = testShells.${shellTorch}; - devShells = build.torchDevShells { + devShells = build.mkTorchDevShells { inherit path doGetKernelCheck pythonCheckInputs pythonNativeCheckInputs ; + buildSets = applicableBuildSets; rev = revUnderscored; }; testShells = build.torchExtensionShells { @@ -105,13 +121,15 @@ in pythonCheckInputs pythonNativeCheckInputs ; + buildSets = applicableBuildSets; rev = revUnderscored; }; }; packages = let - bundle = build.buildTorchExtensionBundle { + bundle = build.mkTorchExtensionBundle { inherit path doGetKernelCheck; + buildSets = applicableBuildSets; rev = revUnderscored; }; in @@ -140,6 +158,23 @@ in chmod -R +w build ''; + ci = + let + setsWithFramework = + framework: builtins.filter (set: set.buildConfig.framework == framework) buildSetsSorted; + # It is too costly to build all variants in CI, so we just build one per framework. + onePerFramework = + (headOrEmpty (setsWithFramework "cuda")) + ++ (headOrEmpty (setsWithFramework "metal")) + ++ (headOrEmpty (setsWithFramework "rocm")) + ++ (headOrEmpty (setsWithFramework "xpu")); + in + build.mkTorchExtensionBundle { + inherit path doGetKernelCheck; + buildSets = onePerFramework; + rev = revUnderscored; + }; + kernels = bestBuildSet.pkgs.python3.withPackages ( ps: with ps; [ @@ -151,10 +186,11 @@ in meta.mainProgram = "kernels"; }; - redistributable = build.buildDistTorchExtensions { + redistributable = build.mkDistTorchExtensions { inherit path doGetKernelCheck; bundleOnly = false; rev = revUnderscored; + buildSets = applicableBuildSets; }; }; }