Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 188 additions & 58 deletions kernels/src/kernels/variants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import logging
import platform
import re
Expand All @@ -16,6 +17,7 @@
XPU,
Backend,
ROCm,
_backend,
_select_backend,
parse_backend,
)
Expand All @@ -28,32 +30,77 @@

@dataclass(unsafe_hash=True)
class Torch:
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"torch(\d+?)(\d+)")
"""Versioned Torch framework (arch variants)."""

version: Version | None
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(
r"torch(\d+?)(\d+)(?:-(cxx11|cxx98))?"
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prepare for variants without an ABI tag (see extra variants in tests).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I think one liner explainer comment would be nice. Bits around (cxx11|cxx98) aren't particularly clear I think.

)

version: Version
cxx11_abi: bool | None

@staticmethod
def possible_variants() -> list["Torch"]:
if has_torch:
import torch

torch_version = parse(torch.__version__)
torch_version = Version(f"{torch_version.major}.{torch_version.minor}")

os_ = platform.system().lower()
if os_ == "linux":
cxx11_abi = torch.compiled_with_cxx11_abi()
return [
Torch(version=torch_version, cxx11_abi=cxx11_abi),
Torch(version=torch_version, cxx11_abi=None),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why include Torch(version=torch_version, cxx11_abi=None) here as well?

]
else:
return [Torch(version=torch_version, cxx11_abi=None)]
else:
return []

@property
def variant_str(self) -> str:
if self.version is None:
return "torch"
return f"torch{self.version.major}{self.version.minor}"
base = f"torch{self.version.major}{self.version.minor}"
if self.cxx11_abi is None:
return base
return f"{base}-{'cxx11' if self.cxx11_abi else 'cxx98'}"

@staticmethod
def parse(s: str) -> "Torch":
if s == "torch":
return Torch(version=None)
m = Torch._VARIANT_REGEX.fullmatch(s)
if not m:
raise ValueError(f"Invalid Torch variant string: {s!r}")
return Torch(version=Version(f"{m.group(1)}.{m.group(2)}"))
version = Version(f"{m.group(1)}.{m.group(2)}")
abi_str = m.group(3)
if abi_str is None:
cxx11_abi = None
else:
cxx11_abi = abi_str != "cxx98"
return Torch(version=version, cxx11_abi=cxx11_abi)


@dataclass(unsafe_hash=True)
class TvmFfi:
"""Versioned tvm-ffi framework (arch variants)."""

_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"tvm-ffi(\d+?)(\d+)")

version: Version

@staticmethod
def possible_variants() -> list["TvmFfi"]:
if has_tvm_ffi:
import tvm_ffi

tvm_ffi_version = parse(tvm_ffi.__version__)
tvm_ffi_version = Version(
f"{tvm_ffi_version.major}.{tvm_ffi_version.minor}"
)
return [TvmFfi(version=tvm_ffi_version)]
else:
return []

@property
def variant_str(self) -> str:
return f"tvm-ffi{self.version.major}{self.version.minor}"
Expand All @@ -66,41 +113,60 @@ def parse(s: str) -> "TvmFfi":
return TvmFfi(version=Version(f"{m.group(1)}.{m.group(2)}"))


@strict
@dataclass(unsafe_hash=True)
class TorchNoarch:
"""Versionless Torch framework (noarch variants)."""

@staticmethod
def possible_variants() -> list["TorchNoarch"]:
if has_torch:
return [TorchNoarch()]
else:
return []

@property
def variant_str(self) -> str:
return "torch"


@strict
@dataclass(unsafe_hash=True)
class Arch:
"""Aarch kernel information."""
"""Arch kernel information."""

backend: Backend
platform: str
os: str
cxx11_abi: bool | None

@property
def variant_str(self) -> str:
if self.cxx11_abi is None:
return f"{self.backend.variant_str}-{self.platform}-{self.os}"
else:
return f"{'cxx11' if self.cxx11_abi else 'cxx98'}-{self.backend.variant_str}-{self.platform}-{self.os}"
return f"{self.backend.variant_str}-{self.platform}-{self.os}"

@staticmethod
def possible_variants() -> list["Arch"]:
cpu = platform.machine()
os = platform.system().lower()

if os == "darwin":
cpu = "aarch64" if cpu == "arm64" else cpu
elif os == "windows":
cpu = "x86_64" if cpu == "AMD64" else cpu

backend = _backend()

return [Arch(backend=backend, platform=cpu, os=os)]

@staticmethod
def parse(parts: list[str]) -> "Arch":
# Handle Linux with cxx11 marker.
if len(parts) == 4:
# In the future, we want to remove the marker and use cxx11 as
# the default. We check on cxx98 for this reason.
cxx11_abi = parts[0] != "cxx98"
parts = parts[1:]
elif len(parts) == 3:
cxx11_abi = None
else:
if len(parts) != 3:
raise ValueError(f"Invalid arch variant parts: {parts!r}")

backend = parse_backend(parts[0])
platform = parts[1]
os = parts[2]

return Arch(backend=backend, platform=platform, os=os, cxx11_abi=cxx11_abi)
return Arch(backend=backend, platform=platform, os=os)


@strict
Expand All @@ -110,6 +176,13 @@ class Noarch:

backend_name: str

@staticmethod
def possible_variants() -> list["Noarch"]:
backend = _backend()
noarch_backend_name = "npu" if backend.name == "cann" else backend.name
names = {noarch_backend_name, "universal"}
return [Noarch(backend_name=name) for name in sorted(names)]

@property
def variant_str(self) -> str:
return self.backend_name
Expand All @@ -121,37 +194,92 @@ def parse(s: str) -> "Noarch":

@strict
@dataclass(unsafe_hash=True)
class Variant:
"""Kernel build variant."""
class ArchVariant:
"""Arch kernel build variant."""

framework: Torch | TvmFfi
arch: Arch | Noarch
arch: Arch

@staticmethod
def possible_variants() -> list["ArchVariant"]:
frameworks: list[Torch | TvmFfi] = (
Torch.possible_variants() + TvmFfi.possible_variants()
)
archs = Arch.possible_variants()
return [
ArchVariant(framework=fw, arch=arch)
for fw, arch in itertools.product(frameworks, archs)
]
Comment on lines +204 to +212
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so neat! As declarative as it gets.


@property
def variant_str(self) -> str:
return f"{self.framework.variant_str}-{self.arch.variant_str}"


@strict
@dataclass(unsafe_hash=True)
class NoarchVariant:
"""Noarch kernel build variant."""

framework: TorchNoarch
arch: Noarch

@staticmethod
def parse(variant_str: str) -> "Variant":
parts = variant_str.split("-")

arch: Arch | Noarch
framework: Torch | TvmFfi

if parts[0] == "torch":
# noarch: e.g. "torch-cpu"
framework = Torch.parse(parts[0])
arch = Noarch.parse("-".join(parts[1:]))
elif parts[0].startswith("torch"):
framework = Torch.parse(parts[0])
arch = Arch.parse(parts[1:])
elif parts[0] == "tvm" and parts[1].startswith("ffi"):
framework = TvmFfi.parse(f"tvm-{parts[1]}")
arch = Arch.parse(parts[2:])
else:
raise ValueError(f"Unknown framework in variant string: {variant_str!r}")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we're removing parse() from the variant dataclasses and keeping a central parse_variant() utility?

def possible_variants() -> list["NoarchVariant"]:
frameworks = TorchNoarch.possible_variants()
archs = Noarch.possible_variants()
return [
NoarchVariant(framework=fw, arch=arch)
for fw, arch in itertools.product(frameworks, archs)
]

@property
def variant_str(self) -> str:
return f"{self.framework.variant_str}-{self.arch.variant_str}"


return Variant(framework=framework, arch=arch)
Variant = ArchVariant | NoarchVariant


def system_variants() -> list[Variant]:
"""Return all possible build variants for the current system.

Warning: this function should only be used internally (so don't export
at the top-level) and for informational purposes, such as user
feedback. When loading kernels, etc. rely what is on disk and
use `parse_variant` + `resolve_variant`, since this uses our
priority order, etc."""
result: list[Variant] = (
ArchVariant.possible_variants() + NoarchVariant.possible_variants()
)
return _sort_variants(result)


def parse_variant(variant_str: str) -> Variant:
"""Parse a variant string into an ArchVariant or NoarchVariant."""
parts = variant_str.split("-")

if parts[0] == "torch":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a constraint on the length of parts in case of NoarchVariant because ArchVariant will also have parts[0] == "torch"?

# noarch: e.g. "torch-cpu"
return NoarchVariant(
framework=TorchNoarch(), arch=Noarch.parse("-".join(parts[1:]))
)
elif parts[0].startswith("torch"):
if len(parts) >= 2 and parts[1] in ("cxx11", "cxx98"):
framework_str = f"{parts[0]}-{parts[1]}"
arch_parts = parts[2:]
else:
framework_str = parts[0]
arch_parts = parts[1:]
return ArchVariant(
framework=Torch.parse(framework_str), arch=Arch.parse(arch_parts)
)
elif parts[0] == "tvm" and len(parts) >= 2 and parts[1].startswith("ffi"):
return ArchVariant(
framework=TvmFfi.parse(f"tvm-{parts[1]}"), arch=Arch.parse(parts[2:])
)
Comment on lines +262 to +280
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conditional logic feels a bit rigid, but I guess it's purely heuristics-driven and isn't meant to change? If it changes then there's something wrong.

else:
raise ValueError(f"Unknown framework in variant string: {variant_str!r}")


def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]:
Expand All @@ -162,10 +290,10 @@ def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]:
item.path.split("/")[-1] for item in tree if isinstance(item, RepoFolder)
}

variants = []
variants: list[Variant] = []
for variant_str in variant_strs:
try:
variants.append(Variant.parse(variant_str))
variants.append(parse_variant(variant_str))
except ValueError:
logging.warning(
f"Repository {repo_id} (revision: {revision}) contains invalid build variant variant: {variant_str!r}"
Expand All @@ -181,10 +309,10 @@ def get_variants_local(repo_path: Path) -> list[Variant]:
except Exception:
return []

variants = []
variants: list[Variant] = []
for variant_str in variant_strs:
try:
variants.append(Variant.parse(variant_str))
variants.append(parse_variant(variant_str))
except ValueError:
pass
return variants
Expand Down Expand Up @@ -228,7 +356,7 @@ def resolve_variants(
if has_tvm_ffi:
import tvm_ffi

# Parse Torch version and strip patch/tags.
# Parse tvm-ffi version and strip patch/tags.
tvm_ffi_version = parse(tvm_ffi.__version__)
tvm_ffi_version = Version(f"{tvm_ffi_version.major}.{tvm_ffi_version.minor}")

Expand Down Expand Up @@ -275,9 +403,9 @@ def _filter_variants(
tvm_ffi_version: Version | None,
) -> list[Variant]:
"""Return only the variants applicable to the current system."""
result = []
result: list[Variant] = []
for v in variants:
if isinstance(v.arch, Arch):
if isinstance(v, ArchVariant):
# Skip non-matching CPU or OS.
if v.arch.platform != cpu or v.arch.os != os:
continue
Expand All @@ -286,7 +414,10 @@ def _filter_variants(
if isinstance(v.framework, Torch):
if v.framework.version != torch_version:
continue
if v.arch.cxx11_abi != torch_cxx11_abi:
if (
v.framework.cxx11_abi is not None
and v.framework.cxx11_abi != torch_cxx11_abi
):
continue
elif isinstance(v.framework, TvmFfi):
if v.framework.version != tvm_ffi_version:
Expand All @@ -302,8 +433,7 @@ def _filter_variants(
continue
elif v.arch.backend.variant_str != selected_backend.variant_str:
continue
else:
assert isinstance(v.arch, Noarch)
elif isinstance(v, NoarchVariant):
# Only noarch variants with a matching backend or "universal"
# are applicable.
noarch_backend_name = (
Expand All @@ -330,7 +460,7 @@ def _sort_variants(
"""

def sort_key(v: Variant) -> tuple:
if isinstance(v.arch, Arch):
if isinstance(v, ArchVariant):
framework_order = 0 if isinstance(v.framework, Torch) else 1
if isinstance(v.arch.backend, (CUDA, ROCm, XPU, CANN)):
# Order by backend version in reverse (higher is better).
Expand All @@ -339,7 +469,7 @@ def sort_key(v: Variant) -> tuple:
backend_order = 0
return (framework_order, backend_order)
else:
assert isinstance(v.arch, Noarch)
assert isinstance(v, NoarchVariant)
universal_order = 1 if v.arch.backend_name == "universal" else 0
return (2, universal_order)

Expand Down
Loading
Loading