Refactor: extract pluggable compile-backend system#276
Conversation
Move ROCm-specific compile logic (pipeline fragments, GPU module targets, runtime libs, cache fingerprinting) out of jit_function.py into a pluggable backend registry (flydsl.compiler.backends). - Add BaseBackend ABC with GPUTarget dataclass (modeled after Triton) - Add RocmBackend as the concrete implementation - Add registry with register_backend/get_backend for extensibility - Eliminate triple backend creation in compile path - Add BaseBackend.hash() for cache key (Triton-style) - Add supports_target assertion in BaseBackend.__init__ - Add FLYDSL_COMPILE_BACKEND env config (default: rocm) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
fsx950223
left a comment
There was a problem hiding this comment.
Review Summary
Clean refactoring that moves ROCm-specific logic into a pluggable backend registry. The abstraction layer is well-modeled after Triton's GPUTarget/BaseBackend pattern. A few issues to address before merging.
Architecture: Approve
The split is logical:
BaseBackendABC defines the right abstraction boundary (pipeline fragments, targets, runtime libs, fingerprinting)RocmBackendis a clean 1:1 lift of the existing logic- Registry pattern (
register_backend/get_backend) enables future RDNA or CUDA backends without modifying core code GPUTargetdataclass is a good Triton-compatible value type
Issue 1: _make_backend is @lru_cache but get_backend() arch defaults change at runtime
@lru_cache(maxsize=4)
def _make_backend(name: str, arch: str) -> BaseBackend:When arch="" (the default), detect_target() reads env.compile.arch or calls get_rocm_arch(). But lru_cache caches on (name, arch) — so _make_backend("rocm", "") is cached forever with whatever arch was first detected. If a user later sets FLYDSL_COMPILE_ARCH=gfx950, the cached backend still returns the original arch.
Suggestion: either (a) resolve arch before caching (pass the resolved arch string, never ""), or (b) don't cache when arch="" and let the caller cache explicitly.
Issue 2: _flydsl_key() is @lru_cache(maxsize=1) but now depends on get_backend()
@lru_cache(maxsize=1)
def _flydsl_key() -> str:
...
backend = get_backend()
...
key = f"flydsl:{flydsl.__version__}:{backend.hash()}-" + "-".join(contents)_flydsl_key() is cached once and never recomputed. If the backend or arch changes (via env var), the cache key becomes stale. This was a pre-existing issue (the old code also read env inside lru_cache), but adding backend.hash() makes it more visible.
This is acceptable for now (backend/arch don't change within a process in practice), but worth a comment.
Issue 3: _resolve_runtime_libs() in jit_executor.py also uses @lru_cache(maxsize=1)
@lru_cache(maxsize=1)
def _resolve_runtime_libs() -> List[str]:
backend = get_backend()
...Same pattern: if backend changes mid-process, the runtime lib list is stale. Fine for now, but consider documenting the invariant that the backend is process-global and immutable after first use.
Issue 4: MlirCompiler.compile() signature change: chip → arch
- def compile(cls, module: ir.Module, *, chip: str = None, func_name: str = "") -> ir.Module:
+ def compile(cls, module: ir.Module, *, arch: str = "", func_name: str = "") -> ir.Module:Good rename, but this is a breaking change for any external callers using MlirCompiler.compile(module, chip="gfx942"). If there are any downstream users, consider keeping chip as a deprecated alias for one release:
def compile(cls, module, *, arch="", chip=None, func_name=""):
if chip is not None:
import warnings
warnings.warn("chip= is deprecated, use arch=", DeprecationWarning)
arch = arch or chipIf MlirCompiler is purely internal, this is fine as-is.
Issue 5: _pipeline_fragments signature inconsistency
- def _pipeline_fragments(*, chip: str) -> list:
+ def _pipeline_fragments(backend) -> list:The static method became a regular static method taking a backend positional arg but lost the @staticmethod decorator context — it still has @staticmethod from the old code. With the new signature (backend), it works correctly as a static method, but consider making it a @classmethod or a plain function since it doesn't use cls and the only caller is compile().
Minor nits
base.pyline 36:assertforsupports_target— in production, asserts can be disabled with-O. Considerif not ...: raise TypeError(...).__init__.pyuses a bottom-of-file import (from .rocm import RocmBackend) which is fine but could be clearer as a top-level import with a comment explaining the registration side-effect.- The
GPUTarget.warp_sizefield is defined but never read by any code in this PR. That's fine for the interface, but note it's currently unused.
Overall: Good refactoring direction. The main concern is the lru_cache + dynamic arch interaction (Issue 1). The rest are minor.
Move ROCm-specific compile logic (pipeline fragments, GPU module targets, runtime libs, cache fingerprinting) out of jit_function.py into a pluggable backend registry (flydsl.compiler.backends).