1+ import itertools
12import logging
23import platform
34import re
1617 XPU ,
1718 Backend ,
1819 ROCm ,
20+ _backend ,
1921 _select_backend ,
2022 parse_backend ,
2123)
2830
2931@dataclass (unsafe_hash = True )
3032class Torch :
31- _VARIANT_REGEX : ClassVar [ re . Pattern ] = re . compile ( r"torch(\d+?)(\d+)" )
33+ """Versioned Torch framework (arch variants)."""
3234
33- version : Version | None
35+ _VARIANT_REGEX : ClassVar [re .Pattern ] = re .compile (
36+ r"torch(\d+?)(\d+)(?:-(cxx11|cxx98))?"
37+ )
38+
39+ version : Version
40+ cxx11_abi : bool | None
41+
42+ @staticmethod
43+ def possible_variants () -> list ["Torch" ]:
44+ if has_torch :
45+ import torch
46+
47+ torch_version = parse (torch .__version__ )
48+ torch_version = Version (f"{ torch_version .major } .{ torch_version .minor } " )
49+
50+ os_ = platform .system ().lower ()
51+ if os_ == "linux" :
52+ cxx11_abi = torch .compiled_with_cxx11_abi ()
53+ return [
54+ Torch (version = torch_version , cxx11_abi = cxx11_abi ),
55+ Torch (version = torch_version , cxx11_abi = None ),
56+ ]
57+ else :
58+ return [Torch (version = torch_version , cxx11_abi = None )]
59+ else :
60+ return []
3461
3562 @property
3663 def variant_str (self ) -> str :
37- if self .version is None :
38- return "torch"
39- return f"torch{ self .version .major } { self .version .minor } "
64+ base = f"torch{ self .version .major } { self .version .minor } "
65+ if self .cxx11_abi is None :
66+ return base
67+ return f"{ base } -{ 'cxx11' if self .cxx11_abi else 'cxx98' } "
4068
4169 @staticmethod
4270 def parse (s : str ) -> "Torch" :
43- if s == "torch" :
44- return Torch (version = None )
4571 m = Torch ._VARIANT_REGEX .fullmatch (s )
4672 if not m :
4773 raise ValueError (f"Invalid Torch variant string: { s !r} " )
48- return Torch (version = Version (f"{ m .group (1 )} .{ m .group (2 )} " ))
74+ version = Version (f"{ m .group (1 )} .{ m .group (2 )} " )
75+ abi_str = m .group (3 )
76+ if abi_str is None :
77+ cxx11_abi = None
78+ else :
79+ cxx11_abi = abi_str != "cxx98"
80+ return Torch (version = version , cxx11_abi = cxx11_abi )
4981
5082
5183@dataclass (unsafe_hash = True )
5284class TvmFfi :
85+ """Versioned tvm-ffi framework (arch variants)."""
86+
5387 _VARIANT_REGEX : ClassVar [re .Pattern ] = re .compile (r"tvm-ffi(\d+?)(\d+)" )
5488
5589 version : Version
5690
91+ @staticmethod
92+ def possible_variants () -> list ["TvmFfi" ]:
93+ if has_tvm_ffi :
94+ import tvm_ffi
95+
96+ tvm_ffi_version = parse (tvm_ffi .__version__ )
97+ tvm_ffi_version = Version (
98+ f"{ tvm_ffi_version .major } .{ tvm_ffi_version .minor } "
99+ )
100+ return [TvmFfi (version = tvm_ffi_version )]
101+ else :
102+ return []
103+
57104 @property
58105 def variant_str (self ) -> str :
59106 return f"tvm-ffi{ self .version .major } { self .version .minor } "
@@ -66,41 +113,60 @@ def parse(s: str) -> "TvmFfi":
66113 return TvmFfi (version = Version (f"{ m .group (1 )} .{ m .group (2 )} " ))
67114
68115
116+ @strict
117+ @dataclass (unsafe_hash = True )
118+ class TorchNoarch :
119+ """Versionless Torch framework (noarch variants)."""
120+
121+ @staticmethod
122+ def possible_variants () -> list ["TorchNoarch" ]:
123+ if has_torch :
124+ return [TorchNoarch ()]
125+ else :
126+ return []
127+
128+ @property
129+ def variant_str (self ) -> str :
130+ return "torch"
131+
132+
69133@strict
70134@dataclass (unsafe_hash = True )
71135class Arch :
72- """Aarch kernel information."""
136+ """Arch kernel information."""
73137
74138 backend : Backend
75139 platform : str
76140 os : str
77- cxx11_abi : bool | None
78141
79142 @property
80143 def variant_str (self ) -> str :
81- if self .cxx11_abi is None :
82- return f"{ self .backend .variant_str } -{ self .platform } -{ self .os } "
83- else :
84- return f"{ 'cxx11' if self .cxx11_abi else 'cxx98' } -{ self .backend .variant_str } -{ self .platform } -{ self .os } "
144+ return f"{ self .backend .variant_str } -{ self .platform } -{ self .os } "
145+
146+ @staticmethod
147+ def possible_variants () -> list ["Arch" ]:
148+ cpu = platform .machine ()
149+ os = platform .system ().lower ()
150+
151+ if os == "darwin" :
152+ cpu = "aarch64" if cpu == "arm64" else cpu
153+ elif os == "windows" :
154+ cpu = "x86_64" if cpu == "AMD64" else cpu
155+
156+ backend = _backend ()
157+
158+ return [Arch (backend = backend , platform = cpu , os = os )]
85159
86160 @staticmethod
87161 def parse (parts : list [str ]) -> "Arch" :
88- # Handle Linux with cxx11 marker.
89- if len (parts ) == 4 :
90- # In the future, we want to remove the marker and use cxx11 as
91- # the default. We check on cxx98 for this reason.
92- cxx11_abi = parts [0 ] != "cxx98"
93- parts = parts [1 :]
94- elif len (parts ) == 3 :
95- cxx11_abi = None
96- else :
162+ if len (parts ) != 3 :
97163 raise ValueError (f"Invalid arch variant parts: { parts !r} " )
98164
99165 backend = parse_backend (parts [0 ])
100166 platform = parts [1 ]
101167 os = parts [2 ]
102168
103- return Arch (backend = backend , platform = platform , os = os , cxx11_abi = cxx11_abi )
169+ return Arch (backend = backend , platform = platform , os = os )
104170
105171
106172@strict
@@ -110,6 +176,13 @@ class Noarch:
110176
111177 backend_name : str
112178
179+ @staticmethod
180+ def possible_variants () -> list ["Noarch" ]:
181+ backend = _backend ()
182+ noarch_backend_name = "npu" if backend .name == "cann" else backend .name
183+ names = {noarch_backend_name , "universal" }
184+ return [Noarch (backend_name = name ) for name in sorted (names )]
185+
113186 @property
114187 def variant_str (self ) -> str :
115188 return self .backend_name
@@ -121,37 +194,92 @@ def parse(s: str) -> "Noarch":
121194
122195@strict
123196@dataclass (unsafe_hash = True )
124- class Variant :
125- """Kernel build variant."""
197+ class ArchVariant :
198+ """Arch kernel build variant."""
126199
127200 framework : Torch | TvmFfi
128- arch : Arch | Noarch
201+ arch : Arch
202+
203+ @staticmethod
204+ def possible_variants () -> list ["ArchVariant" ]:
205+ frameworks : list [Torch | TvmFfi ] = (
206+ Torch .possible_variants () + TvmFfi .possible_variants ()
207+ )
208+ archs = Arch .possible_variants ()
209+ return [
210+ ArchVariant (framework = fw , arch = arch )
211+ for fw , arch in itertools .product (frameworks , archs )
212+ ]
129213
130214 @property
131215 def variant_str (self ) -> str :
132216 return f"{ self .framework .variant_str } -{ self .arch .variant_str } "
133217
218+
219+ @strict
220+ @dataclass (unsafe_hash = True )
221+ class NoarchVariant :
222+ """Noarch kernel build variant."""
223+
224+ framework : TorchNoarch
225+ arch : Noarch
226+
134227 @staticmethod
135- def parse (variant_str : str ) -> "Variant" :
136- parts = variant_str .split ("-" )
137-
138- arch : Arch | Noarch
139- framework : Torch | TvmFfi
140-
141- if parts [0 ] == "torch" :
142- # noarch: e.g. "torch-cpu"
143- framework = Torch .parse (parts [0 ])
144- arch = Noarch .parse ("-" .join (parts [1 :]))
145- elif parts [0 ].startswith ("torch" ):
146- framework = Torch .parse (parts [0 ])
147- arch = Arch .parse (parts [1 :])
148- elif parts [0 ] == "tvm" and parts [1 ].startswith ("ffi" ):
149- framework = TvmFfi .parse (f"tvm-{ parts [1 ]} " )
150- arch = Arch .parse (parts [2 :])
151- else :
152- raise ValueError (f"Unknown framework in variant string: { variant_str !r} " )
228+ def possible_variants () -> list ["NoarchVariant" ]:
229+ frameworks = TorchNoarch .possible_variants ()
230+ archs = Noarch .possible_variants ()
231+ return [
232+ NoarchVariant (framework = fw , arch = arch )
233+ for fw , arch in itertools .product (frameworks , archs )
234+ ]
235+
236+ @property
237+ def variant_str (self ) -> str :
238+ return f"{ self .framework .variant_str } -{ self .arch .variant_str } "
239+
153240
154- return Variant (framework = framework , arch = arch )
241+ Variant = ArchVariant | NoarchVariant
242+
243+
244+ def system_variants () -> list [Variant ]:
245+ """Return all possible build variants for the current system.
246+
247+ Warning: this function should only be used internally (so don't export
248+ at the top-level) and for informational purposes, such as user
249+ feedback. When loading kernels, etc. rely what is on disk and
250+ use `parse_variant` + `resolve_variant`, since this uses our
251+ priority order, etc."""
252+ result : list [Variant ] = (
253+ ArchVariant .possible_variants () + NoarchVariant .possible_variants ()
254+ )
255+ return _sort_variants (result )
256+
257+
258+ def parse_variant (variant_str : str ) -> Variant :
259+ """Parse a variant string into an ArchVariant or NoarchVariant."""
260+ parts = variant_str .split ("-" )
261+
262+ if parts [0 ] == "torch" :
263+ # noarch: e.g. "torch-cpu"
264+ return NoarchVariant (
265+ framework = TorchNoarch (), arch = Noarch .parse ("-" .join (parts [1 :]))
266+ )
267+ elif parts [0 ].startswith ("torch" ):
268+ if len (parts ) >= 2 and parts [1 ] in ("cxx11" , "cxx98" ):
269+ framework_str = f"{ parts [0 ]} -{ parts [1 ]} "
270+ arch_parts = parts [2 :]
271+ else :
272+ framework_str = parts [0 ]
273+ arch_parts = parts [1 :]
274+ return ArchVariant (
275+ framework = Torch .parse (framework_str ), arch = Arch .parse (arch_parts )
276+ )
277+ elif parts [0 ] == "tvm" and len (parts ) >= 2 and parts [1 ].startswith ("ffi" ):
278+ return ArchVariant (
279+ framework = TvmFfi .parse (f"tvm-{ parts [1 ]} " ), arch = Arch .parse (parts [2 :])
280+ )
281+ else :
282+ raise ValueError (f"Unknown framework in variant string: { variant_str !r} " )
155283
156284
157285def get_variants (api : HfApi , * , repo_id : str , revision : str ) -> list [Variant ]:
@@ -162,10 +290,10 @@ def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]:
162290 item .path .split ("/" )[- 1 ] for item in tree if isinstance (item , RepoFolder )
163291 }
164292
165- variants = []
293+ variants : list [ Variant ] = []
166294 for variant_str in variant_strs :
167295 try :
168- variants .append (Variant . parse (variant_str ))
296+ variants .append (parse_variant (variant_str ))
169297 except ValueError :
170298 logging .warning (
171299 f"Repository { repo_id } (revision: { revision } ) contains invalid build variant variant: { variant_str !r} "
@@ -181,10 +309,10 @@ def get_variants_local(repo_path: Path) -> list[Variant]:
181309 except Exception :
182310 return []
183311
184- variants = []
312+ variants : list [ Variant ] = []
185313 for variant_str in variant_strs :
186314 try :
187- variants .append (Variant . parse (variant_str ))
315+ variants .append (parse_variant (variant_str ))
188316 except ValueError :
189317 pass
190318 return variants
@@ -228,7 +356,7 @@ def resolve_variants(
228356 if has_tvm_ffi :
229357 import tvm_ffi
230358
231- # Parse Torch version and strip patch/tags.
359+ # Parse tvm-ffi version and strip patch/tags.
232360 tvm_ffi_version = parse (tvm_ffi .__version__ )
233361 tvm_ffi_version = Version (f"{ tvm_ffi_version .major } .{ tvm_ffi_version .minor } " )
234362
@@ -275,9 +403,9 @@ def _filter_variants(
275403 tvm_ffi_version : Version | None ,
276404) -> list [Variant ]:
277405 """Return only the variants applicable to the current system."""
278- result = []
406+ result : list [ Variant ] = []
279407 for v in variants :
280- if isinstance (v . arch , Arch ):
408+ if isinstance (v , ArchVariant ):
281409 # Skip non-matching CPU or OS.
282410 if v .arch .platform != cpu or v .arch .os != os :
283411 continue
@@ -286,7 +414,10 @@ def _filter_variants(
286414 if isinstance (v .framework , Torch ):
287415 if v .framework .version != torch_version :
288416 continue
289- if v .arch .cxx11_abi != torch_cxx11_abi :
417+ if (
418+ v .framework .cxx11_abi is not None
419+ and v .framework .cxx11_abi != torch_cxx11_abi
420+ ):
290421 continue
291422 elif isinstance (v .framework , TvmFfi ):
292423 if v .framework .version != tvm_ffi_version :
@@ -302,8 +433,7 @@ def _filter_variants(
302433 continue
303434 elif v .arch .backend .variant_str != selected_backend .variant_str :
304435 continue
305- else :
306- assert isinstance (v .arch , Noarch )
436+ elif isinstance (v , NoarchVariant ):
307437 # Only noarch variants with a matching backend or "universal"
308438 # are applicable.
309439 noarch_backend_name = (
@@ -330,7 +460,7 @@ def _sort_variants(
330460 """
331461
332462 def sort_key (v : Variant ) -> tuple :
333- if isinstance (v . arch , Arch ):
463+ if isinstance (v , ArchVariant ):
334464 framework_order = 0 if isinstance (v .framework , Torch ) else 1
335465 if isinstance (v .arch .backend , (CUDA , ROCm , XPU , CANN )):
336466 # Order by backend version in reverse (higher is better).
@@ -339,7 +469,7 @@ def sort_key(v: Variant) -> tuple:
339469 backend_order = 0
340470 return (framework_order , backend_order )
341471 else :
342- assert isinstance (v . arch , Noarch )
472+ assert isinstance (v , NoarchVariant )
343473 universal_order = 1 if v .arch .backend_name == "universal" else 0
344474 return (2 , universal_order )
345475
0 commit comments