-
Notifications
You must be signed in to change notification settings - Fork 29
feat(runtime): implement device runtime layer (Python) #277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Peter9606
wants to merge
1
commit into
ROCm:main
Choose a base branch
from
Deep-Spark:fujun.han/pluggable-compile-backend
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+342
−0
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,28 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2025 FlyDSL Project Contributors | ||
|
|
||
| """FlyDSL runtime: device runtime, GPU detection helpers, etc.""" | ||
|
|
||
| from .device_runtime import ( | ||
| COMPILE_BACKEND_TO_RUNTIME_KIND, | ||
| Device, | ||
| DeviceRuntime, | ||
| Event, | ||
| RocmDeviceRuntime, | ||
| ensure_compile_runtime_compatible, | ||
| get_device_runtime, | ||
| register_compile_runtime_mapping, | ||
| register_device_runtime, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "COMPILE_BACKEND_TO_RUNTIME_KIND", | ||
| "Device", | ||
| "DeviceRuntime", | ||
| "Event", | ||
| "RocmDeviceRuntime", | ||
| "ensure_compile_runtime_compatible", | ||
| "get_device_runtime", | ||
| "register_compile_runtime_mapping", | ||
| "register_device_runtime", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2025 FlyDSL Project Contributors | ||
|
|
||
| """Per-process GPU device runtime. | ||
|
|
||
| Exactly one :class:`DeviceRuntime` implementation is active per process. | ||
| It must match the selected compile backend (e.g. ``rocm`` compile backend ↔ | ||
| ``rocm`` runtime / HIP). | ||
|
|
||
| Environment: | ||
|
|
||
| * ``FLYDSL_RUNTIME_KIND`` — selects the built-in runtime implementation | ||
| (default: ``rocm``). Must agree with ``FLYDSL_COMPILE_BACKEND`` via | ||
| :data:`COMPILE_BACKEND_TO_RUNTIME_KIND` (and optional extension mappings). | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Dict, Optional, Type | ||
|
|
||
| from ...utils import env | ||
| from .base import Device, DeviceRuntime, Event | ||
| from .rocm import RocmDeviceRuntime | ||
|
|
||
| # Compile-backend id -> device-runtime kind (single string namespace). | ||
| COMPILE_BACKEND_TO_RUNTIME_KIND: Dict[str, str] = { | ||
| "rocm": "rocm", | ||
| } | ||
|
|
||
| _EXTRA_MAPPINGS: Dict[str, str] = {} | ||
|
|
||
| _builtin_runtimes: Dict[str, Type[DeviceRuntime]] = { | ||
| "rocm": RocmDeviceRuntime, | ||
| } | ||
|
|
||
| _runtime_cls_override: Optional[Type[DeviceRuntime]] = None | ||
| _instance: Optional[DeviceRuntime] = None | ||
|
|
||
|
|
||
| def register_compile_runtime_mapping(compile_backend: str, runtime_kind: str) -> None: | ||
| """Map a compile-backend id to a device-runtime *kind*. | ||
|
|
||
| Use when a third-party :func:`flydsl.compiler.backends.register_backend` | ||
| targets an existing runtime stack (e.g. custom name → ``rocm``). | ||
| """ | ||
| _EXTRA_MAPPINGS[compile_backend.strip().lower()] = runtime_kind.strip().lower() | ||
|
|
||
|
|
||
| def register_device_runtime( | ||
| cls: Type[DeviceRuntime], | ||
| *, | ||
| kind: Optional[str] = None, | ||
| force: bool = False, | ||
| ) -> None: | ||
| """Register a custom :class:`DeviceRuntime` class for the whole process. | ||
|
|
||
| Must be called before the first :func:`get_device_runtime` if replacing the | ||
| default. Raises if a runtime instance already exists, unless ``force=True``. | ||
|
|
||
| If *kind* is set, the class is also registered under that runtime *kind* for | ||
| ``FLYDSL_RUNTIME_KIND`` (and should match the compile-backend mapping). | ||
| """ | ||
| global _runtime_cls_override, _instance | ||
| if _instance is not None and not force: | ||
| raise RuntimeError( | ||
| "Cannot register a device runtime after get_device_runtime() " | ||
| "has been called (unless force=True)." | ||
| ) | ||
| if _runtime_cls_override is not None and not force: | ||
| raise ValueError( | ||
| "A custom device runtime class is already registered " | ||
| "(use force=True to replace)." | ||
| ) | ||
| _runtime_cls_override = cls | ||
| if kind is not None: | ||
| _builtin_runtimes[kind.strip().lower()] = cls | ||
|
|
||
|
|
||
| def _expected_runtime_kind_for_compile_backend(compile_backend_id: str) -> str: | ||
| key = compile_backend_id.strip().lower() | ||
| if key in _EXTRA_MAPPINGS: | ||
| return _EXTRA_MAPPINGS[key] | ||
| if key in COMPILE_BACKEND_TO_RUNTIME_KIND: | ||
| return COMPILE_BACKEND_TO_RUNTIME_KIND[key] | ||
| raise ValueError( | ||
| f"No device-runtime kind mapped for compile backend {compile_backend_id!r}. " | ||
| "Register a mapping with register_compile_runtime_mapping()." | ||
| ) | ||
|
|
||
|
|
||
| def _resolve_runtime_class() -> Type[DeviceRuntime]: | ||
| if _runtime_cls_override is not None: | ||
| return _runtime_cls_override | ||
| kind = (env.runtime.kind or "rocm").strip().lower() | ||
| cls = _builtin_runtimes.get(kind) | ||
| if cls is None: | ||
| known = ", ".join(sorted(_builtin_runtimes)) or "(none)" | ||
| raise ValueError( | ||
| f"Unknown FLYDSL_RUNTIME_KIND={kind!r}. Built-in kinds: {known}" | ||
| ) | ||
| return cls | ||
|
|
||
|
|
||
| def ensure_compile_runtime_compatible( | ||
| compile_backend_id: str, | ||
| *, | ||
| runtime: Optional[DeviceRuntime] = None, | ||
| ) -> None: | ||
| """Raise if *compile_backend_id* does not match the active runtime kind.""" | ||
| expected = _expected_runtime_kind_for_compile_backend(compile_backend_id) | ||
| rt = runtime if runtime is not None else get_device_runtime() | ||
| if rt.kind != expected: | ||
| raise RuntimeError( | ||
| f"Compile backend {compile_backend_id!r} requires device runtime kind " | ||
| f"{expected!r}, but the active runtime is {rt.kind!r}. " | ||
| f"Align FLYDSL_COMPILE_BACKEND with FLYDSL_RUNTIME_KIND (and extension " | ||
| f"mappings), or use a matching pair of register_backend / " | ||
| f"register_device_runtime." | ||
| ) | ||
|
|
||
|
|
||
| def _active_compile_backend_id() -> str: | ||
| """Mirror :func:`flydsl.compiler.backends.compile_backend_name` without importing ``compiler``.""" | ||
| return (env.compile.backend or "rocm").lower() | ||
|
|
||
|
|
||
| def get_device_runtime() -> DeviceRuntime: | ||
| """Return the single process-wide :class:`DeviceRuntime` instance.""" | ||
| global _instance | ||
| if _instance is None: | ||
| cls = _resolve_runtime_class() | ||
| _instance = cls() | ||
|
|
||
| ensure_compile_runtime_compatible(_active_compile_backend_id(), runtime=_instance) | ||
| return _instance | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "COMPILE_BACKEND_TO_RUNTIME_KIND", | ||
| "Device", | ||
| "DeviceRuntime", | ||
| "Event", | ||
| "RocmDeviceRuntime", | ||
| "ensure_compile_runtime_compatible", | ||
| "get_device_runtime", | ||
| "register_compile_runtime_mapping", | ||
| "register_device_runtime", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2025 FlyDSL Project Contributors | ||
|
|
||
| """Abstract device runtime : single native GPU stack per process.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import ABCMeta, abstractmethod | ||
| from dataclasses import dataclass | ||
| from typing import ClassVar | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class Device: | ||
| """Logical GPU device (ordinal only in v1; capabilities may extend later).""" | ||
|
|
||
| ordinal: int = 0 | ||
| """Device index for the active runtime (e.g. HIP device id).""" | ||
|
|
||
|
|
||
| class DeviceRuntime(metaclass=ABCMeta): | ||
| """Vendor-neutral runtime: one implementation per process (HIP, CUDA, …). | ||
|
|
||
| Stream and Event handles stay opaque at the Python/C boundary; concrete | ||
| APIs live in native glue (e.g. ROCm wrappers). | ||
| """ | ||
|
|
||
| kind: ClassVar[str] | ||
| """Stable runtime identifier (e.g. ``\"rocm\"`` for HIP/ROCm).""" | ||
|
|
||
| @abstractmethod | ||
| def device_count(self) -> int: | ||
| """Number of visible devices for this runtime.""" | ||
|
|
||
| def default_device(self) -> Device: | ||
| """Default device for launch when none is specified.""" | ||
| return Device(ordinal=0) | ||
|
|
||
|
|
||
| class Event: | ||
| """Placeholder for future opaque event handles. | ||
|
|
||
| Kernel launch paths may synchronize via streams; explicit events can be | ||
| added without changing the compile-backend layer. | ||
| """ | ||
|
|
||
| __slots__ = () |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2025 FlyDSL Project Contributors | ||
|
|
||
| """ROCm / HIP device runtime (default FlyDSL GPU stack).""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import ClassVar, Optional | ||
|
|
||
| from .base import DeviceRuntime | ||
|
|
||
|
|
||
| class RocmDeviceRuntime(DeviceRuntime): | ||
| """HIP-based runtime; matches compile backend ``rocm``. | ||
|
|
||
| ``device_count()`` reports 0 when no GPU is visible to PyTorch (no fake device). | ||
| """ | ||
|
|
||
| kind: ClassVar[str] = "rocm" | ||
|
|
||
| def __init__(self) -> None: | ||
| self._torch_device_count: Optional[int] = None | ||
|
|
||
| def _lazy_device_count(self) -> int: | ||
| if self._torch_device_count is not None: | ||
| return self._torch_device_count | ||
| try: | ||
| import torch | ||
|
|
||
| # ROCm builds still use the torch.cuda Python API for device visibility. | ||
| if torch.cuda.is_available(): | ||
| n = int(torch.cuda.device_count()) | ||
| else: | ||
| n = 0 | ||
| except Exception: | ||
| n = 0 | ||
| self._torch_device_count = n | ||
| return self._torch_device_count | ||
|
|
||
| def device_count(self) -> int: | ||
| return self._lazy_device_count() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2025 FlyDSL Project Contributors | ||
|
|
||
| """device runtime registry and compile-backend pairing.""" | ||
|
|
||
| import pytest | ||
|
|
||
| import flydsl.runtime.device_runtime as dr | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def _reset_device_runtime_singleton(): | ||
| """Each test starts without a cached DeviceRuntime instance.""" | ||
| dr._instance = None | ||
| dr._runtime_cls_override = None | ||
| dr._EXTRA_MAPPINGS.clear() | ||
| yield | ||
| dr._instance = None | ||
| dr._runtime_cls_override = None | ||
| dr._EXTRA_MAPPINGS.clear() | ||
|
|
||
|
|
||
| class _FakeCudaRuntime(dr.DeviceRuntime): | ||
| kind = "cuda" | ||
|
|
||
| def device_count(self) -> int: | ||
| return 1 | ||
|
|
||
|
|
||
| def test_rocm_runtime_kind_matches_compile_backend(monkeypatch): | ||
| monkeypatch.delenv("FLYDSL_RUNTIME_KIND", raising=False) | ||
| monkeypatch.setenv("FLYDSL_COMPILE_BACKEND", "rocm") | ||
| rt = dr.get_device_runtime() | ||
| assert rt.kind == "rocm" | ||
| dr.ensure_compile_runtime_compatible("rocm", runtime=rt) | ||
|
|
||
|
|
||
| def test_ensure_mismatch_raises(): | ||
| bad = _FakeCudaRuntime() | ||
| with pytest.raises(RuntimeError, match="requires device runtime kind"): | ||
| dr.ensure_compile_runtime_compatible("rocm", runtime=bad) | ||
|
|
||
|
|
||
| def test_unknown_runtime_kind_env(monkeypatch): | ||
| monkeypatch.setenv("FLYDSL_RUNTIME_KIND", "not_a_real_kind") | ||
| monkeypatch.setenv("FLYDSL_COMPILE_BACKEND", "rocm") | ||
| with pytest.raises(ValueError, match="Unknown FLYDSL_RUNTIME_KIND"): | ||
| dr.get_device_runtime() | ||
|
|
||
|
|
||
| def test_register_compile_runtime_mapping(): | ||
| dr.register_compile_runtime_mapping("foo", "rocm") | ||
| try: | ||
| dr.ensure_compile_runtime_compatible("foo", runtime=dr.RocmDeviceRuntime()) | ||
| finally: | ||
| dr._EXTRA_MAPPINGS.pop("foo", None) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we add those Jit runner funcs here after done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RocmDeviceRuntime in this PR is intentionally minimal — it only exists so we have a process-wide runtime kind (rocm) and can validate it against FLYDSL_COMPILE_BACKEND.
device_count() is a small probe via PyTorch for now; we are not planning to add JIT runner / ExecutionEngine glue here. That stays in jit_executor / compiler as today.
Follow-up work (per RFC) could add stream/event helpers on DeviceRuntime, but still not duplicate the JIT launch path into this module.
We can change device_count() to return 0 when no CUDA/ROCm device is visible instead of max(n, 1) to make it clear.