-
Notifications
You must be signed in to change notification settings - Fork 60
kernels: add a function system_variants
#423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import itertools | ||
| import logging | ||
| import platform | ||
| import re | ||
|
|
@@ -16,6 +17,7 @@ | |
| XPU, | ||
| Backend, | ||
| ROCm, | ||
| _backend, | ||
| _select_backend, | ||
| parse_backend, | ||
| ) | ||
|
|
@@ -28,32 +30,77 @@ | |
|
|
||
| @dataclass(unsafe_hash=True) | ||
| class Torch: | ||
| _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"torch(\d+?)(\d+)") | ||
| """Versioned Torch framework (arch variants).""" | ||
|
|
||
| version: Version | None | ||
| _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile( | ||
| r"torch(\d+?)(\d+)(?:-(cxx11|cxx98))?" | ||
| ) | ||
|
|
||
| version: Version | ||
| cxx11_abi: bool | None | ||
|
|
||
| @staticmethod | ||
| def possible_variants() -> list["Torch"]: | ||
| if has_torch: | ||
| import torch | ||
|
|
||
| torch_version = parse(torch.__version__) | ||
| torch_version = Version(f"{torch_version.major}.{torch_version.minor}") | ||
|
|
||
| os_ = platform.system().lower() | ||
| if os_ == "linux": | ||
| cxx11_abi = torch.compiled_with_cxx11_abi() | ||
| return [ | ||
| Torch(version=torch_version, cxx11_abi=cxx11_abi), | ||
| Torch(version=torch_version, cxx11_abi=None), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious why include |
||
| ] | ||
| else: | ||
| return [Torch(version=torch_version, cxx11_abi=None)] | ||
| else: | ||
| return [] | ||
|
|
||
| @property | ||
| def variant_str(self) -> str: | ||
| if self.version is None: | ||
| return "torch" | ||
| return f"torch{self.version.major}{self.version.minor}" | ||
| base = f"torch{self.version.major}{self.version.minor}" | ||
| if self.cxx11_abi is None: | ||
| return base | ||
| return f"{base}-{'cxx11' if self.cxx11_abi else 'cxx98'}" | ||
|
|
||
| @staticmethod | ||
| def parse(s: str) -> "Torch": | ||
| if s == "torch": | ||
| return Torch(version=None) | ||
| m = Torch._VARIANT_REGEX.fullmatch(s) | ||
| if not m: | ||
| raise ValueError(f"Invalid Torch variant string: {s!r}") | ||
| return Torch(version=Version(f"{m.group(1)}.{m.group(2)}")) | ||
| version = Version(f"{m.group(1)}.{m.group(2)}") | ||
| abi_str = m.group(3) | ||
| if abi_str is None: | ||
| cxx11_abi = None | ||
| else: | ||
| cxx11_abi = abi_str != "cxx98" | ||
| return Torch(version=version, cxx11_abi=cxx11_abi) | ||
|
|
||
|
|
||
| @dataclass(unsafe_hash=True) | ||
| class TvmFfi: | ||
| """Versioned tvm-ffi framework (arch variants).""" | ||
|
|
||
| _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"tvm-ffi(\d+?)(\d+)") | ||
|
|
||
| version: Version | ||
|
|
||
| @staticmethod | ||
| def possible_variants() -> list["TvmFfi"]: | ||
| if has_tvm_ffi: | ||
| import tvm_ffi | ||
|
|
||
| tvm_ffi_version = parse(tvm_ffi.__version__) | ||
| tvm_ffi_version = Version( | ||
| f"{tvm_ffi_version.major}.{tvm_ffi_version.minor}" | ||
| ) | ||
| return [TvmFfi(version=tvm_ffi_version)] | ||
| else: | ||
| return [] | ||
|
|
||
| @property | ||
| def variant_str(self) -> str: | ||
| return f"tvm-ffi{self.version.major}{self.version.minor}" | ||
|
|
@@ -66,41 +113,60 @@ def parse(s: str) -> "TvmFfi": | |
| return TvmFfi(version=Version(f"{m.group(1)}.{m.group(2)}")) | ||
|
|
||
|
|
||
| @strict | ||
| @dataclass(unsafe_hash=True) | ||
| class TorchNoarch: | ||
| """Versionless Torch framework (noarch variants).""" | ||
|
|
||
| @staticmethod | ||
| def possible_variants() -> list["TorchNoarch"]: | ||
| if has_torch: | ||
| return [TorchNoarch()] | ||
| else: | ||
| return [] | ||
|
|
||
| @property | ||
| def variant_str(self) -> str: | ||
| return "torch" | ||
|
|
||
|
|
||
| @strict | ||
| @dataclass(unsafe_hash=True) | ||
| class Arch: | ||
| """Aarch kernel information.""" | ||
| """Arch kernel information.""" | ||
|
|
||
| backend: Backend | ||
| platform: str | ||
| os: str | ||
| cxx11_abi: bool | None | ||
|
|
||
| @property | ||
| def variant_str(self) -> str: | ||
| if self.cxx11_abi is None: | ||
| return f"{self.backend.variant_str}-{self.platform}-{self.os}" | ||
| else: | ||
| return f"{'cxx11' if self.cxx11_abi else 'cxx98'}-{self.backend.variant_str}-{self.platform}-{self.os}" | ||
| return f"{self.backend.variant_str}-{self.platform}-{self.os}" | ||
|
|
||
| @staticmethod | ||
| def possible_variants() -> list["Arch"]: | ||
| cpu = platform.machine() | ||
| os = platform.system().lower() | ||
|
|
||
| if os == "darwin": | ||
| cpu = "aarch64" if cpu == "arm64" else cpu | ||
| elif os == "windows": | ||
| cpu = "x86_64" if cpu == "AMD64" else cpu | ||
|
|
||
| backend = _backend() | ||
|
|
||
| return [Arch(backend=backend, platform=cpu, os=os)] | ||
|
|
||
| @staticmethod | ||
| def parse(parts: list[str]) -> "Arch": | ||
| # Handle Linux with cxx11 marker. | ||
| if len(parts) == 4: | ||
| # In the future, we want to remove the marker and use cxx11 as | ||
| # the default. We check on cxx98 for this reason. | ||
| cxx11_abi = parts[0] != "cxx98" | ||
| parts = parts[1:] | ||
| elif len(parts) == 3: | ||
| cxx11_abi = None | ||
| else: | ||
| if len(parts) != 3: | ||
| raise ValueError(f"Invalid arch variant parts: {parts!r}") | ||
|
|
||
| backend = parse_backend(parts[0]) | ||
| platform = parts[1] | ||
| os = parts[2] | ||
|
|
||
| return Arch(backend=backend, platform=platform, os=os, cxx11_abi=cxx11_abi) | ||
| return Arch(backend=backend, platform=platform, os=os) | ||
|
|
||
|
|
||
| @strict | ||
|
|
@@ -110,6 +176,13 @@ class Noarch: | |
|
|
||
| backend_name: str | ||
|
|
||
| @staticmethod | ||
| def possible_variants() -> list["Noarch"]: | ||
| backend = _backend() | ||
| noarch_backend_name = "npu" if backend.name == "cann" else backend.name | ||
| names = {noarch_backend_name, "universal"} | ||
| return [Noarch(backend_name=name) for name in sorted(names)] | ||
|
|
||
| @property | ||
| def variant_str(self) -> str: | ||
| return self.backend_name | ||
|
|
@@ -121,37 +194,92 @@ def parse(s: str) -> "Noarch": | |
|
|
||
| @strict | ||
| @dataclass(unsafe_hash=True) | ||
| class Variant: | ||
| """Kernel build variant.""" | ||
| class ArchVariant: | ||
| """Arch kernel build variant.""" | ||
|
|
||
| framework: Torch | TvmFfi | ||
| arch: Arch | Noarch | ||
| arch: Arch | ||
|
|
||
| @staticmethod | ||
| def possible_variants() -> list["ArchVariant"]: | ||
| frameworks: list[Torch | TvmFfi] = ( | ||
| Torch.possible_variants() + TvmFfi.possible_variants() | ||
| ) | ||
| archs = Arch.possible_variants() | ||
| return [ | ||
| ArchVariant(framework=fw, arch=arch) | ||
| for fw, arch in itertools.product(frameworks, archs) | ||
| ] | ||
|
Comment on lines
+204
to
+212
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is so neat! As declarative as it gets. |
||
|
|
||
| @property | ||
| def variant_str(self) -> str: | ||
| return f"{self.framework.variant_str}-{self.arch.variant_str}" | ||
|
|
||
|
|
||
| @strict | ||
| @dataclass(unsafe_hash=True) | ||
| class NoarchVariant: | ||
| """Noarch kernel build variant.""" | ||
|
|
||
| framework: TorchNoarch | ||
| arch: Noarch | ||
|
|
||
| @staticmethod | ||
| def parse(variant_str: str) -> "Variant": | ||
| parts = variant_str.split("-") | ||
|
|
||
| arch: Arch | Noarch | ||
| framework: Torch | TvmFfi | ||
|
|
||
| if parts[0] == "torch": | ||
| # noarch: e.g. "torch-cpu" | ||
| framework = Torch.parse(parts[0]) | ||
| arch = Noarch.parse("-".join(parts[1:])) | ||
| elif parts[0].startswith("torch"): | ||
| framework = Torch.parse(parts[0]) | ||
| arch = Arch.parse(parts[1:]) | ||
| elif parts[0] == "tvm" and parts[1].startswith("ffi"): | ||
| framework = TvmFfi.parse(f"tvm-{parts[1]}") | ||
| arch = Arch.parse(parts[2:]) | ||
| else: | ||
| raise ValueError(f"Unknown framework in variant string: {variant_str!r}") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, we're removing |
||
| def possible_variants() -> list["NoarchVariant"]: | ||
| frameworks = TorchNoarch.possible_variants() | ||
| archs = Noarch.possible_variants() | ||
| return [ | ||
| NoarchVariant(framework=fw, arch=arch) | ||
| for fw, arch in itertools.product(frameworks, archs) | ||
| ] | ||
|
|
||
| @property | ||
| def variant_str(self) -> str: | ||
| return f"{self.framework.variant_str}-{self.arch.variant_str}" | ||
|
|
||
|
|
||
| return Variant(framework=framework, arch=arch) | ||
| Variant = ArchVariant | NoarchVariant | ||
|
|
||
|
|
||
| def system_variants() -> list[Variant]: | ||
| """Return all possible build variants for the current system. | ||
|
|
||
| Warning: this function should only be used internally (so don't export | ||
| at the top-level) and for informational purposes, such as user | ||
| feedback. When loading kernels, etc. rely what is on disk and | ||
| use `parse_variant` + `resolve_variant`, since this uses our | ||
| priority order, etc.""" | ||
| result: list[Variant] = ( | ||
| ArchVariant.possible_variants() + NoarchVariant.possible_variants() | ||
| ) | ||
| return _sort_variants(result) | ||
|
|
||
|
|
||
| def parse_variant(variant_str: str) -> Variant: | ||
| """Parse a variant string into an ArchVariant or NoarchVariant.""" | ||
| parts = variant_str.split("-") | ||
|
|
||
| if parts[0] == "torch": | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should there be a constraint on the length of |
||
| # noarch: e.g. "torch-cpu" | ||
| return NoarchVariant( | ||
| framework=TorchNoarch(), arch=Noarch.parse("-".join(parts[1:])) | ||
| ) | ||
| elif parts[0].startswith("torch"): | ||
| if len(parts) >= 2 and parts[1] in ("cxx11", "cxx98"): | ||
| framework_str = f"{parts[0]}-{parts[1]}" | ||
| arch_parts = parts[2:] | ||
| else: | ||
| framework_str = parts[0] | ||
| arch_parts = parts[1:] | ||
| return ArchVariant( | ||
| framework=Torch.parse(framework_str), arch=Arch.parse(arch_parts) | ||
| ) | ||
| elif parts[0] == "tvm" and len(parts) >= 2 and parts[1].startswith("ffi"): | ||
| return ArchVariant( | ||
| framework=TvmFfi.parse(f"tvm-{parts[1]}"), arch=Arch.parse(parts[2:]) | ||
| ) | ||
|
Comment on lines
+262
to
+280
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This conditional logic feels a bit rigid, but I guess it's purely heuristics-driven and isn't meant to change? If it changes then there's something wrong. |
||
| else: | ||
| raise ValueError(f"Unknown framework in variant string: {variant_str!r}") | ||
|
|
||
|
|
||
| def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]: | ||
|
|
@@ -162,10 +290,10 @@ def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]: | |
| item.path.split("/")[-1] for item in tree if isinstance(item, RepoFolder) | ||
| } | ||
|
|
||
| variants = [] | ||
| variants: list[Variant] = [] | ||
| for variant_str in variant_strs: | ||
| try: | ||
| variants.append(Variant.parse(variant_str)) | ||
| variants.append(parse_variant(variant_str)) | ||
| except ValueError: | ||
| logging.warning( | ||
| f"Repository {repo_id} (revision: {revision}) contains invalid build variant variant: {variant_str!r}" | ||
|
|
@@ -181,10 +309,10 @@ def get_variants_local(repo_path: Path) -> list[Variant]: | |
| except Exception: | ||
| return [] | ||
|
|
||
| variants = [] | ||
| variants: list[Variant] = [] | ||
| for variant_str in variant_strs: | ||
| try: | ||
| variants.append(Variant.parse(variant_str)) | ||
| variants.append(parse_variant(variant_str)) | ||
| except ValueError: | ||
| pass | ||
| return variants | ||
|
|
@@ -228,7 +356,7 @@ def resolve_variants( | |
| if has_tvm_ffi: | ||
| import tvm_ffi | ||
|
|
||
| # Parse Torch version and strip patch/tags. | ||
| # Parse tvm-ffi version and strip patch/tags. | ||
| tvm_ffi_version = parse(tvm_ffi.__version__) | ||
| tvm_ffi_version = Version(f"{tvm_ffi_version.major}.{tvm_ffi_version.minor}") | ||
|
|
||
|
|
@@ -275,9 +403,9 @@ def _filter_variants( | |
| tvm_ffi_version: Version | None, | ||
| ) -> list[Variant]: | ||
| """Return only the variants applicable to the current system.""" | ||
| result = [] | ||
| result: list[Variant] = [] | ||
| for v in variants: | ||
| if isinstance(v.arch, Arch): | ||
| if isinstance(v, ArchVariant): | ||
| # Skip non-matching CPU or OS. | ||
| if v.arch.platform != cpu or v.arch.os != os: | ||
| continue | ||
|
|
@@ -286,7 +414,10 @@ def _filter_variants( | |
| if isinstance(v.framework, Torch): | ||
| if v.framework.version != torch_version: | ||
| continue | ||
| if v.arch.cxx11_abi != torch_cxx11_abi: | ||
| if ( | ||
| v.framework.cxx11_abi is not None | ||
| and v.framework.cxx11_abi != torch_cxx11_abi | ||
| ): | ||
| continue | ||
| elif isinstance(v.framework, TvmFfi): | ||
| if v.framework.version != tvm_ffi_version: | ||
|
|
@@ -302,8 +433,7 @@ def _filter_variants( | |
| continue | ||
| elif v.arch.backend.variant_str != selected_backend.variant_str: | ||
| continue | ||
| else: | ||
| assert isinstance(v.arch, Noarch) | ||
| elif isinstance(v, NoarchVariant): | ||
| # Only noarch variants with a matching backend or "universal" | ||
| # are applicable. | ||
| noarch_backend_name = ( | ||
|
|
@@ -330,7 +460,7 @@ def _sort_variants( | |
| """ | ||
|
|
||
| def sort_key(v: Variant) -> tuple: | ||
| if isinstance(v.arch, Arch): | ||
| if isinstance(v, ArchVariant): | ||
| framework_order = 0 if isinstance(v.framework, Torch) else 1 | ||
| if isinstance(v.arch.backend, (CUDA, ROCm, XPU, CANN)): | ||
| # Order by backend version in reverse (higher is better). | ||
|
|
@@ -339,7 +469,7 @@ def sort_key(v: Variant) -> tuple: | |
| backend_order = 0 | ||
| return (framework_order, backend_order) | ||
| else: | ||
| assert isinstance(v.arch, Noarch) | ||
| assert isinstance(v, NoarchVariant) | ||
| universal_order = 1 if v.arch.backend_name == "universal" else 0 | ||
| return (2, universal_order) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prepare for variants without an ABI tag (see extra variants in tests).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I think one liner explainer comment would be nice. Bits around
(cxx11|cxx98)aren't particularly clear I think.