Skip to content

[Feature]: RFC: Pluggable GPU compile backend (Python layer) #253

@Peter9606

Description

@Peter9606

Suggestion Description

RFC : Pluggable GPU compile backend (Python layer)

Status Draft
Author Peter Han
Created 2026-03-18

Summary

Introduce a pluggable GPU compile backend abstraction in the Python compiler so that lowering pipelines, gpu.module targets, JIT shared-library resolution, and disk-cache fingerprints are backend-driven instead of hard-coded for ROCm. The built-in rocm backend preserves current behavior (Fly → ROCDL → fatbin, HIP runtime).

Motivation

  1. Multiple GPU stacks — FlyDSL today assumes AMD ROCm / AMDGPU in jit_function.py and jit_executor.py. Adding another stack (e.g. CUDA/NVVM, or another MLIR path) would scatter conditionals and duplicate cache-key logic.
  2. Correct caching — Disk and in-process JIT caches must not mix artifacts across backends or across forced architecture overrides. Keys and native .so globs should reflect the active backend.
  3. Extensions — Downstream projects or optional packages should be able to register a backend before @jit without forking core compiler code.

Non-goals (this change set)

  • Implementing a second production backend (e.g. CUDA); only the hook + ROCm implementation is in scope.
  • Changing C++/MLIR pass registration beyond what the existing ROCm pipeline already uses.

User-visible behavior

Environment

Variable Default Meaning
FLYDSL_COMPILE_BACKEND rocm Backend id (case-insensitive). Unknown values raise ValueError with a list of registered ids.
FLYDSL_COMPILE_ARCH (empty) Unchanged: optional override of target arch string (e.g. gfx942 for ROCm).

env.compile.backend mirrors FLYDSL_COMPILE_BACKEND (see flydsl.utils.env.CompileEnvManager).

Public API (flydsl.compiler)

New exports:

  • GpuCompileBackend — protocol describing backend capabilities.
  • RocmBackend — default ROCm / ROCDL implementation.
  • compile_backend_name() — active backend id from env.
  • get_backend(name=None) — resolve backend (explicit name or env).
  • register_backend(name, factory) — register a factory () -> GpuCompileBackend (raises if name already registered).

MlirCompiler.compile

  • Accepts optional arch= as an alias for legacy chip=; mutually exclusive (ValueError if both set).
  • Pass pipeline and default arch come from get_backend() instead of static ROCm-only methods.

Design

1. GpuCompileBackend protocol (compiler/backends/base.py)

Backends implement:

Method / attribute Role
name: str Stable id (e.g. "rocm").
default_arch() Used when FLYDSL_COMPILE_ARCH is unset.
pipeline_fragments(arch=...) Ordered PassManager.parse fragments for builtin.module.
gpu_module_targets(arch=...) Targets for create_gpu_module(..., targets=...).
cache_key_native_lib_patterns() Globs under flydsl/_mlir/_mlir_libs hashed into toolchain fingerprint.
jit_runtime_lib_basenames() Basenames passed to ExecutionEngine as shared_libs.

2. Registry (compiler/backends/__init__.py)

  • Module import registers rocmRocmBackend.
  • register_backend stores Callable[[], GpuCompileBackend] so each resolution gets a fresh instance (simple, stateless backends today).

3. ROCm backend (compiler/backends/rocm.py)

Moves the previous MlirCompiler._pipeline_fragments logic here, including RDNA wave64 handling via is_rdna_arch, and the same ROCDL / gpu-module-to-binary{format=fatbin} sequence.

4. Cache and executor partitioning

  • _flydsl_key: Inner cache keyed by compile_backend string; native lib patterns come from get_backend(name=...).
  • _jit_function_cache_key: Adds flydsl_jit_compile_ctx:<backend>:<arch_or_empty>.
  • JitFunction call-state key: Leading tuple ("__flydsl_compile__", backend, arch_override).
  • JitFunction._ensure_cache_manager: If (backend, arch) changes in-process, clears manager_key, cache_manager, _mem_cache, _call_state_cache so switching env mid-process does not reuse stale state.
  • _resolve_runtime_libs: lru_cache(maxsize=16) on (compile_backend,); library list from backend.jit_runtime_lib_basenames().

Alternatives considered

  1. Subclass MlirCompiler per backend — Rejected: JIT, executor, and cache paths would still need a single discovery mechanism; env-based registry is smaller and matches “plugin” use cases.
  2. Single global backend set at import time only — Rejected: documented behavior allows reading env when compiling; in-process invalidation handles backend/arch changes safely.
  3. Encode backend only in _flydsl_key — Partially done, but explicit call-state and manager reset avoids subtle bugs when keys are reused across subsystems.

Migration / compatibility

  • Default remains ROCm; existing users need no env changes.
  • Callers using only chip= on MlirCompiler.compile are unchanged; arch= is additive.
  • Extensions must call register_backend before first compile with that backend id.

Documentation & follow-ups

  • Update docs/installation.rst / docs/api/compiler.rst with FLYDSL_COMPILE_BACKEND and register_backend.
  • Add a short “Writing a compile backend” section (minimal protocol checklist + example stub).
  • When a second backend lands, add CI matrix entry and integration tests.

Open questions

  1. Should register_backend allow replacing an existing id behind a debug flag for experiments?
  2. Should FLYDSL_COMPILE_BACKEND be validated eagerly at import vs lazily at first compile (current: lazy via get_backend)?
  3. Do we want backend-specific runtime env (separate from compile) in the same registry pattern?

Operating System

No response

GPU

No response

ROCm Component

No response

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions