From b8a27cbb4454add682c51eb296e60a0dce160ba5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 1 Apr 2026 13:25:56 +0000 Subject: [PATCH 1/2] kernels: add a function `system_variants` This function lists all the variants that are available on the current system. --- kernels/src/kernels/variants.py | 246 ++++++++++++++++++++++++-------- kernels/tests/test_variants.py | 62 +++++++- 2 files changed, 243 insertions(+), 65 deletions(-) diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index 8bfdddb8..22d44e33 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -1,3 +1,4 @@ +import itertools import logging import platform import re @@ -16,6 +17,7 @@ XPU, Backend, ROCm, + _backend, _select_backend, parse_backend, ) @@ -28,32 +30,77 @@ @dataclass(unsafe_hash=True) class Torch: - _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"torch(\d+?)(\d+)") + """Versioned Torch framework (arch variants).""" - version: Version | None + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile( + r"torch(\d+?)(\d+)(?:-(cxx11|cxx98))?" + ) + + version: Version + cxx11_abi: bool | None + + @staticmethod + def possible_variants() -> list["Torch"]: + if has_torch: + import torch + + torch_version = parse(torch.__version__) + torch_version = Version(f"{torch_version.major}.{torch_version.minor}") + + os_ = platform.system().lower() + if os_ == "linux": + cxx11_abi = torch.compiled_with_cxx11_abi() + return [ + Torch(version=torch_version, cxx11_abi=cxx11_abi), + Torch(version=torch_version, cxx11_abi=None), + ] + else: + return [Torch(version=torch_version, cxx11_abi=None)] + else: + return [] @property def variant_str(self) -> str: - if self.version is None: - return "torch" - return f"torch{self.version.major}{self.version.minor}" + base = f"torch{self.version.major}{self.version.minor}" + if self.cxx11_abi is None: + return base + return f"{base}-{'cxx11' if self.cxx11_abi else 'cxx98'}" @staticmethod def parse(s: str) -> "Torch": - if s == "torch": - return Torch(version=None) m = Torch._VARIANT_REGEX.fullmatch(s) if not m: raise ValueError(f"Invalid Torch variant string: {s!r}") - return Torch(version=Version(f"{m.group(1)}.{m.group(2)}")) + version = Version(f"{m.group(1)}.{m.group(2)}") + abi_str = m.group(3) + if abi_str is None: + cxx11_abi = None + else: + cxx11_abi = abi_str != "cxx98" + return Torch(version=version, cxx11_abi=cxx11_abi) @dataclass(unsafe_hash=True) class TvmFfi: + """Versioned tvm-ffi framework (arch variants).""" + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"tvm-ffi(\d+?)(\d+)") version: Version + @staticmethod + def possible_variants() -> list["TvmFfi"]: + if has_tvm_ffi: + import tvm_ffi + + tvm_ffi_version = parse(tvm_ffi.__version__) + tvm_ffi_version = Version( + f"{tvm_ffi_version.major}.{tvm_ffi_version.minor}" + ) + return [TvmFfi(version=tvm_ffi_version)] + else: + return [] + @property def variant_str(self) -> str: return f"tvm-ffi{self.version.major}{self.version.minor}" @@ -66,41 +113,60 @@ def parse(s: str) -> "TvmFfi": return TvmFfi(version=Version(f"{m.group(1)}.{m.group(2)}")) +@strict +@dataclass(unsafe_hash=True) +class TorchNoarch: + """Versionless Torch framework (noarch variants).""" + + @staticmethod + def possible_variants() -> list["TorchNoarch"]: + if has_torch: + return [TorchNoarch()] + else: + return [] + + @property + def variant_str(self) -> str: + return "torch" + + @strict @dataclass(unsafe_hash=True) class Arch: - """Aarch kernel information.""" + """Arch kernel information.""" backend: Backend platform: str os: str - cxx11_abi: bool | None @property def variant_str(self) -> str: - if self.cxx11_abi is None: - return f"{self.backend.variant_str}-{self.platform}-{self.os}" - else: - return f"{'cxx11' if self.cxx11_abi else 'cxx98'}-{self.backend.variant_str}-{self.platform}-{self.os}" + return f"{self.backend.variant_str}-{self.platform}-{self.os}" + + @staticmethod + def possible_variants() -> list["Arch"]: + cpu = platform.machine() + os = platform.system().lower() + + if os == "darwin": + cpu = "aarch64" if cpu == "arm64" else cpu + elif os == "windows": + cpu = "x86_64" if cpu == "AMD64" else cpu + + backend = _backend() + + return [Arch(backend=backend, platform=cpu, os=os)] @staticmethod def parse(parts: list[str]) -> "Arch": - # Handle Linux with cxx11 marker. - if len(parts) == 4: - # In the future, we want to remove the marker and use cxx11 as - # the default. We check on cxx98 for this reason. - cxx11_abi = parts[0] != "cxx98" - parts = parts[1:] - elif len(parts) == 3: - cxx11_abi = None - else: + if len(parts) != 3: raise ValueError(f"Invalid arch variant parts: {parts!r}") backend = parse_backend(parts[0]) platform = parts[1] os = parts[2] - return Arch(backend=backend, platform=platform, os=os, cxx11_abi=cxx11_abi) + return Arch(backend=backend, platform=platform, os=os) @strict @@ -110,6 +176,13 @@ class Noarch: backend_name: str + @staticmethod + def possible_variants() -> list["Noarch"]: + backend = _backend() + noarch_backend_name = "npu" if backend.name == "cann" else backend.name + names = {noarch_backend_name, "universal"} + return [Noarch(backend_name=name) for name in sorted(names)] + @property def variant_str(self) -> str: return self.backend_name @@ -121,37 +194,92 @@ def parse(s: str) -> "Noarch": @strict @dataclass(unsafe_hash=True) -class Variant: - """Kernel build variant.""" +class ArchVariant: + """Arch kernel build variant.""" framework: Torch | TvmFfi - arch: Arch | Noarch + arch: Arch + + @staticmethod + def possible_variants() -> list["ArchVariant"]: + frameworks: list[Torch | TvmFfi] = ( + Torch.possible_variants() + TvmFfi.possible_variants() + ) + archs = Arch.possible_variants() + return [ + ArchVariant(framework=fw, arch=arch) + for fw, arch in itertools.product(frameworks, archs) + ] @property def variant_str(self) -> str: return f"{self.framework.variant_str}-{self.arch.variant_str}" + +@strict +@dataclass(unsafe_hash=True) +class NoarchVariant: + """Noarch kernel build variant.""" + + framework: TorchNoarch + arch: Noarch + @staticmethod - def parse(variant_str: str) -> "Variant": - parts = variant_str.split("-") - - arch: Arch | Noarch - framework: Torch | TvmFfi - - if parts[0] == "torch": - # noarch: e.g. "torch-cpu" - framework = Torch.parse(parts[0]) - arch = Noarch.parse("-".join(parts[1:])) - elif parts[0].startswith("torch"): - framework = Torch.parse(parts[0]) - arch = Arch.parse(parts[1:]) - elif parts[0] == "tvm" and parts[1].startswith("ffi"): - framework = TvmFfi.parse(f"tvm-{parts[1]}") - arch = Arch.parse(parts[2:]) - else: - raise ValueError(f"Unknown framework in variant string: {variant_str!r}") + def possible_variants() -> list["NoarchVariant"]: + frameworks = TorchNoarch.possible_variants() + archs = Noarch.possible_variants() + return [ + NoarchVariant(framework=fw, arch=arch) + for fw, arch in itertools.product(frameworks, archs) + ] + + @property + def variant_str(self) -> str: + return f"{self.framework.variant_str}-{self.arch.variant_str}" + - return Variant(framework=framework, arch=arch) +Variant = ArchVariant | NoarchVariant + + +def system_variants() -> list[Variant]: + """Return all possible build variants for the current system. + + Warning: this function should only be used internally (so don't export + at the top-level) and for informational purposes, such as user + feedback. When loading kernels, etc. rely what is on disk and + use `parse_variant` + `resolve_variant`, since this uses our + priority order, etc.""" + result: list[Variant] = ( + ArchVariant.possible_variants() + NoarchVariant.possible_variants() + ) + return _sort_variants(result) + + +def parse_variant(variant_str: str) -> Variant: + """Parse a variant string into an ArchVariant or NoarchVariant.""" + parts = variant_str.split("-") + + if parts[0] == "torch": + # noarch: e.g. "torch-cpu" + return NoarchVariant( + framework=TorchNoarch(), arch=Noarch.parse("-".join(parts[1:])) + ) + elif parts[0].startswith("torch"): + if len(parts) >= 2 and parts[1] in ("cxx11", "cxx98"): + framework_str = f"{parts[0]}-{parts[1]}" + arch_parts = parts[2:] + else: + framework_str = parts[0] + arch_parts = parts[1:] + return ArchVariant( + framework=Torch.parse(framework_str), arch=Arch.parse(arch_parts) + ) + elif parts[0] == "tvm" and len(parts) >= 2 and parts[1].startswith("ffi"): + return ArchVariant( + framework=TvmFfi.parse(f"tvm-{parts[1]}"), arch=Arch.parse(parts[2:]) + ) + else: + raise ValueError(f"Unknown framework in variant string: {variant_str!r}") def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]: @@ -162,10 +290,10 @@ def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]: item.path.split("/")[-1] for item in tree if isinstance(item, RepoFolder) } - variants = [] + variants: list[Variant] = [] for variant_str in variant_strs: try: - variants.append(Variant.parse(variant_str)) + variants.append(parse_variant(variant_str)) except ValueError: logging.warning( f"Repository {repo_id} (revision: {revision}) contains invalid build variant variant: {variant_str!r}" @@ -181,10 +309,10 @@ def get_variants_local(repo_path: Path) -> list[Variant]: except Exception: return [] - variants = [] + variants: list[Variant] = [] for variant_str in variant_strs: try: - variants.append(Variant.parse(variant_str)) + variants.append(parse_variant(variant_str)) except ValueError: pass return variants @@ -228,7 +356,7 @@ def resolve_variants( if has_tvm_ffi: import tvm_ffi - # Parse Torch version and strip patch/tags. + # Parse tvm-ffi version and strip patch/tags. tvm_ffi_version = parse(tvm_ffi.__version__) tvm_ffi_version = Version(f"{tvm_ffi_version.major}.{tvm_ffi_version.minor}") @@ -275,9 +403,9 @@ def _filter_variants( tvm_ffi_version: Version | None, ) -> list[Variant]: """Return only the variants applicable to the current system.""" - result = [] + result: list[Variant] = [] for v in variants: - if isinstance(v.arch, Arch): + if isinstance(v, ArchVariant): # Skip non-matching CPU or OS. if v.arch.platform != cpu or v.arch.os != os: continue @@ -286,7 +414,10 @@ def _filter_variants( if isinstance(v.framework, Torch): if v.framework.version != torch_version: continue - if v.arch.cxx11_abi != torch_cxx11_abi: + if ( + v.framework.cxx11_abi is not None + and v.framework.cxx11_abi != torch_cxx11_abi + ): continue elif isinstance(v.framework, TvmFfi): if v.framework.version != tvm_ffi_version: @@ -302,8 +433,7 @@ def _filter_variants( continue elif v.arch.backend.variant_str != selected_backend.variant_str: continue - else: - assert isinstance(v.arch, Noarch) + elif isinstance(v, NoarchVariant): # Only noarch variants with a matching backend or "universal" # are applicable. noarch_backend_name = ( @@ -330,7 +460,7 @@ def _sort_variants( """ def sort_key(v: Variant) -> tuple: - if isinstance(v.arch, Arch): + if isinstance(v, ArchVariant): framework_order = 0 if isinstance(v.framework, Torch) else 1 if isinstance(v.arch.backend, (CUDA, ROCm, XPU, CANN)): # Order by backend version in reverse (higher is better). @@ -339,7 +469,7 @@ def sort_key(v: Variant) -> tuple: backend_order = 0 return (framework_order, backend_order) else: - assert isinstance(v.arch, Noarch) + assert isinstance(v, NoarchVariant) universal_order = 1 if v.arch.backend_name == "universal" else 0 return (2, universal_order) diff --git a/kernels/tests/test_variants.py b/kernels/tests/test_variants.py index b0d17616..cdb88bcc 100644 --- a/kernels/tests/test_variants.py +++ b/kernels/tests/test_variants.py @@ -3,7 +3,13 @@ from packaging.version import Version from kernels.backends import CPU, CUDA, ROCm -from kernels.variants import Variant, _resolve_variant_for_system, get_variants +from kernels.variants import ( + _resolve_variant_for_system, + get_variants, + parse_variant, + resolve_variants, + system_variants, +) VARIANT_STRINGS = [ "torch25-cxx98-cu118-aarch64-linux", @@ -69,6 +75,17 @@ "torch29-cxx11-rocm63-x86_64-linux", "torch29-cxx11-rocm64-x86_64-linux", "torch29-cxx11-xpu20252-x86_64-linux", + "torch29-cpu-aarch64-linux", + "torch29-cpu-x86_64-linux", + "torch29-cu126-aarch64-linux", + "torch29-cu126-x86_64-linux", + "torch29-cu128-aarch64-linux", + "torch29-cu128-x86_64-linux", + "torch29-cu130-aarch64-linux", + "torch29-cu130-x86_64-linux", + "torch29-rocm63-x86_64-linux", + "torch29-rocm64-x86_64-linux", + "torch29-xpu20252-x86_64-linux", "torch29-metal-aarch64-darwin", "torch210-cpu-aarch64-darwin", "torch210-cu128-x86_64-windows", @@ -83,6 +100,17 @@ "torch210-cxx11-rocm70-x86_64-linux", "torch210-cxx11-rocm71-x86_64-linux", "torch210-cxx11-xpu20253-x86_64-linux", + "torch210-cpu-aarch64-linux", + "torch210-cpu-x86_64-linux", + "torch210-cu126-aarch64-linux", + "torch210-cu126-x86_64-linux", + "torch210-cu128-aarch64-linux", + "torch210-cu128-x86_64-linux", + "torch210-cu130-aarch64-linux", + "torch210-cu130-x86_64-linux", + "torch210-rocm70-x86_64-linux", + "torch210-rocm71-x86_64-linux", + "torch210-xpu20253-x86_64-linux", "torch210-metal-aarch64-darwin", "torch210-xpu20253-x86_64-windows", ] @@ -91,13 +119,13 @@ @pytest.mark.parametrize("variant_str", VARIANT_STRINGS) def test_arch_variants(variant_str: str): # Roundtrip parse and generate variant string. - assert Variant.parse(variant_str).variant_str == variant_str + assert parse_variant(variant_str).variant_str == variant_str @pytest.mark.parametrize("variant_str", NOARCH_VARIANT_STRINGS) def test_noarch_variants(variant_str: str): # Roundtrip parse and generate variant string. - assert Variant.parse(variant_str).variant_str == variant_str + assert parse_variant(variant_str).variant_str == variant_str def test_get_variants(): @@ -109,7 +137,7 @@ def test_get_variants(): RESOLVE_VARIANTS = [ - Variant.parse(s) + parse_variant(s) for s in [ "torch210-cxx11-cu128-x86_64-linux", "torch210-cxx11-cu126-x86_64-linux", @@ -269,7 +297,7 @@ def test_resolve_no_match(): RESOLVE_VARIANTS_UNIVERSAL = [ - Variant.parse(s) + parse_variant(s) for s in [ "torch210-cxx11-cu128-x86_64-linux", "torch-universal", @@ -309,7 +337,7 @@ def test_resolve_universal_is_last_resort(): def test_resolve_specific_noarch_preferred_over_universal(): # Backend-specific noarch is preferred over universal. - variants = [Variant.parse(s) for s in ["torch-universal", "torch-cuda"]] + variants = [parse_variant(s) for s in ["torch-universal", "torch-cuda"]] result = _resolve_variant_for_system( variants=variants, selected_backend=CUDA(Version("12.8")), @@ -324,7 +352,7 @@ def test_resolve_specific_noarch_preferred_over_universal(): RESOLVE_VARIANTS_NO_NOARCH = [ - Variant.parse(s) + parse_variant(s) for s in [ "torch210-cxx11-cu126-x86_64-linux", "torch210-cxx11-cu128-x86_64-linux", @@ -359,3 +387,23 @@ def test_resolve_cuda_no_different_major_no_noarch(): tvm_ffi_version=None, ) assert result == [] + + +def test_possible_variants_roundtrip(): + """Every variant produced by possible_variants() should round-trip through parse.""" + variants = system_variants() + for v in variants: + assert parse_variant(v.variant_str).variant_str == v.variant_str + + +def test_possible_variants_no_duplicates(): + variants = system_variants() + variant_strs = [v.variant_str for v in variants] + assert len(variant_strs) == len(set(variant_strs)) + + +def test_possible_variants_all_resolve(): + """All generated variants should be accepted by resolve_variants.""" + variants = system_variants() + resolved = resolve_variants(variants) + assert set(v.variant_str for v in resolved) == set(v.variant_str for v in variants) From 54351d3640804118d5fb6a6c722233ae2d3ae431 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 1 Apr 2026 14:13:04 +0000 Subject: [PATCH 2/2] Test fixes --- kernels/tests/test_variants.py | 52 ++++++++++++++++------------------ 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/kernels/tests/test_variants.py b/kernels/tests/test_variants.py index cdb88bcc..4673fbe7 100644 --- a/kernels/tests/test_variants.py +++ b/kernels/tests/test_variants.py @@ -26,6 +26,17 @@ "torch29-cxx11-rocm63-x86_64-linux", "torch29-cxx11-rocm64-x86_64-linux", "torch29-cxx11-xpu20252-x86_64-linux", + "torch29-cpu-aarch64-linux", + "torch29-cpu-x86_64-linux", + "torch29-cu126-aarch64-linux", + "torch29-cu126-x86_64-linux", + "torch29-cu128-aarch64-linux", + "torch29-cu128-x86_64-linux", + "torch29-cu130-aarch64-linux", + "torch29-cu130-x86_64-linux", + "torch29-rocm63-x86_64-linux", + "torch29-rocm64-x86_64-linux", + "torch29-xpu20252-x86_64-linux", "torch29-metal-aarch64-darwin", "torch210-cpu-aarch64-darwin", "torch210-cu128-x86_64-windows", @@ -40,6 +51,17 @@ "torch210-cxx11-rocm70-x86_64-linux", "torch210-cxx11-rocm71-x86_64-linux", "torch210-cxx11-xpu20253-x86_64-linux", + "torch210-cpu-aarch64-linux", + "torch210-cpu-x86_64-linux", + "torch210-cu126-aarch64-linux", + "torch210-cu126-x86_64-linux", + "torch210-cu128-aarch64-linux", + "torch210-cu128-x86_64-linux", + "torch210-cu130-aarch64-linux", + "torch210-cu130-x86_64-linux", + "torch210-rocm70-x86_64-linux", + "torch210-rocm71-x86_64-linux", + "torch210-xpu20253-x86_64-linux", "torch210-metal-aarch64-darwin", "torch210-xpu20253-x86_64-windows", "tvm-ffi01-cpu-x86_64-linux", @@ -75,17 +97,6 @@ "torch29-cxx11-rocm63-x86_64-linux", "torch29-cxx11-rocm64-x86_64-linux", "torch29-cxx11-xpu20252-x86_64-linux", - "torch29-cpu-aarch64-linux", - "torch29-cpu-x86_64-linux", - "torch29-cu126-aarch64-linux", - "torch29-cu126-x86_64-linux", - "torch29-cu128-aarch64-linux", - "torch29-cu128-x86_64-linux", - "torch29-cu130-aarch64-linux", - "torch29-cu130-x86_64-linux", - "torch29-rocm63-x86_64-linux", - "torch29-rocm64-x86_64-linux", - "torch29-xpu20252-x86_64-linux", "torch29-metal-aarch64-darwin", "torch210-cpu-aarch64-darwin", "torch210-cu128-x86_64-windows", @@ -100,17 +111,6 @@ "torch210-cxx11-rocm70-x86_64-linux", "torch210-cxx11-rocm71-x86_64-linux", "torch210-cxx11-xpu20253-x86_64-linux", - "torch210-cpu-aarch64-linux", - "torch210-cpu-x86_64-linux", - "torch210-cu126-aarch64-linux", - "torch210-cu126-x86_64-linux", - "torch210-cu128-aarch64-linux", - "torch210-cu128-x86_64-linux", - "torch210-cu130-aarch64-linux", - "torch210-cu130-x86_64-linux", - "torch210-rocm70-x86_64-linux", - "torch210-rocm71-x86_64-linux", - "torch210-xpu20253-x86_64-linux", "torch210-metal-aarch64-darwin", "torch210-xpu20253-x86_64-windows", ] @@ -389,21 +389,19 @@ def test_resolve_cuda_no_different_major_no_noarch(): assert result == [] -def test_possible_variants_roundtrip(): - """Every variant produced by possible_variants() should round-trip through parse.""" +def test_system_variants_roundtrip(): variants = system_variants() for v in variants: assert parse_variant(v.variant_str).variant_str == v.variant_str -def test_possible_variants_no_duplicates(): +def test_system_variants_no_duplicates(): variants = system_variants() variant_strs = [v.variant_str for v in variants] assert len(variant_strs) == len(set(variant_strs)) -def test_possible_variants_all_resolve(): - """All generated variants should be accepted by resolve_variants.""" +def test_system_variants_all_resolve(): variants = system_variants() resolved = resolve_variants(variants) assert set(v.variant_str for v in resolved) == set(v.variant_str for v in variants)