diff --git a/python/flydsl/compiler/backends/__init__.py b/python/flydsl/compiler/backends/__init__.py index 5e67fd61..cca36ff5 100644 --- a/python/flydsl/compiler/backends/__init__.py +++ b/python/flydsl/compiler/backends/__init__.py @@ -68,6 +68,10 @@ def get_backend(name: Optional[str] = None, *, arch: str = "") -> BaseBackend: *name* defaults to ``FLYDSL_COMPILE_BACKEND`` env var (or ``'rocm'``). *arch* overrides the auto-detected architecture when non-empty. + + Compile/runtime pairing (``FLYDSL_COMPILE_BACKEND`` vs ``FLYDSL_RUNTIME_KIND``) + is validated on the JIT path when a kernel is compiled or loaded for execution + (see :func:`flydsl.runtime.device_runtime.get_device_runtime`), not here. """ if name is None: name = compile_backend_name() diff --git a/python/flydsl/compiler/jit_executor.py b/python/flydsl/compiler/jit_executor.py index 1d3c5ddd..934c0ef7 100644 --- a/python/flydsl/compiler/jit_executor.py +++ b/python/flydsl/compiler/jit_executor.py @@ -85,6 +85,11 @@ def _ensure_engine(self): if self._engine is not None: return + # Validate compile backend vs device runtime before loading ExecutionEngine. + from ..runtime.device_runtime import get_device_runtime + + get_device_runtime() + with ir.Context() as ctx: ctx.load_all_available_dialects() self._module = ir.Module.parse(self._ir_text) diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index 1be37d85..5b3139ba 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -778,6 +778,11 @@ def __call__(self, *args, **kwargs): compiled_module = MlirCompiler.compile(module, arch=backend.target.arch, func_name=self.func.__name__) + # Compile/runtime pairing before COMPILE_ONLY return (no ExecutionEngine on that path). + from ..runtime.device_runtime import get_device_runtime + + get_device_runtime() + if env.compile.compile_only: print(f"[flydsl] COMPILE_ONLY=1, compilation succeeded (arch={backend.target.arch})") return None diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index 061d4d37..2d09d0c8 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -300,6 +300,13 @@ def __fly_values__(self): class Stream: + """Opaque async queue handle for kernel launch. + + Values may be ``None`` (default queue), a raw pointer, or a framework stream + with a ``cuda_stream`` attribute (e.g. PyTorch); the active + :class:`flydsl.runtime.device_runtime.DeviceRuntime` interprets them. + """ + _is_stream_param = True def __init__(self, value=None): diff --git a/python/flydsl/runtime/__init__.py b/python/flydsl/runtime/__init__.py index 450d0f1d..0ffce4ab 100644 --- a/python/flydsl/runtime/__init__.py +++ b/python/flydsl/runtime/__init__.py @@ -1,3 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors +"""FlyDSL runtime: device runtime, GPU detection helpers, etc.""" + +from .device_runtime import ( + COMPILE_BACKEND_TO_RUNTIME_KIND, + Device, + DeviceRuntime, + Event, + RocmDeviceRuntime, + ensure_compile_runtime_compatible, + get_device_runtime, + register_compile_runtime_mapping, + register_device_runtime, +) + +__all__ = [ + "COMPILE_BACKEND_TO_RUNTIME_KIND", + "Device", + "DeviceRuntime", + "Event", + "RocmDeviceRuntime", + "ensure_compile_runtime_compatible", + "get_device_runtime", + "register_compile_runtime_mapping", + "register_device_runtime", +] diff --git a/python/flydsl/runtime/device_runtime/__init__.py b/python/flydsl/runtime/device_runtime/__init__.py new file mode 100644 index 00000000..c6755c4a --- /dev/null +++ b/python/flydsl/runtime/device_runtime/__init__.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Per-process GPU device runtime. + +Exactly one :class:`DeviceRuntime` implementation is active per process. +It must match the selected compile backend (e.g. ``rocm`` compile backend ↔ +``rocm`` runtime / HIP). + +Environment: + +* ``FLYDSL_RUNTIME_KIND`` — selects the built-in runtime implementation + (default: ``rocm``). Must agree with ``FLYDSL_COMPILE_BACKEND`` via + :data:`COMPILE_BACKEND_TO_RUNTIME_KIND` (and optional extension mappings). +""" + +from __future__ import annotations + +from typing import Dict, Optional, Type + +from ...utils import env +from .base import Device, DeviceRuntime, Event +from .rocm import RocmDeviceRuntime + +# Compile-backend id -> device-runtime kind (single string namespace). +COMPILE_BACKEND_TO_RUNTIME_KIND: Dict[str, str] = { + "rocm": "rocm", +} + +_EXTRA_MAPPINGS: Dict[str, str] = {} + +_builtin_runtimes: Dict[str, Type[DeviceRuntime]] = { + "rocm": RocmDeviceRuntime, +} + +_runtime_cls_override: Optional[Type[DeviceRuntime]] = None +_instance: Optional[DeviceRuntime] = None + + +def register_compile_runtime_mapping(compile_backend: str, runtime_kind: str) -> None: + """Map a compile-backend id to a device-runtime *kind*. + + Use when a third-party :func:`flydsl.compiler.backends.register_backend` + targets an existing runtime stack (e.g. custom name → ``rocm``). + """ + _EXTRA_MAPPINGS[compile_backend.strip().lower()] = runtime_kind.strip().lower() + + +def register_device_runtime( + cls: Type[DeviceRuntime], + *, + kind: Optional[str] = None, + force: bool = False, +) -> None: + """Register a custom :class:`DeviceRuntime` class for the whole process. + + Must be called before the first :func:`get_device_runtime` if replacing the + default. Raises if a runtime instance already exists, unless ``force=True``. + + If *kind* is set, the class is also registered under that runtime *kind* for + ``FLYDSL_RUNTIME_KIND`` (and should match the compile-backend mapping). + """ + global _runtime_cls_override, _instance + if _instance is not None and not force: + raise RuntimeError( + "Cannot register a device runtime after get_device_runtime() " + "has been called (unless force=True)." + ) + if _runtime_cls_override is not None and not force: + raise ValueError( + "A custom device runtime class is already registered " + "(use force=True to replace)." + ) + _runtime_cls_override = cls + if kind is not None: + _builtin_runtimes[kind.strip().lower()] = cls + + +def _expected_runtime_kind_for_compile_backend(compile_backend_id: str) -> str: + key = compile_backend_id.strip().lower() + if key in _EXTRA_MAPPINGS: + return _EXTRA_MAPPINGS[key] + if key in COMPILE_BACKEND_TO_RUNTIME_KIND: + return COMPILE_BACKEND_TO_RUNTIME_KIND[key] + raise ValueError( + f"No device-runtime kind mapped for compile backend {compile_backend_id!r}. " + "Register a mapping with register_compile_runtime_mapping()." + ) + + +def _resolve_runtime_class() -> Type[DeviceRuntime]: + if _runtime_cls_override is not None: + return _runtime_cls_override + kind = (env.runtime.kind or "rocm").strip().lower() + cls = _builtin_runtimes.get(kind) + if cls is None: + known = ", ".join(sorted(_builtin_runtimes)) or "(none)" + raise ValueError( + f"Unknown FLYDSL_RUNTIME_KIND={kind!r}. Built-in kinds: {known}" + ) + return cls + + +def ensure_compile_runtime_compatible( + compile_backend_id: str, + *, + runtime: Optional[DeviceRuntime] = None, +) -> None: + """Raise if *compile_backend_id* does not match the active runtime kind.""" + expected = _expected_runtime_kind_for_compile_backend(compile_backend_id) + rt = runtime if runtime is not None else get_device_runtime() + if rt.kind != expected: + raise RuntimeError( + f"Compile backend {compile_backend_id!r} requires device runtime kind " + f"{expected!r}, but the active runtime is {rt.kind!r}. " + f"Align FLYDSL_COMPILE_BACKEND with FLYDSL_RUNTIME_KIND (and extension " + f"mappings), or use a matching pair of register_backend / " + f"register_device_runtime." + ) + + +def _active_compile_backend_id() -> str: + """Mirror :func:`flydsl.compiler.backends.compile_backend_name` without importing ``compiler``.""" + return (env.compile.backend or "rocm").lower() + + +def get_device_runtime() -> DeviceRuntime: + """Return the single process-wide :class:`DeviceRuntime` instance.""" + global _instance + if _instance is None: + cls = _resolve_runtime_class() + _instance = cls() + + ensure_compile_runtime_compatible(_active_compile_backend_id(), runtime=_instance) + return _instance + + +__all__ = [ + "COMPILE_BACKEND_TO_RUNTIME_KIND", + "Device", + "DeviceRuntime", + "Event", + "RocmDeviceRuntime", + "ensure_compile_runtime_compatible", + "get_device_runtime", + "register_compile_runtime_mapping", + "register_device_runtime", +] diff --git a/python/flydsl/runtime/device_runtime/base.py b/python/flydsl/runtime/device_runtime/base.py new file mode 100644 index 00000000..5943776d --- /dev/null +++ b/python/flydsl/runtime/device_runtime/base.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Abstract device runtime : single native GPU stack per process.""" + +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import ClassVar + + +@dataclass(frozen=True) +class Device: + """Logical GPU device (ordinal only in v1; capabilities may extend later).""" + + ordinal: int = 0 + """Device index for the active runtime (e.g. HIP device id).""" + + +class DeviceRuntime(metaclass=ABCMeta): + """Vendor-neutral runtime: one implementation per process (HIP, CUDA, …). + + Stream and Event handles stay opaque at the Python/C boundary; concrete + APIs live in native glue (e.g. ROCm wrappers). + """ + + kind: ClassVar[str] + """Stable runtime identifier (e.g. ``\"rocm\"`` for HIP/ROCm).""" + + @abstractmethod + def device_count(self) -> int: + """Number of visible devices for this runtime.""" + + def default_device(self) -> Device: + """Default device for launch when none is specified.""" + return Device(ordinal=0) + + +class Event: + """Placeholder for future opaque event handles. + + Kernel launch paths may synchronize via streams; explicit events can be + added without changing the compile-backend layer. + """ + + __slots__ = () diff --git a/python/flydsl/runtime/device_runtime/rocm.py b/python/flydsl/runtime/device_runtime/rocm.py new file mode 100644 index 00000000..2de9eda4 --- /dev/null +++ b/python/flydsl/runtime/device_runtime/rocm.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""ROCm / HIP device runtime (default FlyDSL GPU stack).""" + +from __future__ import annotations + +from typing import ClassVar, Optional + +from .base import DeviceRuntime + + +class RocmDeviceRuntime(DeviceRuntime): + """HIP-based runtime; matches compile backend ``rocm``. + + ``device_count()`` reports 0 when no GPU is visible to PyTorch (no fake device). + """ + + kind: ClassVar[str] = "rocm" + + def __init__(self) -> None: + self._torch_device_count: Optional[int] = None + + def _lazy_device_count(self) -> int: + if self._torch_device_count is not None: + return self._torch_device_count + try: + import torch + + # ROCm builds still use the torch.cuda Python API for device visibility. + if torch.cuda.is_available(): + n = int(torch.cuda.device_count()) + else: + n = 0 + except Exception: + n = 0 + self._torch_device_count = n + return self._torch_device_count + + def device_count(self) -> int: + return self._lazy_device_count() diff --git a/python/flydsl/utils/env.py b/python/flydsl/utils/env.py index fefe7997..9689a7e8 100644 --- a/python/flydsl/utils/env.py +++ b/python/flydsl/utils/env.py @@ -255,6 +255,10 @@ class RuntimeEnvManager(EnvManager): env_prefix = "RUNTIME" + kind = OptStr( + "rocm", + description="Device runtime kind (must match FLYDSL_COMPILE_BACKEND; e.g. rocm for HIP)", + ) cache_dir = OptStr(str(Path.home() / ".flydsl" / "cache"), description="Directory for caching compiled kernels") enable_cache = OptBool(True, description="Enable kernel caching") diff --git a/tests/unit/test_device_runtime.py b/tests/unit/test_device_runtime.py new file mode 100644 index 00000000..edebe0f6 --- /dev/null +++ b/tests/unit/test_device_runtime.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""device runtime registry and compile-backend pairing.""" + +import pytest + +import flydsl.runtime.device_runtime as dr + + +@pytest.fixture(autouse=True) +def _reset_device_runtime_singleton(): + """Each test starts without a cached DeviceRuntime instance.""" + dr._instance = None + dr._runtime_cls_override = None + dr._EXTRA_MAPPINGS.clear() + yield + dr._instance = None + dr._runtime_cls_override = None + dr._EXTRA_MAPPINGS.clear() + + +class _FakeCudaRuntime(dr.DeviceRuntime): + kind = "cuda" + + def device_count(self) -> int: + return 1 + + +def test_rocm_runtime_kind_matches_compile_backend(monkeypatch): + monkeypatch.delenv("FLYDSL_RUNTIME_KIND", raising=False) + monkeypatch.setenv("FLYDSL_COMPILE_BACKEND", "rocm") + rt = dr.get_device_runtime() + assert rt.kind == "rocm" + dr.ensure_compile_runtime_compatible("rocm", runtime=rt) + + +def test_ensure_mismatch_raises(): + bad = _FakeCudaRuntime() + with pytest.raises(RuntimeError, match="requires device runtime kind"): + dr.ensure_compile_runtime_compatible("rocm", runtime=bad) + + +def test_unknown_runtime_kind_env(monkeypatch): + monkeypatch.setenv("FLYDSL_RUNTIME_KIND", "not_a_real_kind") + monkeypatch.setenv("FLYDSL_COMPILE_BACKEND", "rocm") + with pytest.raises(ValueError, match="Unknown FLYDSL_RUNTIME_KIND"): + dr.get_device_runtime() + + +def test_register_compile_runtime_mapping(): + dr.register_compile_runtime_mapping("foo", "rocm") + try: + dr.ensure_compile_runtime_compatible("foo", runtime=dr.RocmDeviceRuntime()) + finally: + dr._EXTRA_MAPPINGS.pop("foo", None)