Skip to content

Commit c53cde0

Browse files
committed
kernels: add a function system_variants
This function lists all the variants that are available on the current system.
1 parent 3def540 commit c53cde0

2 files changed

Lines changed: 221 additions & 65 deletions

File tree

kernels/src/kernels/variants.py

Lines changed: 188 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import logging
23
import platform
34
import re
@@ -16,6 +17,7 @@
1617
XPU,
1718
Backend,
1819
ROCm,
20+
_backend,
1921
_select_backend,
2022
parse_backend,
2123
)
@@ -28,32 +30,77 @@
2830

2931
@dataclass(unsafe_hash=True)
3032
class Torch:
31-
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"torch(\d+?)(\d+)")
33+
"""Versioned Torch framework (arch variants)."""
3234

33-
version: Version | None
35+
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(
36+
r"torch(\d+?)(\d+)(?:-(cxx11|cxx98))?"
37+
)
38+
39+
version: Version
40+
cxx11_abi: bool | None
41+
42+
@staticmethod
43+
def possible_variants() -> list["Torch"]:
44+
if has_torch:
45+
import torch
46+
47+
torch_version = parse(torch.__version__)
48+
torch_version = Version(f"{torch_version.major}.{torch_version.minor}")
49+
50+
os_ = platform.system().lower()
51+
if os_ == "linux":
52+
cxx11_abi = torch.compiled_with_cxx11_abi()
53+
return [
54+
Torch(version=torch_version, cxx11_abi=cxx11_abi),
55+
Torch(version=torch_version, cxx11_abi=None),
56+
]
57+
else:
58+
return [Torch(version=torch_version, cxx11_abi=None)]
59+
else:
60+
return []
3461

3562
@property
3663
def variant_str(self) -> str:
37-
if self.version is None:
38-
return "torch"
39-
return f"torch{self.version.major}{self.version.minor}"
64+
base = f"torch{self.version.major}{self.version.minor}"
65+
if self.cxx11_abi is None:
66+
return base
67+
return f"{base}-{'cxx11' if self.cxx11_abi else 'cxx98'}"
4068

4169
@staticmethod
4270
def parse(s: str) -> "Torch":
43-
if s == "torch":
44-
return Torch(version=None)
4571
m = Torch._VARIANT_REGEX.fullmatch(s)
4672
if not m:
4773
raise ValueError(f"Invalid Torch variant string: {s!r}")
48-
return Torch(version=Version(f"{m.group(1)}.{m.group(2)}"))
74+
version = Version(f"{m.group(1)}.{m.group(2)}")
75+
abi_str = m.group(3)
76+
if abi_str is None:
77+
cxx11_abi = None
78+
else:
79+
cxx11_abi = abi_str != "cxx98"
80+
return Torch(version=version, cxx11_abi=cxx11_abi)
4981

5082

5183
@dataclass(unsafe_hash=True)
5284
class TvmFfi:
85+
"""Versioned tvm-ffi framework (arch variants)."""
86+
5387
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"tvm-ffi(\d+?)(\d+)")
5488

5589
version: Version
5690

91+
@staticmethod
92+
def possible_variants() -> list["TvmFfi"]:
93+
if has_tvm_ffi:
94+
import tvm_ffi
95+
96+
tvm_ffi_version = parse(tvm_ffi.__version__)
97+
tvm_ffi_version = Version(
98+
f"{tvm_ffi_version.major}.{tvm_ffi_version.minor}"
99+
)
100+
return [TvmFfi(version=tvm_ffi_version)]
101+
else:
102+
return []
103+
57104
@property
58105
def variant_str(self) -> str:
59106
return f"tvm-ffi{self.version.major}{self.version.minor}"
@@ -66,41 +113,60 @@ def parse(s: str) -> "TvmFfi":
66113
return TvmFfi(version=Version(f"{m.group(1)}.{m.group(2)}"))
67114

68115

116+
@strict
117+
@dataclass(unsafe_hash=True)
118+
class TorchNoarch:
119+
"""Versionless Torch framework (noarch variants)."""
120+
121+
@staticmethod
122+
def possible_variants() -> list["TorchNoarch"]:
123+
if has_torch:
124+
return [TorchNoarch()]
125+
else:
126+
return []
127+
128+
@property
129+
def variant_str(self) -> str:
130+
return "torch"
131+
132+
69133
@strict
70134
@dataclass(unsafe_hash=True)
71135
class Arch:
72-
"""Aarch kernel information."""
136+
"""Arch kernel information."""
73137

74138
backend: Backend
75139
platform: str
76140
os: str
77-
cxx11_abi: bool | None
78141

79142
@property
80143
def variant_str(self) -> str:
81-
if self.cxx11_abi is None:
82-
return f"{self.backend.variant_str}-{self.platform}-{self.os}"
83-
else:
84-
return f"{'cxx11' if self.cxx11_abi else 'cxx98'}-{self.backend.variant_str}-{self.platform}-{self.os}"
144+
return f"{self.backend.variant_str}-{self.platform}-{self.os}"
145+
146+
@staticmethod
147+
def possible_variants() -> list["Arch"]:
148+
cpu = platform.machine()
149+
os = platform.system().lower()
150+
151+
if os == "darwin":
152+
cpu = "aarch64" if cpu == "arm64" else cpu
153+
elif os == "windows":
154+
cpu = "x86_64" if cpu == "AMD64" else cpu
155+
156+
backend = _backend()
157+
158+
return [Arch(backend=backend, platform=cpu, os=os)]
85159

86160
@staticmethod
87161
def parse(parts: list[str]) -> "Arch":
88-
# Handle Linux with cxx11 marker.
89-
if len(parts) == 4:
90-
# In the future, we want to remove the marker and use cxx11 as
91-
# the default. We check on cxx98 for this reason.
92-
cxx11_abi = parts[0] != "cxx98"
93-
parts = parts[1:]
94-
elif len(parts) == 3:
95-
cxx11_abi = None
96-
else:
162+
if len(parts) != 3:
97163
raise ValueError(f"Invalid arch variant parts: {parts!r}")
98164

99165
backend = parse_backend(parts[0])
100166
platform = parts[1]
101167
os = parts[2]
102168

103-
return Arch(backend=backend, platform=platform, os=os, cxx11_abi=cxx11_abi)
169+
return Arch(backend=backend, platform=platform, os=os)
104170

105171

106172
@strict
@@ -110,6 +176,13 @@ class Noarch:
110176

111177
backend_name: str
112178

179+
@staticmethod
180+
def possible_variants() -> list["Noarch"]:
181+
backend = _backend()
182+
noarch_backend_name = "npu" if backend.name == "cann" else backend.name
183+
names = {noarch_backend_name, "universal"}
184+
return [Noarch(backend_name=name) for name in sorted(names)]
185+
113186
@property
114187
def variant_str(self) -> str:
115188
return self.backend_name
@@ -121,37 +194,92 @@ def parse(s: str) -> "Noarch":
121194

122195
@strict
123196
@dataclass(unsafe_hash=True)
124-
class Variant:
125-
"""Kernel build variant."""
197+
class ArchVariant:
198+
"""Arch kernel build variant."""
126199

127200
framework: Torch | TvmFfi
128-
arch: Arch | Noarch
201+
arch: Arch
202+
203+
@staticmethod
204+
def possible_variants() -> list["ArchVariant"]:
205+
frameworks: list[Torch | TvmFfi] = (
206+
Torch.possible_variants() + TvmFfi.possible_variants()
207+
)
208+
archs = Arch.possible_variants()
209+
return [
210+
ArchVariant(framework=fw, arch=arch)
211+
for fw, arch in itertools.product(frameworks, archs)
212+
]
129213

130214
@property
131215
def variant_str(self) -> str:
132216
return f"{self.framework.variant_str}-{self.arch.variant_str}"
133217

218+
219+
@strict
220+
@dataclass(unsafe_hash=True)
221+
class NoarchVariant:
222+
"""Noarch kernel build variant."""
223+
224+
framework: TorchNoarch
225+
arch: Noarch
226+
134227
@staticmethod
135-
def parse(variant_str: str) -> "Variant":
136-
parts = variant_str.split("-")
137-
138-
arch: Arch | Noarch
139-
framework: Torch | TvmFfi
140-
141-
if parts[0] == "torch":
142-
# noarch: e.g. "torch-cpu"
143-
framework = Torch.parse(parts[0])
144-
arch = Noarch.parse("-".join(parts[1:]))
145-
elif parts[0].startswith("torch"):
146-
framework = Torch.parse(parts[0])
147-
arch = Arch.parse(parts[1:])
148-
elif parts[0] == "tvm" and parts[1].startswith("ffi"):
149-
framework = TvmFfi.parse(f"tvm-{parts[1]}")
150-
arch = Arch.parse(parts[2:])
151-
else:
152-
raise ValueError(f"Unknown framework in variant string: {variant_str!r}")
228+
def possible_variants() -> list["NoarchVariant"]:
229+
frameworks = TorchNoarch.possible_variants()
230+
archs = Noarch.possible_variants()
231+
return [
232+
NoarchVariant(framework=fw, arch=arch)
233+
for fw, arch in itertools.product(frameworks, archs)
234+
]
235+
236+
@property
237+
def variant_str(self) -> str:
238+
return f"{self.framework.variant_str}-{self.arch.variant_str}"
239+
153240

154-
return Variant(framework=framework, arch=arch)
241+
Variant = ArchVariant | NoarchVariant
242+
243+
244+
def system_variants() -> list[Variant]:
245+
"""Return all possible build variants for the current system.
246+
247+
Warning: this function should only be used internally (so don't export
248+
at the top-level) and for informational purposes, such as user
249+
feedback. When loading kernels, etc. rely what is on disk and
250+
use `parse_variant` + `resolve_variant`, since this uses our
251+
priority order, etc."""
252+
result: list[Variant] = (
253+
ArchVariant.possible_variants() + NoarchVariant.possible_variants()
254+
)
255+
return _sort_variants(result)
256+
257+
258+
def parse_variant(variant_str: str) -> Variant:
259+
"""Parse a variant string into an ArchVariant or NoarchVariant."""
260+
parts = variant_str.split("-")
261+
262+
if parts[0] == "torch":
263+
# noarch: e.g. "torch-cpu"
264+
return NoarchVariant(
265+
framework=TorchNoarch(), arch=Noarch.parse("-".join(parts[1:]))
266+
)
267+
elif parts[0].startswith("torch"):
268+
if len(parts) >= 2 and parts[1] in ("cxx11", "cxx98"):
269+
framework_str = f"{parts[0]}-{parts[1]}"
270+
arch_parts = parts[2:]
271+
else:
272+
framework_str = parts[0]
273+
arch_parts = parts[1:]
274+
return ArchVariant(
275+
framework=Torch.parse(framework_str), arch=Arch.parse(arch_parts)
276+
)
277+
elif parts[0] == "tvm" and len(parts) >= 2 and parts[1].startswith("ffi"):
278+
return ArchVariant(
279+
framework=TvmFfi.parse(f"tvm-{parts[1]}"), arch=Arch.parse(parts[2:])
280+
)
281+
else:
282+
raise ValueError(f"Unknown framework in variant string: {variant_str!r}")
155283

156284

157285
def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]:
@@ -162,10 +290,10 @@ def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]:
162290
item.path.split("/")[-1] for item in tree if isinstance(item, RepoFolder)
163291
}
164292

165-
variants = []
293+
variants: list[Variant] = []
166294
for variant_str in variant_strs:
167295
try:
168-
variants.append(Variant.parse(variant_str))
296+
variants.append(parse_variant(variant_str))
169297
except ValueError:
170298
logging.warning(
171299
f"Repository {repo_id} (revision: {revision}) contains invalid build variant variant: {variant_str!r}"
@@ -181,10 +309,10 @@ def get_variants_local(repo_path: Path) -> list[Variant]:
181309
except Exception:
182310
return []
183311

184-
variants = []
312+
variants: list[Variant] = []
185313
for variant_str in variant_strs:
186314
try:
187-
variants.append(Variant.parse(variant_str))
315+
variants.append(parse_variant(variant_str))
188316
except ValueError:
189317
pass
190318
return variants
@@ -228,7 +356,7 @@ def resolve_variants(
228356
if has_tvm_ffi:
229357
import tvm_ffi
230358

231-
# Parse Torch version and strip patch/tags.
359+
# Parse tvm-ffi version and strip patch/tags.
232360
tvm_ffi_version = parse(tvm_ffi.__version__)
233361
tvm_ffi_version = Version(f"{tvm_ffi_version.major}.{tvm_ffi_version.minor}")
234362

@@ -275,9 +403,9 @@ def _filter_variants(
275403
tvm_ffi_version: Version | None,
276404
) -> list[Variant]:
277405
"""Return only the variants applicable to the current system."""
278-
result = []
406+
result: list[Variant] = []
279407
for v in variants:
280-
if isinstance(v.arch, Arch):
408+
if isinstance(v, ArchVariant):
281409
# Skip non-matching CPU or OS.
282410
if v.arch.platform != cpu or v.arch.os != os:
283411
continue
@@ -286,7 +414,10 @@ def _filter_variants(
286414
if isinstance(v.framework, Torch):
287415
if v.framework.version != torch_version:
288416
continue
289-
if v.arch.cxx11_abi != torch_cxx11_abi:
417+
if (
418+
v.framework.cxx11_abi is not None
419+
and v.framework.cxx11_abi != torch_cxx11_abi
420+
):
290421
continue
291422
elif isinstance(v.framework, TvmFfi):
292423
if v.framework.version != tvm_ffi_version:
@@ -302,8 +433,7 @@ def _filter_variants(
302433
continue
303434
elif v.arch.backend.variant_str != selected_backend.variant_str:
304435
continue
305-
else:
306-
assert isinstance(v.arch, Noarch)
436+
elif isinstance(v, NoarchVariant):
307437
# Only noarch variants with a matching backend or "universal"
308438
# are applicable.
309439
noarch_backend_name = (
@@ -330,7 +460,7 @@ def _sort_variants(
330460
"""
331461

332462
def sort_key(v: Variant) -> tuple:
333-
if isinstance(v.arch, Arch):
463+
if isinstance(v, ArchVariant):
334464
framework_order = 0 if isinstance(v.framework, Torch) else 1
335465
if isinstance(v.arch.backend, (CUDA, ROCm, XPU, CANN)):
336466
# Order by backend version in reverse (higher is better).
@@ -339,7 +469,7 @@ def sort_key(v: Variant) -> tuple:
339469
backend_order = 0
340470
return (framework_order, backend_order)
341471
else:
342-
assert isinstance(v.arch, Noarch)
472+
assert isinstance(v, NoarchVariant)
343473
universal_order = 1 if v.arch.backend_name == "universal" else 0
344474
return (2, universal_order)
345475

0 commit comments

Comments
 (0)