From 7984cd4caee187b6a1aaa97254c5c4147a6dd7b2 Mon Sep 17 00:00:00 2001 From: Will Anderson Date: Thu, 14 May 2026 11:52:20 -0500 Subject: [PATCH] Add Metal JIT warmup, cache inspect.signature, and add 'metal' health label --- backend/backends/mlx_backend.py | 60 +++++++++++++++++++-------------- backend/routes/health.py | 22 ++++++------ 2 files changed, 44 insertions(+), 38 deletions(-) diff --git a/backend/backends/mlx_backend.py b/backend/backends/mlx_backend.py index 9692e59b..28e1a733 100644 --- a/backend/backends/mlx_backend.py +++ b/backend/backends/mlx_backend.py @@ -2,24 +2,24 @@ MLX backend implementation for TTS and STT using mlx-audio. """ -from typing import Optional, List, Tuple import asyncio import logging -import numpy as np from pathlib import Path +import numpy as np + logger = logging.getLogger(__name__) # PATCH: Import and apply offline patch BEFORE any huggingface_hub usage # This prevents mlx_audio from making network requests when models are cached -from ..utils.hf_offline_patch import patch_huggingface_hub_offline, ensure_original_qwen_config_cached +from ..utils.hf_offline_patch import ensure_original_qwen_config_cached, patch_huggingface_hub_offline patch_huggingface_hub_offline() ensure_original_qwen_config_cached() -from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS -from .base import is_model_cached, combine_voice_prompts as _combine_voice_prompts, model_load_progress -from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt +from ..utils.cache import cache_voice_prompt, get_cache_key, get_cached_voice_prompt +from . import LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS +from .base import combine_voice_prompts as _combine_voice_prompts, is_model_cached, model_load_progress class MLXTTSBackend: @@ -63,7 +63,7 @@ def _is_model_cached(self, model_size: str) -> bool: weight_extensions=(".safetensors", ".bin", ".npz"), ) - async def load_model_async(self, model_size: Optional[str] = None): + async def load_model_async(self, model_size: str | None = None): """ Lazy load the MLX TTS model. @@ -100,6 +100,19 @@ def _load_model_sync(self, model_size: str): self.model = load(model_path) + import inspect + + self._supports_ref_audio = "ref_audio" in inspect.signature(self.model.generate).parameters + + # Warm up Metal JIT kernels — first inference compiles shaders, shift cost to load time + try: + logger.info("Warming up Metal kernels...") + for _ in self.model.generate("Hello.", lang_code="english"): + break # one token is enough to trigger compilation + logger.info("Metal warmup complete") + except Exception as e: + logger.warning("Warmup failed (non-fatal): %s", e) + self._current_model_size = model_size self.model_size = model_size logger.info("MLX TTS model %s loaded successfully", model_size) @@ -117,7 +130,7 @@ async def create_voice_prompt( audio_path: str, reference_text: str, use_cache: bool = True, - ) -> Tuple[dict, bool]: + ) -> tuple[dict, bool]: """ Create voice prompt from reference audio. @@ -145,9 +158,8 @@ async def create_voice_prompt( cached_audio_path = cached_prompt.get("ref_audio") or cached_prompt.get("ref_audio_path") if cached_audio_path and Path(cached_audio_path).exists(): return cached_prompt, True - else: - # Cached file no longer exists, invalidate cache - logger.warning("Cached audio file not found: %s, regenerating prompt", cached_audio_path) + # Cached file no longer exists, invalidate cache + logger.warning("Cached audio file not found: %s, regenerating prompt", cached_audio_path) # MLX voice prompt format - store audio path and text # The model will process this during generation @@ -171,9 +183,9 @@ async def generate( text: str, voice_prompt: dict, language: str = "en", - seed: Optional[int] = None, - instruct: Optional[str] = None, - ) -> Tuple[np.ndarray, int]: + seed: int | None = None, + instruct: str | None = None, + ) -> tuple[np.ndarray, int]: """ Generate audio from text using voice prompt. @@ -223,11 +235,8 @@ def _generate_sync(): # legitimate metadata calls during generation. try: if ref_audio: - # Check if generate accepts ref_audio parameter - import inspect - - sig = inspect.signature(self.model.generate) - if "ref_audio" in sig.parameters: + # Use cached capability flag set at model load time + if self._supports_ref_audio: # Generate with voice cloning for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang): audio_chunks.append(np.array(result.audio)) @@ -279,7 +288,7 @@ def _is_model_cached(self, model_size: str) -> bool: hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}") return is_model_cached(hf_repo, weight_extensions=(".safetensors", ".bin", ".npz")) - async def load_model_async(self, model_size: Optional[str] = None): + async def load_model_async(self, model_size: str | None = None): """ Lazy load the MLX Whisper model. @@ -324,8 +333,8 @@ def unload_model(self): async def transcribe( self, audio_path: str, - language: Optional[str] = None, - model_size: Optional[str] = None, + language: str | None = None, + model_size: str | None = None, ) -> str: """ Transcribe audio to text. @@ -356,12 +365,11 @@ def _transcribe_sync(): # Extract text from result if isinstance(result, str): return result.strip() - elif isinstance(result, dict): + if isinstance(result, dict): return result.get("text", "").strip() - elif hasattr(result, "text"): + if hasattr(result, "text"): return result.text.strip() - else: - return str(result).strip() + return str(result).strip() # Run blocking transcription in thread pool return await asyncio.to_thread(_transcribe_sync) diff --git a/backend/routes/health.py b/backend/routes/health.py index e5ad86ec..d39e7f9a 100644 --- a/backend/routes/health.py +++ b/backend/routes/health.py @@ -1,18 +1,17 @@ """Health and infrastructure endpoints.""" import asyncio +import contextlib import os import signal from pathlib import Path import torch -from fastapi import APIRouter, Depends +from fastapi import APIRouter from fastapi.responses import FileResponse -from sqlalchemy.orm import Session from .. import config, models from ..services import tts -from ..database import get_db from ..utils.platform_detect import get_backend_type router = APIRouter() @@ -40,7 +39,7 @@ async def shutdown_async(): await asyncio.sleep(0.1) os.kill(os.getpid(), signal.SIGTERM) - asyncio.create_task(shutdown_async()) + asyncio.create_task(shutdown_async()) # noqa: RUF006 — fire-and-forget shutdown return {"message": "Shutting down..."} @@ -56,9 +55,10 @@ async def watchdog_disable(): @router.get("/health", response_model=models.HealthResponse) async def health(): """Health check endpoint.""" - from huggingface_hub import constants as hf_constants from pathlib import Path + from huggingface_hub import constants as hf_constants + tts_model = tts.get_tts_model() backend_type = get_backend_type() @@ -117,10 +117,8 @@ async def health(): if has_cuda: vram_used = torch.cuda.memory_allocated() / 1024 / 1024 elif has_xpu: - try: + with contextlib.suppress(Exception): # memory_allocated() may not be available on all IPEX versions vram_used = torch.xpu.memory_allocated() / 1024 / 1024 - except Exception: - pass # memory_allocated() may not be available on all IPEX versions model_loaded = False model_size = None @@ -175,7 +173,9 @@ async def health(): backend_type=backend_type, backend_variant=os.environ.get( "VOICEBOX_BACKEND_VARIANT", - "cuda" if torch.cuda.is_available() else ("xpu" if has_xpu else "cpu"), + "cuda" + if torch.cuda.is_available() + else ("xpu" if has_xpu else ("metal" if backend_type == "mlx" else "cpu")), ), gpu_compatibility_warning=gpu_compat_warning, ) @@ -211,10 +211,8 @@ async def filesystem_health(): except OSError as e: error = str(e) finally: - try: + with contextlib.suppress(Exception): probe.unlink(missing_ok=True) - except Exception: - pass else: error = "Directory does not exist"