Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions lib/torch-extension/arch.nix
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ assert (buildConfig ? xpuVersion) -> xpuSupport;
assert (buildConfig.metal or false) -> stdenv.hostPlatform.isDarwin;

let
extensionName = builtins.replaceStrings [ "-" ] [ "_" ] kernelName;
moduleName = builtins.replaceStrings [ "-" ] [ "_" ] kernelName;

# On Darwin, we need the host's xcrun for `xcrun metal` to compile Metal shaders.
# It's not supported by the nixpkgs shim.
Expand All @@ -80,11 +80,11 @@ let
in

stdenv.mkDerivation (prevAttrs: {
name = "${extensionName}-torch-ext";
name = "${kernelName}-torch-ext";

inherit
doAbiCheck
extensionName
moduleName
nvccThreads
src
;
Expand Down Expand Up @@ -232,15 +232,15 @@ stdenv.mkDerivation (prevAttrs: {
postInstall = ''
(
cd ..
cp -r torch-ext/${extensionName}/* $out/
cp -r torch-ext/${moduleName}/* $out/
)
mv $out/_${extensionName}_*/* $out/
rm -d $out/_${extensionName}_${rev}
mv $out/_${moduleName}_*/* $out/
rm -d $out/_${moduleName}_${rev}

# Set up a compatibility module for older kernels versions, remove when
# the updated kernels has been around for a while.
mkdir $out/${extensionName}
cp ${./compat.py} $out/${extensionName}/__init__.py
mkdir $out/${moduleName}
cp ${./compat.py} $out/${moduleName}/__init__.py
''
+ (lib.optionalString (stripRPath && stdenv.hostPlatform.isLinux)) ''
find $out/ -name '*.so' \
Expand Down
10 changes: 5 additions & 5 deletions lib/torch-extension/no-arch.nix
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
}:

let
extensionName = builtins.replaceStrings [ "-" ] [ "_" ] kernelName;
moduleName = builtins.replaceStrings [ "-" ] [ "_" ] kernelName;
in

stdenv.mkDerivation (prevAttrs: {
name = "${kernelName}-torch-ext";

inherit extensionName src;
inherit moduleName src;

# Add Torch as a dependency, so that devshells for universal kernels
# also get torch as a build input.
Expand All @@ -54,9 +54,9 @@ stdenv.mkDerivation (prevAttrs: {

installPhase = ''
mkdir -p $out
cp -r torch-ext/${extensionName}/* $out/
mkdir $out/${extensionName}
cp ${./compat.py} $out/${extensionName}/__init__.py
cp -r torch-ext/${moduleName}/* $out/
mkdir $out/${moduleName}
cp ${./compat.py} $out/${moduleName}/__init__.py
'';

doInstallCheck = true;
Expand Down
8 changes: 4 additions & 4 deletions pkgs/get-kernel-check/get-kernel-check-hook.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ echo "Sourcing get-kernel-check-hook.sh"
_getKernelCheckHook() {
echo "Checking loading kernel with get_kernel"

if [ -z ${extensionName+x} ]; then
echo "extensionName must be set in derivation"
if [ -z ${moduleName+x} ]; then
echo "moduleName must be set in derivation"
exit 1
fi

echo "Check whether the kernel can be loaded with get-kernel: ${extensionName}"
echo "Check whether the kernel can be loaded with get-kernel: ${moduleName}"

# We strip the full library paths from the extension. Unfortunately,
# in a Nix environment, the library dependencies cannot be found
Expand All @@ -35,7 +35,7 @@ _getKernelCheckHook() {
mkdir -p "${TMPDIR}/build"
ln -s "$out" "${TMPDIR}/build/${BUILD_VARIANT}"

python -c "from pathlib import Path; import kernels; kernels.get_local_kernel(Path('${TMPDIR}'), '${extensionName}')"
python -c "from pathlib import Path; import kernels; kernels.get_local_kernel(Path('${TMPDIR}'), '${moduleName}')"
}

postInstallCheckHooks+=(_getKernelCheckHook)
12 changes: 6 additions & 6 deletions pkgs/kernel-layout-check/kernel-layout-check-hook.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@ echo "Sourcing kernel-layout-check-hook.sh"
kernelLayoutCheckHook() {
echo "Checking kernel layout"

if [ -z ${extensionName+x} ]; then
echo "extensionName must be set in derivation"
if [ -z ${moduleName+x} ]; then
echo "moduleName must be set in derivation"
exit 1
fi

if [ ! -f source/torch-ext/${extensionName}/__init__.py ]; then
echo "Python module at source/torch-ext/${extensionName} must contain __init__.py"
if [ ! -f source/torch-ext/${moduleName}/__init__.py ]; then
echo "Python module at source/torch-ext/${moduleName} must contain __init__.py"
exit 1
fi

# TODO: remove once the old location is removed from kernels.
if [ -e source/torch-ext/${extensionName}/${extensionName} ]; then
echo "Python module at source/torch-ext/${extensionName} must not have ${extensionName} file or directory."
if [ -e source/torch-ext/${moduleName}/${moduleName} ]; then
echo "Python module at source/torch-ext/${moduleName} must not have ${moduleName} file or directory."
exit 1
fi
}
Expand Down
Loading