diff --git a/backend/backends/base.py b/backend/backends/base.py index c566af10..9168b1fe 100644 --- a/backend/backends/base.py +++ b/backend/backends/base.py @@ -77,6 +77,36 @@ def is_model_cached( return False +def _detect_iris_igpu() -> bool: + """ + Detect if system has Intel Iris integrated GPU (Windows). + + Iris iGPU is typically found on: + - Intel i5/i7 12th-14th gen (Alder Lake, Raptor Lake) with Iris Xe Graphics + - Intel Arc A-series mobile discrete GPUs also work with this path + + Returns True if Iris/Intel iGPU detected, False otherwise. + """ + if platform.system() != "Windows": + return False + + try: + import wmi + wmi_obj = wmi.WMI() + for item in wmi_obj.Win32_VideoController(): + name = item.Name or "" + # Match Intel Iris, UHD, Arc Graphics names + if any(intel_gfx in name for intel_gfx in ["Iris", "UHD Graphics", "Arc", "Intel Arc"]): + logger.info(f"Detected Intel iGPU: {name}") + return True + except (ImportError, Exception) as e: + logger.debug(f"Could not detect Iris iGPU via WMI: {e}") + # Fallback: just try DirectML and log what's available + pass + + return False + + def get_torch_device( *, allow_xpu: bool = False, @@ -92,6 +122,9 @@ def get_torch_device( allow_directml: Check for DirectML (Windows) support. allow_mps: Allow MPS (Apple Silicon). If False, MPS falls back to CPU. force_cpu_on_mac: Force CPU on macOS regardless of GPU availability. + + Priority: CUDA > XPU > DirectML > MPS > CPU + DirectML on Windows covers Intel iGPU (Iris/UHD), AMD iGPU, Arc discrete. """ if force_cpu_on_mac and platform.system() == "Darwin": return "cpu" @@ -106,6 +139,7 @@ def get_torch_device( import intel_extension_for_pytorch # noqa: F401 if hasattr(torch, "xpu") and torch.xpu.is_available(): + logger.info("Using Intel XPU device") return "xpu" except ImportError: pass @@ -114,15 +148,26 @@ def get_torch_device( try: import torch_directml - if torch_directml.device_count() > 0: - return torch_directml.device(0) + device_count = torch_directml.device_count() + if device_count > 0: + device = torch_directml.device(0) + iris_detected = _detect_iris_igpu() + if iris_detected: + logger.info("Using DirectML device (Intel Iris iGPU detected)") + else: + logger.info("Using DirectML device (Windows GPU acceleration via DirectML)") + return device except ImportError: - pass + logger.debug("torch_directml not installed, falling back to CPU or MPS") + except Exception as e: + logger.warning(f"DirectML initialization failed: {e}, falling back to CPU or MPS") if allow_mps: if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + logger.info("Using MPS (Apple Metal Performance Shaders)") return "mps" + logger.info("No GPU detected, using CPU") return "cpu" diff --git a/backend/backends/pytorch_backend.py b/backend/backends/pytorch_backend.py index f8ae79b8..28391f0a 100644 --- a/backend/backends/pytorch_backend.py +++ b/backend/backends/pytorch_backend.py @@ -33,7 +33,7 @@ def __init__(self, model_size: str = "1.7B"): self._current_model_size = None def _get_device(self) -> str: - """Get the best available device.""" + """Get the best available device (CUDA > XPU > DirectML > CPU).""" return get_torch_device(allow_xpu=True, allow_directml=True) def is_loaded(self) -> bool: @@ -255,7 +255,7 @@ def __init__(self, model_size: str = "base"): self.device = self._get_device() def _get_device(self) -> str: - """Get the best available device.""" + """Get the best available device (CUDA > XPU > DirectML > CPU).""" return get_torch_device(allow_xpu=True, allow_directml=True) def is_loaded(self) -> bool: diff --git a/backend/requirements.txt b/backend/requirements.txt index caafc0e7..d68a8003 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -14,6 +14,10 @@ accelerate>=0.26.0 huggingface_hub>=0.20.0 qwen-tts>=0.0.5 +# DirectML support for Windows Intel iGPU / integrated GPUs (Iris, UHD, Arc) +# Latest dev version (1.13.0 stable not yet released); use latest 0.2.x dev +torch-directml>=0.2.0 ; platform_system == "Windows" + # LuxTTS (voice cloning engine) # piper-phonemize needs custom index (no PyPI wheels) --find-links https://k2-fsa.github.io/icefall/piper_phonemize.html diff --git a/backend/tests/test_directml_iris.py b/backend/tests/test_directml_iris.py new file mode 100644 index 00000000..99fc6926 --- /dev/null +++ b/backend/tests/test_directml_iris.py @@ -0,0 +1,179 @@ +""" +Test DirectML device detection and Iris iGPU support on Windows. + +Run with: pytest backend/tests/test_directml_iris.py -v -s +""" + +import platform +import logging +import pytest + +logger = logging.getLogger(__name__) + + +@pytest.mark.skipif(platform.system() != "Windows", reason="DirectML tests only on Windows") +class TestDirectMLDetection: + """Test DirectML device availability and Iris iGPU detection.""" + + def test_directml_import(self): + """Test torch_directml can be imported.""" + try: + import torch_directml + assert torch_directml is not None + logger.info("✓ torch_directml imported successfully") + except ImportError as e: + pytest.skip(f"torch_directml not installed: {e}") + + def test_directml_device_count(self): + """Test DirectML detects at least one device.""" + try: + import torch_directml + device_count = torch_directml.device_count() + assert device_count > 0, f"DirectML device_count returned {device_count}, expected > 0" + logger.info(f"✓ DirectML detected {device_count} device(s)") + except ImportError: + pytest.skip("torch_directml not installed") + + def test_directml_device_creation(self): + """Test creating a DirectML device object.""" + try: + import torch_directml + if torch_directml.device_count() > 0: + device = torch_directml.device(0) + assert device is not None + logger.info(f"✓ DirectML device created: {device}") + except ImportError: + pytest.skip("torch_directml not installed") + + def test_get_torch_device_directml(self): + """Test get_torch_device returns DirectML on Windows with iGPU.""" + from ..backends.base import get_torch_device + import torch + + device = get_torch_device(allow_directml=True) + logger.info(f"Selected device: {device}") + + # On Windows with iGPU and torch_directml installed, should use DirectML + try: + import torch_directml + if torch_directml.device_count() > 0: + # Should be DirectML device, not CPU + assert str(device) != "cpu", f"Expected DirectML but got {device}" + logger.info(f"✓ DirectML device selected: {device}") + except ImportError: + logger.info("torch_directml not installed, may fall back to CPU") + + def test_iris_igpu_detection(self): + """Test Iris iGPU detection via WMI.""" + from ..backends.base import _detect_iris_igpu + + try: + import wmi + has_iris = _detect_iris_igpu() + logger.info(f"Iris iGPU detected: {has_iris}") + except ImportError: + logger.info("wmi module not available, skipping Iris detection test") + + +@pytest.mark.skipif(platform.system() != "Windows", reason="DirectML tests only on Windows") +class TestDirectMLTorchTensor: + """Test basic torch tensor operations on DirectML device.""" + + def test_torch_tensor_on_directml(self): + """Test creating and operating on tensors with DirectML.""" + try: + import torch + import torch_directml + + if torch_directml.device_count() == 0: + pytest.skip("No DirectML devices available") + + device = torch_directml.device(0) + x = torch.randn(3, 3, device=device) + y = torch.randn(3, 3, device=device) + z = torch.mm(x, y) + + assert z.shape == (3, 3) + logger.info(f"✓ Tensor operation successful on {device}") + logger.info(f" Result shape: {z.shape}") + except ImportError: + pytest.skip("torch_directml not installed") + + def test_directml_memory_management(self): + """Test DirectML memory can be freed properly.""" + try: + import torch + import torch_directml + + if torch_directml.device_count() == 0: + pytest.skip("No DirectML devices available") + + device = torch_directml.device(0) + # Create and delete tensors to check memory cleanup + for _ in range(5): + x = torch.randn(1000, 1000, device=device) + del x + + logger.info("✓ DirectML memory management OK") + except ImportError: + pytest.skip("torch_directml not installed") + + +@pytest.mark.skipif(platform.system() != "Windows", reason="Model tests only on Windows") +@pytest.mark.asyncio +class TestWhisperOnDirectML: + """Test Whisper (STT) model on DirectML device.""" + + async def test_whisper_model_loads_on_directml(self): + """Test Whisper model can load on DirectML.""" + try: + import torch_directml + if torch_directml.device_count() == 0: + pytest.skip("No DirectML devices available") + except ImportError: + pytest.skip("torch_directml not installed") + + from ..backends.pytorch_backend import PyTorchSTTBackend + + backend = PyTorchSTTBackend(model_size="base") + assert backend.device != "cpu", f"Expected GPU device, got {backend.device}" + logger.info(f"✓ Whisper backend using device: {backend.device}") + + # Try to load the model (this will download if needed) + try: + await backend.load_model_async("base") + assert backend.is_loaded() + logger.info("✓ Whisper model loaded successfully on DirectML") + backend.unload_model() + except (TimeoutError, ConnectionError, OSError) as e: + pytest.skip(f"Environment/network limitation during model load: {e}") + + +@pytest.mark.skipif(platform.system() != "Windows", reason="Model tests only on Windows") +@pytest.mark.asyncio +class TestQwenTTSOnDirectML: + """Test Qwen TTS model on DirectML device.""" + + async def test_qwen_tts_loads_on_directml(self): + """Test Qwen TTS model can load on DirectML.""" + try: + import torch_directml + if torch_directml.device_count() == 0: + pytest.skip("No DirectML devices available") + except ImportError: + pytest.skip("torch_directml not installed") + + from ..backends.pytorch_backend import PyTorchTTSBackend + + backend = PyTorchTTSBackend(model_size="0.6B") + assert backend.device != "cpu", f"Expected GPU device, got {backend.device}" + logger.info(f"✓ Qwen TTS backend using device: {backend.device}") + + # Try to load the model (this will download if needed) + try: + await backend.load_model_async("0.6B") + assert backend.is_loaded() + logger.info("✓ Qwen TTS model loaded successfully on DirectML") + backend.unload_model() + except (TimeoutError, ConnectionError, OSError) as e: + pytest.skip(f"Environment/network limitation during model load: {e}")