Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/flydsl/compiler/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def get_backend(name: Optional[str] = None, *, arch: str = "") -> BaseBackend:

*name* defaults to ``FLYDSL_COMPILE_BACKEND`` env var (or ``'rocm'``).
*arch* overrides the auto-detected architecture when non-empty.

Compile/runtime pairing (``FLYDSL_COMPILE_BACKEND`` vs ``FLYDSL_RUNTIME_KIND``)
is validated on the JIT path when a kernel is compiled or loaded for execution
(see :func:`flydsl.runtime.device_runtime.get_device_runtime`), not here.
"""
if name is None:
name = compile_backend_name()
Expand Down
5 changes: 5 additions & 0 deletions python/flydsl/compiler/jit_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def _ensure_engine(self):
if self._engine is not None:
return

# Validate compile backend vs device runtime before loading ExecutionEngine.
from ..runtime.device_runtime import get_device_runtime

get_device_runtime()

with ir.Context() as ctx:
ctx.load_all_available_dialects()
self._module = ir.Module.parse(self._ir_text)
Expand Down
5 changes: 5 additions & 0 deletions python/flydsl/compiler/jit_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,11 @@ def __call__(self, *args, **kwargs):

compiled_module = MlirCompiler.compile(module, arch=backend.target.arch, func_name=self.func.__name__)

# Compile/runtime pairing before COMPILE_ONLY return (no ExecutionEngine on that path).
from ..runtime.device_runtime import get_device_runtime

get_device_runtime()

if env.compile.compile_only:
print(f"[flydsl] COMPILE_ONLY=1, compilation succeeded (arch={backend.target.arch})")
return None
Expand Down
7 changes: 7 additions & 0 deletions python/flydsl/expr/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,13 @@ def __fly_values__(self):


class Stream:
"""Opaque async queue handle for kernel launch.

Values may be ``None`` (default queue), a raw pointer, or a framework stream
with a ``cuda_stream`` attribute (e.g. PyTorch); the active
:class:`flydsl.runtime.device_runtime.DeviceRuntime` interprets them.
"""

_is_stream_param = True

def __init__(self, value=None):
Expand Down
25 changes: 25 additions & 0 deletions python/flydsl/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors

"""FlyDSL runtime: device runtime, GPU detection helpers, etc."""

from .device_runtime import (
COMPILE_BACKEND_TO_RUNTIME_KIND,
Device,
DeviceRuntime,
Event,
RocmDeviceRuntime,
ensure_compile_runtime_compatible,
get_device_runtime,
register_compile_runtime_mapping,
register_device_runtime,
)

__all__ = [
"COMPILE_BACKEND_TO_RUNTIME_KIND",
"Device",
"DeviceRuntime",
"Event",
"RocmDeviceRuntime",
"ensure_compile_runtime_compatible",
"get_device_runtime",
"register_compile_runtime_mapping",
"register_device_runtime",
]
148 changes: 148 additions & 0 deletions python/flydsl/runtime/device_runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors

"""Per-process GPU device runtime.

Exactly one :class:`DeviceRuntime` implementation is active per process.
It must match the selected compile backend (e.g. ``rocm`` compile backend ↔
``rocm`` runtime / HIP).

Environment:

* ``FLYDSL_RUNTIME_KIND`` — selects the built-in runtime implementation
(default: ``rocm``). Must agree with ``FLYDSL_COMPILE_BACKEND`` via
:data:`COMPILE_BACKEND_TO_RUNTIME_KIND` (and optional extension mappings).
"""

from __future__ import annotations

from typing import Dict, Optional, Type

from ...utils import env
from .base import Device, DeviceRuntime, Event
from .rocm import RocmDeviceRuntime

# Compile-backend id -> device-runtime kind (single string namespace).
COMPILE_BACKEND_TO_RUNTIME_KIND: Dict[str, str] = {
"rocm": "rocm",
}

_EXTRA_MAPPINGS: Dict[str, str] = {}

_builtin_runtimes: Dict[str, Type[DeviceRuntime]] = {
"rocm": RocmDeviceRuntime,
}

_runtime_cls_override: Optional[Type[DeviceRuntime]] = None
_instance: Optional[DeviceRuntime] = None


def register_compile_runtime_mapping(compile_backend: str, runtime_kind: str) -> None:
"""Map a compile-backend id to a device-runtime *kind*.

Use when a third-party :func:`flydsl.compiler.backends.register_backend`
targets an existing runtime stack (e.g. custom name → ``rocm``).
"""
_EXTRA_MAPPINGS[compile_backend.strip().lower()] = runtime_kind.strip().lower()


def register_device_runtime(
cls: Type[DeviceRuntime],
*,
kind: Optional[str] = None,
force: bool = False,
) -> None:
"""Register a custom :class:`DeviceRuntime` class for the whole process.

Must be called before the first :func:`get_device_runtime` if replacing the
default. Raises if a runtime instance already exists, unless ``force=True``.

If *kind* is set, the class is also registered under that runtime *kind* for
``FLYDSL_RUNTIME_KIND`` (and should match the compile-backend mapping).
"""
global _runtime_cls_override, _instance
if _instance is not None and not force:
raise RuntimeError(
"Cannot register a device runtime after get_device_runtime() "
"has been called (unless force=True)."
)
if _runtime_cls_override is not None and not force:
raise ValueError(
"A custom device runtime class is already registered "
"(use force=True to replace)."
)
_runtime_cls_override = cls
if kind is not None:
_builtin_runtimes[kind.strip().lower()] = cls


def _expected_runtime_kind_for_compile_backend(compile_backend_id: str) -> str:
key = compile_backend_id.strip().lower()
if key in _EXTRA_MAPPINGS:
return _EXTRA_MAPPINGS[key]
if key in COMPILE_BACKEND_TO_RUNTIME_KIND:
return COMPILE_BACKEND_TO_RUNTIME_KIND[key]
raise ValueError(
f"No device-runtime kind mapped for compile backend {compile_backend_id!r}. "
"Register a mapping with register_compile_runtime_mapping()."
)


def _resolve_runtime_class() -> Type[DeviceRuntime]:
if _runtime_cls_override is not None:
return _runtime_cls_override
kind = (env.runtime.kind or "rocm").strip().lower()
cls = _builtin_runtimes.get(kind)
if cls is None:
known = ", ".join(sorted(_builtin_runtimes)) or "(none)"
raise ValueError(
f"Unknown FLYDSL_RUNTIME_KIND={kind!r}. Built-in kinds: {known}"
)
return cls


def ensure_compile_runtime_compatible(
compile_backend_id: str,
*,
runtime: Optional[DeviceRuntime] = None,
) -> None:
"""Raise if *compile_backend_id* does not match the active runtime kind."""
expected = _expected_runtime_kind_for_compile_backend(compile_backend_id)
rt = runtime if runtime is not None else get_device_runtime()
if rt.kind != expected:
raise RuntimeError(
f"Compile backend {compile_backend_id!r} requires device runtime kind "
f"{expected!r}, but the active runtime is {rt.kind!r}. "
f"Align FLYDSL_COMPILE_BACKEND with FLYDSL_RUNTIME_KIND (and extension "
f"mappings), or use a matching pair of register_backend / "
f"register_device_runtime."
)


def _active_compile_backend_id() -> str:
"""Mirror :func:`flydsl.compiler.backends.compile_backend_name` without importing ``compiler``."""
return (env.compile.backend or "rocm").lower()


def get_device_runtime() -> DeviceRuntime:
"""Return the single process-wide :class:`DeviceRuntime` instance."""
global _instance
if _instance is None:
cls = _resolve_runtime_class()
_instance = cls()

ensure_compile_runtime_compatible(_active_compile_backend_id(), runtime=_instance)
return _instance


__all__ = [
"COMPILE_BACKEND_TO_RUNTIME_KIND",
"Device",
"DeviceRuntime",
"Event",
"RocmDeviceRuntime",
"ensure_compile_runtime_compatible",
"get_device_runtime",
"register_compile_runtime_mapping",
"register_device_runtime",
]
47 changes: 47 additions & 0 deletions python/flydsl/runtime/device_runtime/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors

"""Abstract device runtime : single native GPU stack per process."""

from __future__ import annotations

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import ClassVar


@dataclass(frozen=True)
class Device:
"""Logical GPU device (ordinal only in v1; capabilities may extend later)."""

ordinal: int = 0
"""Device index for the active runtime (e.g. HIP device id)."""


class DeviceRuntime(metaclass=ABCMeta):
"""Vendor-neutral runtime: one implementation per process (HIP, CUDA, …).

Stream and Event handles stay opaque at the Python/C boundary; concrete
APIs live in native glue (e.g. ROCm wrappers).
"""

kind: ClassVar[str]
"""Stable runtime identifier (e.g. ``\"rocm\"`` for HIP/ROCm)."""

@abstractmethod
def device_count(self) -> int:
"""Number of visible devices for this runtime."""

def default_device(self) -> Device:
"""Default device for launch when none is specified."""
return Device(ordinal=0)


class Event:
"""Placeholder for future opaque event handles.

Kernel launch paths may synchronize via streams; explicit events can be
added without changing the compile-backend layer.
"""

__slots__ = ()
41 changes: 41 additions & 0 deletions python/flydsl/runtime/device_runtime/rocm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors

"""ROCm / HIP device runtime (default FlyDSL GPU stack)."""

from __future__ import annotations

from typing import ClassVar, Optional

from .base import DeviceRuntime


class RocmDeviceRuntime(DeviceRuntime):
"""HIP-based runtime; matches compile backend ``rocm``.

``device_count()`` reports 0 when no GPU is visible to PyTorch (no fake device).
"""

kind: ClassVar[str] = "rocm"

def __init__(self) -> None:
self._torch_device_count: Optional[int] = None

def _lazy_device_count(self) -> int:
if self._torch_device_count is not None:
return self._torch_device_count
try:
import torch

# ROCm builds still use the torch.cuda Python API for device visibility.
if torch.cuda.is_available():
n = int(torch.cuda.device_count())
else:
n = 0
except Exception:
n = 0
self._torch_device_count = n
return self._torch_device_count

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we add those Jit runner funcs here after done?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RocmDeviceRuntime in this PR is intentionally minimal — it only exists so we have a process-wide runtime kind (rocm) and can validate it against FLYDSL_COMPILE_BACKEND.
device_count() is a small probe via PyTorch for now; we are not planning to add JIT runner / ExecutionEngine glue here. That stays in jit_executor / compiler as today.
Follow-up work (per RFC) could add stream/event helpers on DeviceRuntime, but still not duplicate the JIT launch path into this module.
We can change device_count() to return 0 when no CUDA/ROCm device is visible instead of max(n, 1) to make it clear.

def device_count(self) -> int:
return self._lazy_device_count()
4 changes: 4 additions & 0 deletions python/flydsl/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ class RuntimeEnvManager(EnvManager):

env_prefix = "RUNTIME"

kind = OptStr(
"rocm",
description="Device runtime kind (must match FLYDSL_COMPILE_BACKEND; e.g. rocm for HIP)",
)
cache_dir = OptStr(str(Path.home() / ".flydsl" / "cache"), description="Directory for caching compiled kernels")
enable_cache = OptBool(True, description="Enable kernel caching")

Expand Down
56 changes: 56 additions & 0 deletions tests/unit/test_device_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors

"""device runtime registry and compile-backend pairing."""

import pytest

import flydsl.runtime.device_runtime as dr


@pytest.fixture(autouse=True)
def _reset_device_runtime_singleton():
"""Each test starts without a cached DeviceRuntime instance."""
dr._instance = None
dr._runtime_cls_override = None
dr._EXTRA_MAPPINGS.clear()
yield
dr._instance = None
dr._runtime_cls_override = None
dr._EXTRA_MAPPINGS.clear()


class _FakeCudaRuntime(dr.DeviceRuntime):
kind = "cuda"

def device_count(self) -> int:
return 1


def test_rocm_runtime_kind_matches_compile_backend(monkeypatch):
monkeypatch.delenv("FLYDSL_RUNTIME_KIND", raising=False)
monkeypatch.setenv("FLYDSL_COMPILE_BACKEND", "rocm")
rt = dr.get_device_runtime()
assert rt.kind == "rocm"
dr.ensure_compile_runtime_compatible("rocm", runtime=rt)


def test_ensure_mismatch_raises():
bad = _FakeCudaRuntime()
with pytest.raises(RuntimeError, match="requires device runtime kind"):
dr.ensure_compile_runtime_compatible("rocm", runtime=bad)


def test_unknown_runtime_kind_env(monkeypatch):
monkeypatch.setenv("FLYDSL_RUNTIME_KIND", "not_a_real_kind")
monkeypatch.setenv("FLYDSL_COMPILE_BACKEND", "rocm")
with pytest.raises(ValueError, match="Unknown FLYDSL_RUNTIME_KIND"):
dr.get_device_runtime()


def test_register_compile_runtime_mapping():
dr.register_compile_runtime_mapping("foo", "rocm")
try:
dr.ensure_compile_runtime_compatible("foo", runtime=dr.RocmDeviceRuntime())
finally:
dr._EXTRA_MAPPINGS.pop("foo", None)