diff --git a/lib/torch-extension/default.nix b/lib/torch-extension/default.nix index 41e90860..4cc583dc 100644 --- a/lib/torch-extension/default.nix +++ b/lib/torch-extension/default.nix @@ -130,7 +130,13 @@ stdenv.mkDerivation (prevAttrs: { libcusparse ] ) - ++ lib.optionals rocmSupport (with rocmPackages; [ hipsparselt ]) + ++ lib.optionals rocmSupport ( + with rocmPackages; + [ + hipsparselt + rocwmma-devel + ] + ) ++ lib.optionals xpuSupport ([ oneapi-torch-dev onednn-xpu