Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions backend/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,36 @@ def is_model_cached(
return False


def _detect_iris_igpu() -> bool:
"""
Detect if system has Intel Iris integrated GPU (Windows).

Iris iGPU is typically found on:
- Intel i5/i7 12th-14th gen (Alder Lake, Raptor Lake) with Iris Xe Graphics
- Intel Arc A-series mobile discrete GPUs also work with this path

Returns True if Iris/Intel iGPU detected, False otherwise.
"""
if platform.system() != "Windows":
return False

try:
import wmi
wmi_obj = wmi.WMI()
for item in wmi_obj.Win32_VideoController():
name = item.Name or ""
# Match Intel Iris, UHD, Arc Graphics names
if any(intel_gfx in name for intel_gfx in ["Iris", "UHD Graphics", "Arc", "Intel Arc"]):
logger.info(f"Detected Intel iGPU: {name}")
return True
except (ImportError, Exception) as e:
logger.debug(f"Could not detect Iris iGPU via WMI: {e}")
# Fallback: just try DirectML and log what's available
pass

return False


def get_torch_device(
*,
allow_xpu: bool = False,
Expand All @@ -92,6 +122,9 @@ def get_torch_device(
allow_directml: Check for DirectML (Windows) support.
allow_mps: Allow MPS (Apple Silicon). If False, MPS falls back to CPU.
force_cpu_on_mac: Force CPU on macOS regardless of GPU availability.

Priority: CUDA > XPU > DirectML > MPS > CPU
DirectML on Windows covers Intel iGPU (Iris/UHD), AMD iGPU, Arc discrete.
"""
if force_cpu_on_mac and platform.system() == "Darwin":
return "cpu"
Expand All @@ -106,6 +139,7 @@ def get_torch_device(
import intel_extension_for_pytorch # noqa: F401

if hasattr(torch, "xpu") and torch.xpu.is_available():
logger.info("Using Intel XPU device")
return "xpu"
except ImportError:
pass
Expand All @@ -114,15 +148,26 @@ def get_torch_device(
try:
import torch_directml

if torch_directml.device_count() > 0:
return torch_directml.device(0)
device_count = torch_directml.device_count()
if device_count > 0:
device = torch_directml.device(0)
iris_detected = _detect_iris_igpu()
if iris_detected:
logger.info("Using DirectML device (Intel Iris iGPU detected)")
else:
logger.info("Using DirectML device (Windows GPU acceleration via DirectML)")
return device
except ImportError:
pass
logger.debug("torch_directml not installed, falling back to CPU or MPS")
except Exception as e:
logger.warning(f"DirectML initialization failed: {e}, falling back to CPU or MPS")

if allow_mps:
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
logger.info("Using MPS (Apple Metal Performance Shaders)")
return "mps"

logger.info("No GPU detected, using CPU")
return "cpu"


Expand Down
4 changes: 2 additions & 2 deletions backend/backends/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, model_size: str = "1.7B"):
self._current_model_size = None

def _get_device(self) -> str:
"""Get the best available device."""
"""Get the best available device (CUDA > XPU > DirectML > CPU)."""
return get_torch_device(allow_xpu=True, allow_directml=True)

def is_loaded(self) -> bool:
Expand Down Expand Up @@ -255,7 +255,7 @@ def __init__(self, model_size: str = "base"):
self.device = self._get_device()

def _get_device(self) -> str:
"""Get the best available device."""
"""Get the best available device (CUDA > XPU > DirectML > CPU)."""
return get_torch_device(allow_xpu=True, allow_directml=True)

def is_loaded(self) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ accelerate>=0.26.0
huggingface_hub>=0.20.0
qwen-tts>=0.0.5

# DirectML support for Windows Intel iGPU / integrated GPUs (Iris, UHD, Arc)
# Latest dev version (1.13.0 stable not yet released); use latest 0.2.x dev
torch-directml>=0.2.0 ; platform_system == "Windows"

# LuxTTS (voice cloning engine)
# piper-phonemize needs custom index (no PyPI wheels)
--find-links https://k2-fsa.github.io/icefall/piper_phonemize.html
Expand Down
179 changes: 179 additions & 0 deletions backend/tests/test_directml_iris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
Test DirectML device detection and Iris iGPU support on Windows.

Run with: pytest backend/tests/test_directml_iris.py -v -s
"""

import platform
import logging
import pytest

logger = logging.getLogger(__name__)


@pytest.mark.skipif(platform.system() != "Windows", reason="DirectML tests only on Windows")
class TestDirectMLDetection:
"""Test DirectML device availability and Iris iGPU detection."""

def test_directml_import(self):
"""Test torch_directml can be imported."""
try:
import torch_directml
assert torch_directml is not None
logger.info("✓ torch_directml imported successfully")
except ImportError as e:
pytest.skip(f"torch_directml not installed: {e}")

def test_directml_device_count(self):
"""Test DirectML detects at least one device."""
try:
import torch_directml
device_count = torch_directml.device_count()
assert device_count > 0, f"DirectML device_count returned {device_count}, expected > 0"
logger.info(f"✓ DirectML detected {device_count} device(s)")
except ImportError:
pytest.skip("torch_directml not installed")

def test_directml_device_creation(self):
"""Test creating a DirectML device object."""
try:
import torch_directml
if torch_directml.device_count() > 0:
device = torch_directml.device(0)
assert device is not None
logger.info(f"✓ DirectML device created: {device}")
except ImportError:
pytest.skip("torch_directml not installed")

def test_get_torch_device_directml(self):
"""Test get_torch_device returns DirectML on Windows with iGPU."""
from ..backends.base import get_torch_device
import torch

device = get_torch_device(allow_directml=True)
logger.info(f"Selected device: {device}")

# On Windows with iGPU and torch_directml installed, should use DirectML
try:
import torch_directml
if torch_directml.device_count() > 0:
# Should be DirectML device, not CPU
assert str(device) != "cpu", f"Expected DirectML but got {device}"
logger.info(f"✓ DirectML device selected: {device}")
except ImportError:
logger.info("torch_directml not installed, may fall back to CPU")

def test_iris_igpu_detection(self):
"""Test Iris iGPU detection via WMI."""
from ..backends.base import _detect_iris_igpu

try:
import wmi
has_iris = _detect_iris_igpu()
logger.info(f"Iris iGPU detected: {has_iris}")
except ImportError:
logger.info("wmi module not available, skipping Iris detection test")


@pytest.mark.skipif(platform.system() != "Windows", reason="DirectML tests only on Windows")
class TestDirectMLTorchTensor:
"""Test basic torch tensor operations on DirectML device."""

def test_torch_tensor_on_directml(self):
"""Test creating and operating on tensors with DirectML."""
try:
import torch
import torch_directml

if torch_directml.device_count() == 0:
pytest.skip("No DirectML devices available")

device = torch_directml.device(0)
x = torch.randn(3, 3, device=device)
y = torch.randn(3, 3, device=device)
z = torch.mm(x, y)

assert z.shape == (3, 3)
logger.info(f"✓ Tensor operation successful on {device}")
logger.info(f" Result shape: {z.shape}")
except ImportError:
pytest.skip("torch_directml not installed")

def test_directml_memory_management(self):
"""Test DirectML memory can be freed properly."""
try:
import torch
import torch_directml

if torch_directml.device_count() == 0:
pytest.skip("No DirectML devices available")

device = torch_directml.device(0)
# Create and delete tensors to check memory cleanup
for _ in range(5):
x = torch.randn(1000, 1000, device=device)
del x

logger.info("✓ DirectML memory management OK")
except ImportError:
pytest.skip("torch_directml not installed")


@pytest.mark.skipif(platform.system() != "Windows", reason="Model tests only on Windows")
@pytest.mark.asyncio
class TestWhisperOnDirectML:
"""Test Whisper (STT) model on DirectML device."""

async def test_whisper_model_loads_on_directml(self):
"""Test Whisper model can load on DirectML."""
try:
import torch_directml
if torch_directml.device_count() == 0:
pytest.skip("No DirectML devices available")
except ImportError:
pytest.skip("torch_directml not installed")

from ..backends.pytorch_backend import PyTorchSTTBackend

backend = PyTorchSTTBackend(model_size="base")
assert backend.device != "cpu", f"Expected GPU device, got {backend.device}"
logger.info(f"✓ Whisper backend using device: {backend.device}")

# Try to load the model (this will download if needed)
try:
await backend.load_model_async("base")
assert backend.is_loaded()
logger.info("✓ Whisper model loaded successfully on DirectML")
backend.unload_model()
except (TimeoutError, ConnectionError, OSError) as e:
pytest.skip(f"Environment/network limitation during model load: {e}")


@pytest.mark.skipif(platform.system() != "Windows", reason="Model tests only on Windows")
@pytest.mark.asyncio
class TestQwenTTSOnDirectML:
"""Test Qwen TTS model on DirectML device."""

async def test_qwen_tts_loads_on_directml(self):
"""Test Qwen TTS model can load on DirectML."""
try:
import torch_directml
if torch_directml.device_count() == 0:
pytest.skip("No DirectML devices available")
except ImportError:
pytest.skip("torch_directml not installed")

from ..backends.pytorch_backend import PyTorchTTSBackend

backend = PyTorchTTSBackend(model_size="0.6B")
assert backend.device != "cpu", f"Expected GPU device, got {backend.device}"
logger.info(f"✓ Qwen TTS backend using device: {backend.device}")

# Try to load the model (this will download if needed)
try:
await backend.load_model_async("0.6B")
assert backend.is_loaded()
logger.info("✓ Qwen TTS model loaded successfully on DirectML")
backend.unload_model()
except (TimeoutError, ConnectionError, OSError) as e:
pytest.skip(f"Environment/network limitation during model load: {e}")