Suggestion Description
RFC : Pluggable GPU compile backend (Python layer)
| Status |
Draft |
| Author |
Peter Han |
| Created |
2026-03-18 |
Summary
Introduce a pluggable GPU compile backend abstraction in the Python compiler so that lowering pipelines, gpu.module targets, JIT shared-library resolution, and disk-cache fingerprints are backend-driven instead of hard-coded for ROCm. The built-in rocm backend preserves current behavior (Fly → ROCDL → fatbin, HIP runtime).
Motivation
- Multiple GPU stacks — FlyDSL today assumes AMD ROCm / AMDGPU in
jit_function.py and jit_executor.py. Adding another stack (e.g. CUDA/NVVM, or another MLIR path) would scatter conditionals and duplicate cache-key logic.
- Correct caching — Disk and in-process JIT caches must not mix artifacts across backends or across forced architecture overrides. Keys and native
.so globs should reflect the active backend.
- Extensions — Downstream projects or optional packages should be able to register a backend before
@jit without forking core compiler code.
Non-goals (this change set)
- Implementing a second production backend (e.g. CUDA); only the hook + ROCm implementation is in scope.
- Changing C++/MLIR pass registration beyond what the existing ROCm pipeline already uses.
User-visible behavior
Environment
| Variable |
Default |
Meaning |
FLYDSL_COMPILE_BACKEND |
rocm |
Backend id (case-insensitive). Unknown values raise ValueError with a list of registered ids. |
FLYDSL_COMPILE_ARCH |
(empty) |
Unchanged: optional override of target arch string (e.g. gfx942 for ROCm). |
env.compile.backend mirrors FLYDSL_COMPILE_BACKEND (see flydsl.utils.env.CompileEnvManager).
Public API (flydsl.compiler)
New exports:
GpuCompileBackend — protocol describing backend capabilities.
RocmBackend — default ROCm / ROCDL implementation.
compile_backend_name() — active backend id from env.
get_backend(name=None) — resolve backend (explicit name or env).
register_backend(name, factory) — register a factory () -> GpuCompileBackend (raises if name already registered).
MlirCompiler.compile
- Accepts optional
arch= as an alias for legacy chip=; mutually exclusive (ValueError if both set).
- Pass pipeline and default arch come from
get_backend() instead of static ROCm-only methods.
Design
1. GpuCompileBackend protocol (compiler/backends/base.py)
Backends implement:
| Method / attribute |
Role |
name: str |
Stable id (e.g. "rocm"). |
default_arch() |
Used when FLYDSL_COMPILE_ARCH is unset. |
pipeline_fragments(arch=...) |
Ordered PassManager.parse fragments for builtin.module. |
gpu_module_targets(arch=...) |
Targets for create_gpu_module(..., targets=...). |
cache_key_native_lib_patterns() |
Globs under flydsl/_mlir/_mlir_libs hashed into toolchain fingerprint. |
jit_runtime_lib_basenames() |
Basenames passed to ExecutionEngine as shared_libs. |
2. Registry (compiler/backends/__init__.py)
- Module import registers
rocm → RocmBackend.
register_backend stores Callable[[], GpuCompileBackend] so each resolution gets a fresh instance (simple, stateless backends today).
3. ROCm backend (compiler/backends/rocm.py)
Moves the previous MlirCompiler._pipeline_fragments logic here, including RDNA wave64 handling via is_rdna_arch, and the same ROCDL / gpu-module-to-binary{format=fatbin} sequence.
4. Cache and executor partitioning
_flydsl_key: Inner cache keyed by compile_backend string; native lib patterns come from get_backend(name=...).
_jit_function_cache_key: Adds flydsl_jit_compile_ctx:<backend>:<arch_or_empty>.
JitFunction call-state key: Leading tuple ("__flydsl_compile__", backend, arch_override).
JitFunction._ensure_cache_manager: If (backend, arch) changes in-process, clears manager_key, cache_manager, _mem_cache, _call_state_cache so switching env mid-process does not reuse stale state.
_resolve_runtime_libs: lru_cache(maxsize=16) on (compile_backend,); library list from backend.jit_runtime_lib_basenames().
Alternatives considered
- Subclass
MlirCompiler per backend — Rejected: JIT, executor, and cache paths would still need a single discovery mechanism; env-based registry is smaller and matches “plugin” use cases.
- Single global backend set at import time only — Rejected: documented behavior allows reading env when compiling; in-process invalidation handles backend/arch changes safely.
- Encode backend only in
_flydsl_key — Partially done, but explicit call-state and manager reset avoids subtle bugs when keys are reused across subsystems.
Migration / compatibility
- Default remains ROCm; existing users need no env changes.
- Callers using only
chip= on MlirCompiler.compile are unchanged; arch= is additive.
- Extensions must call
register_backend before first compile with that backend id.
Documentation & follow-ups
Open questions
- Should
register_backend allow replacing an existing id behind a debug flag for experiments?
- Should
FLYDSL_COMPILE_BACKEND be validated eagerly at import vs lazily at first compile (current: lazy via get_backend)?
- Do we want backend-specific runtime env (separate from compile) in the same registry pattern?
Operating System
No response
GPU
No response
ROCm Component
No response
Suggestion Description
RFC : Pluggable GPU compile backend (Python layer)
Summary
Introduce a pluggable GPU compile backend abstraction in the Python compiler so that lowering pipelines,
gpu.moduletargets, JIT shared-library resolution, and disk-cache fingerprints are backend-driven instead of hard-coded for ROCm. The built-inrocmbackend preserves current behavior (Fly → ROCDL → fatbin, HIP runtime).Motivation
jit_function.pyandjit_executor.py. Adding another stack (e.g. CUDA/NVVM, or another MLIR path) would scatter conditionals and duplicate cache-key logic..soglobs should reflect the active backend.@jitwithout forking core compiler code.Non-goals (this change set)
User-visible behavior
Environment
FLYDSL_COMPILE_BACKENDrocmValueErrorwith a list of registered ids.FLYDSL_COMPILE_ARCHgfx942for ROCm).env.compile.backendmirrorsFLYDSL_COMPILE_BACKEND(seeflydsl.utils.env.CompileEnvManager).Public API (
flydsl.compiler)New exports:
GpuCompileBackend— protocol describing backend capabilities.RocmBackend— default ROCm / ROCDL implementation.compile_backend_name()— active backend id from env.get_backend(name=None)— resolve backend (explicitnameor env).register_backend(name, factory)— register a factory() -> GpuCompileBackend(raises if name already registered).MlirCompiler.compilearch=as an alias for legacychip=; mutually exclusive (ValueErrorif both set).get_backend()instead of static ROCm-only methods.Design
1.
GpuCompileBackendprotocol (compiler/backends/base.py)Backends implement:
name: str"rocm").default_arch()FLYDSL_COMPILE_ARCHis unset.pipeline_fragments(arch=...)PassManager.parsefragments forbuiltin.module.gpu_module_targets(arch=...)create_gpu_module(..., targets=...).cache_key_native_lib_patterns()flydsl/_mlir/_mlir_libshashed into toolchain fingerprint.jit_runtime_lib_basenames()ExecutionEngineasshared_libs.2. Registry (
compiler/backends/__init__.py)rocm→RocmBackend.register_backendstoresCallable[[], GpuCompileBackend]so each resolution gets a fresh instance (simple, stateless backends today).3. ROCm backend (
compiler/backends/rocm.py)Moves the previous
MlirCompiler._pipeline_fragmentslogic here, including RDNA wave64 handling viais_rdna_arch, and the same ROCDL /gpu-module-to-binary{format=fatbin}sequence.4. Cache and executor partitioning
_flydsl_key: Inner cache keyed bycompile_backendstring; native lib patterns come fromget_backend(name=...)._jit_function_cache_key: Addsflydsl_jit_compile_ctx:<backend>:<arch_or_empty>.JitFunctioncall-state key: Leading tuple("__flydsl_compile__", backend, arch_override).JitFunction._ensure_cache_manager: If(backend, arch)changes in-process, clearsmanager_key,cache_manager,_mem_cache,_call_state_cacheso switching env mid-process does not reuse stale state._resolve_runtime_libs:lru_cache(maxsize=16)on(compile_backend,); library list frombackend.jit_runtime_lib_basenames().Alternatives considered
MlirCompilerper backend — Rejected: JIT, executor, and cache paths would still need a single discovery mechanism; env-based registry is smaller and matches “plugin” use cases._flydsl_key— Partially done, but explicit call-state and manager reset avoids subtle bugs when keys are reused across subsystems.Migration / compatibility
chip=onMlirCompiler.compileare unchanged;arch=is additive.register_backendbefore first compile with that backend id.Documentation & follow-ups
docs/installation.rst/docs/api/compiler.rstwithFLYDSL_COMPILE_BACKENDandregister_backend.Open questions
register_backendallow replacing an existing id behind a debug flag for experiments?FLYDSL_COMPILE_BACKENDbe validated eagerly at import vs lazily at first compile (current: lazy viaget_backend)?Operating System
No response
GPU
No response
ROCm Component
No response