Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/flydsl/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors

from .backends import BaseBackend, GPUTarget, compile_backend_name, get_backend, register_backend
from .jit_argument import JitArgumentRegistry, from_dlpack
from .jit_function import jit
from .kernel_function import kernel

__all__ = [
"BaseBackend",
"GPUTarget",
"compile_backend_name",
"from_dlpack",
"JitArgumentRegistry",
"get_backend",
"jit",
"JitArgumentRegistry",
"kernel",
"register_backend",
]
97 changes: 97 additions & 0 deletions python/flydsl/compiler/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors

"""Pluggable GPU compile-backend registry.

Usage::

from flydsl.compiler.backends import get_backend, register_backend

backend = get_backend() # resolve from FLYDSL_COMPILE_BACKEND (default: rocm)
backend = get_backend("rocm") # explicit

register_backend("my_hw", MyBackendFactory) # third-party extension
"""

from __future__ import annotations

from functools import lru_cache
from typing import Dict, Optional, Type

from ...utils import env
from .base import BaseBackend, GPUTarget

_registry: Dict[str, Type[BaseBackend]] = {}


def register_backend(name: str, backend_cls: type, *, force: bool = False) -> None:
"""Register a backend class under *name* (case-insensitive).

``backend_cls`` must be a concrete subclass of ``BaseBackend``.
Raises ``ValueError`` if *name* is already registered, unless
``force=True`` (useful during development / experimentation).
"""
key = name.lower()
if key in _registry and not force:
raise ValueError(
f"Compile backend '{name}' is already registered. "
f"Use force=True to override (not recommended in production)."
)
_registry[key] = backend_cls


def compile_backend_name() -> str:
"""Return the active backend id from env (default ``'rocm'``)."""
return (env.compile.backend or "rocm").lower()


@lru_cache(maxsize=4)
def _make_backend(name: str, arch: str) -> BaseBackend:
"""Internal: create and cache a backend instance for *(name, arch)*.

Both *name* and *arch* must already be resolved (non-empty) so that
the ``lru_cache`` key is deterministic and won't become stale when
environment variables change after the first call.
"""
if name not in _registry:
available = ", ".join(sorted(_registry)) or "(none)"
raise ValueError(
f"Unknown compile backend '{name}'. Registered backends: {available}"
)
backend_cls = _registry[name]
target = backend_cls.make_target(arch)
return backend_cls(target)


def get_backend(name: Optional[str] = None, *, arch: str = "") -> BaseBackend:
"""Resolve a backend instance.

*name* defaults to ``FLYDSL_COMPILE_BACKEND`` env var (or ``'rocm'``).
*arch* overrides the auto-detected architecture when non-empty.
"""
if name is None:
name = compile_backend_name()
name = name.lower()
if not arch:
backend_cls = _registry.get(name)
if backend_cls is None:
available = ", ".join(sorted(_registry)) or "(none)"
raise ValueError(
f"Unknown compile backend '{name}'. Registered backends: {available}"
)
arch = backend_cls.detect_target().arch
return _make_backend(name, arch)


# -- auto-register built-in backends ------------------------------------
from .rocm import RocmBackend # noqa: E402

register_backend("rocm", RocmBackend)

__all__ = [
"BaseBackend",
"GPUTarget",
"compile_backend_name",
"get_backend",
"register_backend",
]
99 changes: 99 additions & 0 deletions python/flydsl/compiler/backends/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import List


@dataclass(frozen=True)
class GPUTarget:
"""Immutable description of a GPU compilation target.

Modeled after Triton's GPUTarget — carries just enough info for the
compile pipeline and cache-key logic.
"""

backend: str # e.g. "rocm"
arch: str # e.g. "gfx942", "gfx950"
warp_size: int # 64 for CDNA, 32 for RDNA


class BaseBackend(metaclass=ABCMeta):
"""Abstract compile-backend interface.

Each backend provides:
* target detection and arch defaults,
* MLIR pass-pipeline fragments for lowering Fly IR to device binary,
* gpu.module target attributes,
* native-library patterns for toolchain fingerprinting (cache key),
* runtime shared-library basenames for the JIT ExecutionEngine.
"""

def __init__(self, target: GPUTarget) -> None:
assert self.supports_target(target), (
f"{type(self).__name__} does not support target {target}"
)
self.target = target

# -- target helpers --------------------------------------------------

@staticmethod
@abstractmethod
def supports_target(target: GPUTarget) -> bool:
"""Return True if this backend can compile for *target*."""
...

@staticmethod
@abstractmethod
def detect_target() -> GPUTarget:
"""Auto-detect the current device and return a GPUTarget."""
...

@classmethod
def make_target(cls, arch: str) -> GPUTarget:
"""Build a GPUTarget for an explicit *arch* string.

Defaults to ``detect_target()`` then overrides the arch.
Subclasses should override to compute correct ``warp_size`` etc.
"""
base = cls.detect_target()
return GPUTarget(backend=base.backend, arch=arch, warp_size=base.warp_size)

# -- compile pipeline ------------------------------------------------

@abstractmethod
def pipeline_fragments(self, *, compile_hints: dict) -> List[str]:
"""Ordered list of MLIR PassManager.parse fragments.

``compile_hints`` carries per-kernel knobs such as ``waves_per_eu``
and ``maxnreg`` (from ``CompilationContext.get_compile_hints()``).
"""
...

@abstractmethod
def gpu_module_targets(self) -> List[str]:
"""MLIR target attributes for ``create_gpu_module(..., targets=...)``."""
...

# -- cache / fingerprint ---------------------------------------------

def hash(self) -> str:
"""Return a string uniquely identifying this backend + target.

Used as part of the JIT cache key (analogous to Triton's
``BaseBackend.hash()``). Subclasses may override to include
toolchain version info (e.g. ptxas version).
"""
return f"{self.target}"

@abstractmethod
def native_lib_patterns(self) -> List[str]:
"""Glob patterns (relative to ``_mlir/_mlir_libs/``) whose content
is hashed into the toolchain fingerprint ``_flydsl_key``."""
...

@abstractmethod
def jit_runtime_lib_basenames(self) -> List[str]:
"""Basenames of shared libraries passed to ``ExecutionEngine``."""
...
86 changes: 86 additions & 0 deletions python/flydsl/compiler/backends/rocm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors

from typing import List

from ...runtime.device import get_rocm_arch, is_rdna_arch
from ...utils import env
from .base import BaseBackend, GPUTarget


class RocmBackend(BaseBackend):
"""ROCm / AMDGPU compile backend (HIP runtime, ROCDL lowering)."""

@staticmethod
def supports_target(target: GPUTarget) -> bool:
return target.backend == "rocm"

@staticmethod
def detect_target() -> GPUTarget:
arch = env.compile.arch or get_rocm_arch()
warp_size = 32 if is_rdna_arch(arch) else 64
return GPUTarget(backend="rocm", arch=arch, warp_size=warp_size)

@classmethod
def make_target(cls, arch: str) -> GPUTarget:
warp_size = 32 if is_rdna_arch(arch) else 64
return GPUTarget(backend="rocm", arch=arch, warp_size=warp_size)

# -- compile pipeline ------------------------------------------------

def pipeline_fragments(self, *, compile_hints: dict) -> List[str]:
chip = self.target.arch
waves_per_eu = compile_hints.get("waves_per_eu")
maxnreg = compile_hints.get("maxnreg")

debug_opt = "-g" if env.debug.enable_debug_info else ""
wpe_opt = f" --amdgpu-waves-per-eu={waves_per_eu}" if waves_per_eu else ""
maxnreg_opt = f" --amdgpu-num-vgpr={maxnreg}" if maxnreg else ""
all_opts = f"{debug_opt}{wpe_opt}{maxnreg_opt}".strip()

wave64 = "false" if is_rdna_arch(chip) else "true"
return [
"fly-rewrite-func-signature",
"fly-canonicalize",
"fly-layout-lowering",
"convert-fly-to-rocdl",
"canonicalize",
f"gpu.module(convert-scf-to-cf,cse,"
f"convert-gpu-to-rocdl{{chipset={chip} index-bitwidth=0 runtime=HIP use-bare-ptr-memref-call-conv=true}})",
f"rocdl-attach-target{{O=2 abi=600 chip={chip} correct-sqrt=true daz=false fast=false features= "
f"finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64={wave64}}}",
"convert-scf-to-cf",
"convert-cf-to-llvm",
"gpu-to-llvm{use-bare-pointers-for-host=true use-bare-pointers-for-kernels=true}",
"convert-arith-to-llvm",
"convert-func-to-llvm",
"reconcile-unrealized-casts",
*(
["ensure-debug-info-scope-on-llvm-func{emission-kind=LineTablesOnly}"]
if env.debug.enable_debug_info
else []
),
f'gpu-module-to-binary{{format=fatbin opts="{all_opts}"}}',
]

def gpu_module_targets(self) -> List[str]:
chip = self.target.arch
return [f'#rocdl.target<chip = "{chip}">']

# -- cache / fingerprint ---------------------------------------------

def native_lib_patterns(self) -> List[str]:
return [
"_fly*.so",
"_fly_rocdl*.so",
"libFly*.so",
"libfly_jit_runtime.so",
"libmlir_rocm_runtime.so",
"_mlirRegisterEverything*.so",
]

def jit_runtime_lib_basenames(self) -> List[str]:
return [
"libfly_jit_runtime.so",
"libmlir_c_runner_utils.so",
]
8 changes: 4 additions & 4 deletions python/flydsl/compiler/jit_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

@lru_cache(maxsize=1)
def _resolve_runtime_libs() -> List[str]:
from .backends import get_backend

backend = get_backend()
mlir_libs_dir = Path(__file__).resolve().parent.parent / "_mlir" / "_mlir_libs"
libs = [
mlir_libs_dir / "libfly_jit_runtime.so",
mlir_libs_dir / "libmlir_c_runner_utils.so",
]
libs = [mlir_libs_dir / name for name in backend.jit_runtime_lib_basenames()]
for lib in libs:
if not lib.exists():
raise FileNotFoundError(
Expand Down
Loading
Loading