diff --git a/python/flydsl/compiler/__init__.py b/python/flydsl/compiler/__init__.py index 6ebaecc5..eef90c97 100644 --- a/python/flydsl/compiler/__init__.py +++ b/python/flydsl/compiler/__init__.py @@ -1,13 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors +from .backends import BaseBackend, GPUTarget, compile_backend_name, get_backend, register_backend from .jit_argument import JitArgumentRegistry, from_dlpack from .jit_function import jit from .kernel_function import kernel __all__ = [ + "BaseBackend", + "GPUTarget", + "compile_backend_name", "from_dlpack", - "JitArgumentRegistry", + "get_backend", "jit", + "JitArgumentRegistry", "kernel", + "register_backend", ] diff --git a/python/flydsl/compiler/backends/__init__.py b/python/flydsl/compiler/backends/__init__.py new file mode 100644 index 00000000..5e67fd61 --- /dev/null +++ b/python/flydsl/compiler/backends/__init__.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Pluggable GPU compile-backend registry. + +Usage:: + + from flydsl.compiler.backends import get_backend, register_backend + + backend = get_backend() # resolve from FLYDSL_COMPILE_BACKEND (default: rocm) + backend = get_backend("rocm") # explicit + + register_backend("my_hw", MyBackendFactory) # third-party extension +""" + +from __future__ import annotations + +from functools import lru_cache +from typing import Dict, Optional, Type + +from ...utils import env +from .base import BaseBackend, GPUTarget + +_registry: Dict[str, Type[BaseBackend]] = {} + + +def register_backend(name: str, backend_cls: type, *, force: bool = False) -> None: + """Register a backend class under *name* (case-insensitive). + + ``backend_cls`` must be a concrete subclass of ``BaseBackend``. + Raises ``ValueError`` if *name* is already registered, unless + ``force=True`` (useful during development / experimentation). + """ + key = name.lower() + if key in _registry and not force: + raise ValueError( + f"Compile backend '{name}' is already registered. " + f"Use force=True to override (not recommended in production)." + ) + _registry[key] = backend_cls + + +def compile_backend_name() -> str: + """Return the active backend id from env (default ``'rocm'``).""" + return (env.compile.backend or "rocm").lower() + + +@lru_cache(maxsize=4) +def _make_backend(name: str, arch: str) -> BaseBackend: + """Internal: create and cache a backend instance for *(name, arch)*. + + Both *name* and *arch* must already be resolved (non-empty) so that + the ``lru_cache`` key is deterministic and won't become stale when + environment variables change after the first call. + """ + if name not in _registry: + available = ", ".join(sorted(_registry)) or "(none)" + raise ValueError( + f"Unknown compile backend '{name}'. Registered backends: {available}" + ) + backend_cls = _registry[name] + target = backend_cls.make_target(arch) + return backend_cls(target) + + +def get_backend(name: Optional[str] = None, *, arch: str = "") -> BaseBackend: + """Resolve a backend instance. + + *name* defaults to ``FLYDSL_COMPILE_BACKEND`` env var (or ``'rocm'``). + *arch* overrides the auto-detected architecture when non-empty. + """ + if name is None: + name = compile_backend_name() + name = name.lower() + if not arch: + backend_cls = _registry.get(name) + if backend_cls is None: + available = ", ".join(sorted(_registry)) or "(none)" + raise ValueError( + f"Unknown compile backend '{name}'. Registered backends: {available}" + ) + arch = backend_cls.detect_target().arch + return _make_backend(name, arch) + + +# -- auto-register built-in backends ------------------------------------ +from .rocm import RocmBackend # noqa: E402 + +register_backend("rocm", RocmBackend) + +__all__ = [ + "BaseBackend", + "GPUTarget", + "compile_backend_name", + "get_backend", + "register_backend", +] diff --git a/python/flydsl/compiler/backends/base.py b/python/flydsl/compiler/backends/base.py new file mode 100644 index 00000000..b4969a4e --- /dev/null +++ b/python/flydsl/compiler/backends/base.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import List + + +@dataclass(frozen=True) +class GPUTarget: + """Immutable description of a GPU compilation target. + + Modeled after Triton's GPUTarget — carries just enough info for the + compile pipeline and cache-key logic. + """ + + backend: str # e.g. "rocm" + arch: str # e.g. "gfx942", "gfx950" + warp_size: int # 64 for CDNA, 32 for RDNA + + +class BaseBackend(metaclass=ABCMeta): + """Abstract compile-backend interface. + + Each backend provides: + * target detection and arch defaults, + * MLIR pass-pipeline fragments for lowering Fly IR to device binary, + * gpu.module target attributes, + * native-library patterns for toolchain fingerprinting (cache key), + * runtime shared-library basenames for the JIT ExecutionEngine. + """ + + def __init__(self, target: GPUTarget) -> None: + assert self.supports_target(target), ( + f"{type(self).__name__} does not support target {target}" + ) + self.target = target + + # -- target helpers -------------------------------------------------- + + @staticmethod + @abstractmethod + def supports_target(target: GPUTarget) -> bool: + """Return True if this backend can compile for *target*.""" + ... + + @staticmethod + @abstractmethod + def detect_target() -> GPUTarget: + """Auto-detect the current device and return a GPUTarget.""" + ... + + @classmethod + def make_target(cls, arch: str) -> GPUTarget: + """Build a GPUTarget for an explicit *arch* string. + + Defaults to ``detect_target()`` then overrides the arch. + Subclasses should override to compute correct ``warp_size`` etc. + """ + base = cls.detect_target() + return GPUTarget(backend=base.backend, arch=arch, warp_size=base.warp_size) + + # -- compile pipeline ------------------------------------------------ + + @abstractmethod + def pipeline_fragments(self, *, compile_hints: dict) -> List[str]: + """Ordered list of MLIR PassManager.parse fragments. + + ``compile_hints`` carries per-kernel knobs such as ``waves_per_eu`` + and ``maxnreg`` (from ``CompilationContext.get_compile_hints()``). + """ + ... + + @abstractmethod + def gpu_module_targets(self) -> List[str]: + """MLIR target attributes for ``create_gpu_module(..., targets=...)``.""" + ... + + # -- cache / fingerprint --------------------------------------------- + + def hash(self) -> str: + """Return a string uniquely identifying this backend + target. + + Used as part of the JIT cache key (analogous to Triton's + ``BaseBackend.hash()``). Subclasses may override to include + toolchain version info (e.g. ptxas version). + """ + return f"{self.target}" + + @abstractmethod + def native_lib_patterns(self) -> List[str]: + """Glob patterns (relative to ``_mlir/_mlir_libs/``) whose content + is hashed into the toolchain fingerprint ``_flydsl_key``.""" + ... + + @abstractmethod + def jit_runtime_lib_basenames(self) -> List[str]: + """Basenames of shared libraries passed to ``ExecutionEngine``.""" + ... diff --git a/python/flydsl/compiler/backends/rocm.py b/python/flydsl/compiler/backends/rocm.py new file mode 100644 index 00000000..b83e074a --- /dev/null +++ b/python/flydsl/compiler/backends/rocm.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +from typing import List + +from ...runtime.device import get_rocm_arch, is_rdna_arch +from ...utils import env +from .base import BaseBackend, GPUTarget + + +class RocmBackend(BaseBackend): + """ROCm / AMDGPU compile backend (HIP runtime, ROCDL lowering).""" + + @staticmethod + def supports_target(target: GPUTarget) -> bool: + return target.backend == "rocm" + + @staticmethod + def detect_target() -> GPUTarget: + arch = env.compile.arch or get_rocm_arch() + warp_size = 32 if is_rdna_arch(arch) else 64 + return GPUTarget(backend="rocm", arch=arch, warp_size=warp_size) + + @classmethod + def make_target(cls, arch: str) -> GPUTarget: + warp_size = 32 if is_rdna_arch(arch) else 64 + return GPUTarget(backend="rocm", arch=arch, warp_size=warp_size) + + # -- compile pipeline ------------------------------------------------ + + def pipeline_fragments(self, *, compile_hints: dict) -> List[str]: + chip = self.target.arch + waves_per_eu = compile_hints.get("waves_per_eu") + maxnreg = compile_hints.get("maxnreg") + + debug_opt = "-g" if env.debug.enable_debug_info else "" + wpe_opt = f" --amdgpu-waves-per-eu={waves_per_eu}" if waves_per_eu else "" + maxnreg_opt = f" --amdgpu-num-vgpr={maxnreg}" if maxnreg else "" + all_opts = f"{debug_opt}{wpe_opt}{maxnreg_opt}".strip() + + wave64 = "false" if is_rdna_arch(chip) else "true" + return [ + "fly-rewrite-func-signature", + "fly-canonicalize", + "fly-layout-lowering", + "convert-fly-to-rocdl", + "canonicalize", + f"gpu.module(convert-scf-to-cf,cse," + f"convert-gpu-to-rocdl{{chipset={chip} index-bitwidth=0 runtime=HIP use-bare-ptr-memref-call-conv=true}})", + f"rocdl-attach-target{{O=2 abi=600 chip={chip} correct-sqrt=true daz=false fast=false features= " + f"finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64={wave64}}}", + "convert-scf-to-cf", + "convert-cf-to-llvm", + "gpu-to-llvm{use-bare-pointers-for-host=true use-bare-pointers-for-kernels=true}", + "convert-arith-to-llvm", + "convert-func-to-llvm", + "reconcile-unrealized-casts", + *( + ["ensure-debug-info-scope-on-llvm-func{emission-kind=LineTablesOnly}"] + if env.debug.enable_debug_info + else [] + ), + f'gpu-module-to-binary{{format=fatbin opts="{all_opts}"}}', + ] + + def gpu_module_targets(self) -> List[str]: + chip = self.target.arch + return [f'#rocdl.target'] + + # -- cache / fingerprint --------------------------------------------- + + def native_lib_patterns(self) -> List[str]: + return [ + "_fly*.so", + "_fly_rocdl*.so", + "libFly*.so", + "libfly_jit_runtime.so", + "libmlir_rocm_runtime.so", + "_mlirRegisterEverything*.so", + ] + + def jit_runtime_lib_basenames(self) -> List[str]: + return [ + "libfly_jit_runtime.so", + "libmlir_c_runner_utils.so", + ] diff --git a/python/flydsl/compiler/jit_executor.py b/python/flydsl/compiler/jit_executor.py index f5e67d23..1d3c5ddd 100644 --- a/python/flydsl/compiler/jit_executor.py +++ b/python/flydsl/compiler/jit_executor.py @@ -14,11 +14,11 @@ @lru_cache(maxsize=1) def _resolve_runtime_libs() -> List[str]: + from .backends import get_backend + + backend = get_backend() mlir_libs_dir = Path(__file__).resolve().parent.parent / "_mlir" / "_mlir_libs" - libs = [ - mlir_libs_dir / "libfly_jit_runtime.so", - mlir_libs_dir / "libmlir_c_runner_utils.so", - ] + libs = [mlir_libs_dir / name for name in backend.jit_runtime_lib_basenames()] for lib in libs: if not lib.exists(): raise FileNotFoundError( diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index b89db496..6a94ea09 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -17,9 +17,9 @@ from .._mlir.dialects import func from .._mlir.passmanager import PassManager from ..expr.typing import Stream -from ..runtime.device import get_rocm_arch, is_rdna_arch from ..utils import env, log from .ast_rewriter import ASTRewriter +from .backends import get_backend from .jit_argument import convert_to_jit_arguments from .jit_executor import CompiledArtifact from .kernel_function import ( @@ -81,17 +81,10 @@ def _flydsl_key() -> str: contents.append(hashlib.sha256(f.read()).hexdigest()) # 2) Hash native shared libraries (C++ passes, runtime wrappers, bindings). + backend = get_backend() mlir_libs_dir = flydsl_root / "_mlir" / "_mlir_libs" if mlir_libs_dir.is_dir(): - so_patterns = [ - "_fly*.so", - "_fly_rocdl*.so", - "libFly*.so", - "libfly_jit_runtime.so", - "libmlir_rocm_runtime.so", - "_mlirRegisterEverything*.so", - ] - for pattern in so_patterns: + for pattern in backend.native_lib_patterns(): for so_file in sorted(mlir_libs_dir.glob(pattern)): h = hashlib.sha256() with open(so_file, "rb") as f: @@ -102,7 +95,7 @@ def _flydsl_key() -> str: h.update(chunk) contents.append(h.hexdigest()) - key = f"flydsl:{flydsl.__version__}-" + "-".join(contents) + key = f"flydsl:{flydsl.__version__}:{backend.hash()}-" + "-".join(contents) log().debug(f"flydsl_key: {hashlib.sha256(key.encode()).hexdigest()[:16]}") return key @@ -351,50 +344,23 @@ def _sanitize_path_component(s: str) -> str: return _re.sub(r"[^A-Za-z0-9_.-]+", "_", s) if s else "unknown" -class MlirCompiler: - @staticmethod - def _pipeline_fragments(*, chip: str) -> list: - from .kernel_function import CompilationContext - hints = CompilationContext.get_compile_hints() - waves_per_eu = hints.get('waves_per_eu') - maxnreg = hints.get('maxnreg') - - # Build compiler option flags for gpu-module-to-binary - debug_opt = "-g" if env.debug.enable_debug_info else "" - wpe_opt = f" --amdgpu-waves-per-eu={waves_per_eu}" if waves_per_eu else "" - maxnreg_opt = f" --amdgpu-num-vgpr={maxnreg}" if maxnreg else "" - all_opts = f"{debug_opt}{wpe_opt}{maxnreg_opt}".strip() - - wave64 = "false" if is_rdna_arch(chip) else "true" - return [ - "fly-rewrite-func-signature", - "fly-canonicalize", - "fly-layout-lowering", - "convert-fly-to-rocdl", - "canonicalize", - f"gpu.module(convert-scf-to-cf,cse," - f"convert-gpu-to-rocdl{{chipset={chip} index-bitwidth=0 runtime=HIP use-bare-ptr-memref-call-conv=true}})", - f"rocdl-attach-target{{O=2 abi=600 chip={chip} correct-sqrt=true daz=false fast=false features= " - f"finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64={wave64}}}", - "convert-scf-to-cf", - "convert-cf-to-llvm", - "gpu-to-llvm{use-bare-pointers-for-host=true use-bare-pointers-for-kernels=true}", - "convert-arith-to-llvm", - "convert-func-to-llvm", - "reconcile-unrealized-casts", - *((['ensure-debug-info-scope-on-llvm-func{emission-kind=LineTablesOnly}'] if env.debug.enable_debug_info else [])), - f'gpu-module-to-binary{{format=fatbin opts="{all_opts}"}}', - ] +def _pipeline_fragments(backend) -> list: + """Return the MLIR pass-pipeline fragments for *backend*.""" + from .kernel_function import CompilationContext + + hints = CompilationContext.get_compile_hints() + return backend.pipeline_fragments(compile_hints=hints) + +class MlirCompiler: @classmethod - def compile(cls, module: ir.Module, *, chip: str = None, func_name: str = "") -> ir.Module: + def compile(cls, module: ir.Module, *, arch: str = "", func_name: str = "") -> ir.Module: module.operation.verify() - if chip is None: - chip = env.compile.arch or get_rocm_arch() + backend = get_backend(arch=arch) module = ir.Module.parse(module.operation.get_asm(enable_debug_info=env.debug.enable_debug_info)) - fragments = cls._pipeline_fragments(chip=chip) + fragments = _pipeline_fragments(backend) if env.debug.print_origin_ir: log().info(f"Origin IR: \n{module}") @@ -783,8 +749,8 @@ def __call__(self, *args, **kwargs): func_tracker = FuncLocationTracker(self.func) with ir.InsertionPoint(module.body), loc: - chip = env.compile.arch or get_rocm_arch() - gpu_module = create_gpu_module("kernels", targets=[f'#rocdl.target']) + backend = get_backend() + gpu_module = create_gpu_module("kernels", targets=backend.gpu_module_targets()) func_op = func.FuncOp(self.func.__name__, (ir_types, [])) func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() @@ -808,10 +774,10 @@ def __call__(self, *args, **kwargs): original_ir = module.operation.get_asm(enable_debug_info=True) - compiled_module = MlirCompiler.compile(module, chip=chip, func_name=self.func.__name__) + compiled_module = MlirCompiler.compile(module, arch=backend.target.arch, func_name=self.func.__name__) if env.compile.compile_only: - print(f"[flydsl] COMPILE_ONLY=1, compilation succeeded (arch={chip})") + print(f"[flydsl] COMPILE_ONLY=1, compilation succeeded (arch={backend.target.arch})") return None compiled_func = CompiledArtifact( diff --git a/python/flydsl/utils/env.py b/python/flydsl/utils/env.py index 3cb7d60d..fefe7997 100644 --- a/python/flydsl/utils/env.py +++ b/python/flydsl/utils/env.py @@ -224,6 +224,7 @@ class CompileEnvManager(EnvManager): opt_level = OptInt(2, min_value=0, max_value=3, description="Optimization level") compile_only = OptBool(False, env_var="COMPILE_ONLY", description="Only compile without execution, useful for verifying compilation without a GPU") arch = OptStr("", env_var="ARCH", description="Override target GPU architecture (e.g. gfx942, gfx950)") + backend = OptStr("rocm", description="GPU compile backend id (e.g. rocm)") class DebugEnvManager(EnvManager):