diff --git a/app/src/components/Generation/EngineModelSelector.tsx b/app/src/components/Generation/EngineModelSelector.tsx index 7f4f600b..53cab228 100644 --- a/app/src/components/Generation/EngineModelSelector.tsx +++ b/app/src/components/Generation/EngineModelSelector.tsx @@ -1,4 +1,6 @@ import { useEffect } from 'react'; +import { useQuery } from '@tanstack/react-query'; +import { Lock } from 'lucide-react'; import type { UseFormReturn } from 'react-hook-form'; import { FormControl } from '@/components/ui/form'; import { @@ -8,6 +10,8 @@ import { SelectTrigger, SelectValue, } from '@/components/ui/select'; +import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'; +import { apiClient } from '@/lib/api/client'; import type { VoiceProfileResponse } from '@/lib/api/types'; import { getLanguageOptionsForEngine } from '@/lib/constants/languages'; import type { GenerationFormValues } from '@/lib/hooks/useGenerationForm'; @@ -39,6 +43,22 @@ const ENGINE_DESCRIPTIONS: Record = { kokoro: '82M params, CPU realtime, 8 langs', }; +/** Map from engine name to the model_name used by the backend status API. */ +const ENGINE_TO_MODEL_NAME: Partial> = { + kokoro: 'kokoro', + luxtts: 'luxtts', + chatterbox: 'chatterbox-tts', + chatterbox_turbo: 'chatterbox-turbo', +}; + +/** Derive the backend model_name string for an ENGINE_OPTIONS entry. */ +function deriveModelName(opt: (typeof ENGINE_OPTIONS)[number]): string { + if (opt.value.includes(':')) { + return opt.engine + '-tts-' + opt.value.split(':')[1]; + } + return ENGINE_TO_MODEL_NAME[opt.engine] ?? ''; +} + /** Engines that only support English and should force language to 'en' on select. */ const ENGLISH_ONLY_ENGINES = new Set(['luxtts', 'chatterbox_turbo']); @@ -119,6 +139,54 @@ export function EngineModelSelector({ form, compact, selectedProfile }: EngineMo const selectValue = getSelectValue(engine, modelSize); const availableOptions = getAvailableOptions(selectedProfile); + // Fetch model status to get platform_compatible per engine. + // staleTime is long — this rarely changes within a session. + const { data: modelStatus } = useQuery({ + queryKey: ['modelStatus'], + queryFn: () => apiClient.getModelStatus(), + staleTime: 1000 * 60 * 5, + }); + + // Build a set of engine names that are platform-incompatible. + const incompatibleEngines = new Set(); + const engineRequires: Record = {}; + if (modelStatus?.models) { + for (const m of modelStatus.models) { + if (!m.platform_compatible && m.requires && m.requires.length > 0) { + incompatibleEngines.add(m.model_name.split('-')[0]); // coarse key + // Map by engine derived from model_name patterns + // We'll match by checking the option engine field below + } + } + // More precise: map from engine name to requires via model entries + for (const m of modelStatus.models) { + if (m.requires && m.requires.length > 0 && !m.platform_compatible) { + // Derive engine name from model — use the requires list keyed by display + // We rely on the fact that incompatible means requires is non-empty + // Store by hf_repo_id to later map to engine + } + } + } + + // Build engine -> {compatible, requires} from model status + const engineCompatibility: Record = {}; + if (modelStatus?.models) { + for (const m of modelStatus.models) { + // Derive the engine name from model_name by checking ENGINE_OPTIONS + for (const opt of ENGINE_OPTIONS) { + const optModelName = deriveModelName(opt); + if (m.model_name === optModelName || m.model_name.startsWith(opt.engine.replace('_', '-'))) { + if (!engineCompatibility[opt.engine] || !m.platform_compatible) { + engineCompatibility[opt.engine] = { + compatible: m.platform_compatible, + requires: m.requires ?? [], + }; + } + } + } + } + } + const currentEngineAvailable = availableOptions.some((opt) => opt.value === selectValue); useEffect(() => { @@ -140,11 +208,42 @@ export function EngineModelSelector({ form, compact, selectedProfile }: EngineMo - {availableOptions.map((opt) => ( - - {opt.label} - - ))} + {availableOptions.map((opt) => { + const compat = engineCompatibility[opt.engine]; + const isIncompatible = compat && !compat.compatible && compat.requires.length > 0; + const requiresLabel = compat?.requires.join('/') ?? ''; + + if (isIncompatible) { + return ( + + +
+ + + + {opt.label} + + +
+
+ + Requires {requiresLabel} hardware + +
+ ); + } + + return ( + + {opt.label} + + ); + })}
); diff --git a/app/src/components/ServerSettings/ModelManagement.tsx b/app/src/components/ServerSettings/ModelManagement.tsx index b06783f9..748a31d7 100644 --- a/app/src/components/ServerSettings/ModelManagement.tsx +++ b/app/src/components/ServerSettings/ModelManagement.tsx @@ -11,6 +11,7 @@ import { HardDrive, Heart, Loader2, + Lock, RotateCcw, Scale, Trash2, @@ -552,6 +553,8 @@ export function ModelManagement() { ) : isDownloading ? ( + ) : !model.platform_compatible ? ( + ) : model.loaded ? ( ) : model.downloaded ? ( @@ -589,6 +592,16 @@ export function ModelManagement() { {t('common.error')} )} + {!model.platform_compatible && model.requires.length > 0 && ( + + + {model.requires.join('/')} + + )} {model.loaded && ( {t('models.status.loaded')} diff --git a/app/src/lib/api/types.ts b/app/src/lib/api/types.ts index 37ca4667..55832f76 100644 --- a/app/src/lib/api/types.ts +++ b/app/src/lib/api/types.ts @@ -310,6 +310,10 @@ export interface ModelStatus { downloading: boolean; // True if download is in progress size_mb?: number; loaded: boolean; + /** False when the engine has hardware requirements that this machine does not meet. */ + platform_compatible: boolean; + /** Hardware platforms required — mirrors ModelConfig.requires on the backend. */ + requires: string[]; } export interface HuggingFaceModelInfo { diff --git a/backend/backends/__init__.py b/backend/backends/__init__.py index 2437a87b..8b265828 100644 --- a/backend/backends/__init__.py +++ b/backend/backends/__init__.py @@ -11,6 +11,7 @@ # unconditional HuggingFace metadata call that otherwise raises on # HF_HUB_OFFLINE=1 and on network failures. from ..utils import hf_offline_patch # noqa: F401 +from ..utils.platform_detect import get_backend_type, get_supported_platforms import threading from dataclasses import dataclass, field @@ -21,8 +22,6 @@ DEFAULT_LLM_MAX_TOKENS = 512 DEFAULT_LLM_TEMPERATURE = 0.7 -from ..utils.platform_detect import get_backend_type - LANGUAGE_CODE_TO_NAME = { "zh": "chinese", "en": "english", @@ -58,6 +57,14 @@ class ModelConfig: needs_trim: bool = False supports_instruct: bool = False languages: list[str] = field(default_factory=lambda: ["en"]) + requires: list[str] = field(default_factory=list) + """Hardware platforms required to run this engine. + + Values: "cuda", "mps", "xpu", "rocm", "cpu" + An empty list means the engine runs on all platforms. + A non-empty list means the engine ONLY runs on (any of) the listed + platforms — the UI will hide/disable it on incompatible hardware. + """ @runtime_checkable @@ -510,8 +517,45 @@ def engine_has_model_sizes(engine: str) -> bool: return len(configs) > 1 +def is_engine_platform_compatible(engine: str) -> bool: + """Return True if the current machine can run at least one variant of the engine. + + An engine with an empty ``requires`` list is always compatible. + Otherwise the machine must support at least one of the required platforms + for at least one model variant. + """ + configs = [c for c in get_tts_model_configs() if c.engine == engine] + if not configs: + return True # unknown engine — don't gate + supported = get_supported_platforms() + return any( + (not cfg.requires) or any(p in supported for p in cfg.requires) + for cfg in configs + ) + + async def load_engine_model(engine: str, model_size: str = "default") -> None: """Load a model for the given engine, handling engines with multiple model sizes.""" + from fastapi import HTTPException + + # Hard guard — refuse to load an engine on an incompatible platform. + configs = [c for c in get_tts_model_configs() if c.engine == engine] + if configs: + # Use per-variant requirements when model_size is known; fall back to first. + selected = next((c for c in configs if c.model_size == model_size), configs[0]) + requires = selected.requires + if requires: + supported = get_supported_platforms() + if not any(p in supported for p in requires): + raise HTTPException( + status_code=400, + detail=( + f"Engine '{engine}' requires one of: {requires}. " + f"This machine supports: {supported}. " + "Download a compatible engine from Settings → Models." + ), + ) + backend = get_tts_backend_for_engine(engine) if engine in ("qwen", "qwen_custom_voice"): await backend.load_model_async(model_size) diff --git a/backend/models.py b/backend/models.py index 06f321ac..07b7613e 100644 --- a/backend/models.py +++ b/backend/models.py @@ -474,6 +474,8 @@ class ModelStatus(BaseModel): downloading: bool = False # True if download is in progress size_mb: Optional[float] = None loaded: bool = False + platform_compatible: bool = True # False when requires != [] and current platform not in requires + requires: List[str] = [] # Hardware platform requirements (mirrors ModelConfig.requires) class ModelStatusListResponse(BaseModel): diff --git a/backend/routes/models.py b/backend/routes/models.py index 7cbb7b04..329002ca 100644 --- a/backend/routes/models.py +++ b/backend/routes/models.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from .. import models -from ..utils.platform_detect import get_backend_type +from ..utils.platform_detect import get_backend_type, get_supported_platforms from ..services.task_queue import create_background_task from ..utils.progress import get_progress_manager from ..utils.tasks import get_task_manager @@ -242,6 +242,7 @@ async def get_model_status(): from ..backends import get_all_model_configs, check_model_loaded + supported_platforms = get_supported_platforms() registry_configs = get_all_model_configs() model_configs = [ { @@ -249,6 +250,10 @@ async def get_model_status(): "display_name": cfg.display_name, "hf_repo_id": cfg.hf_repo_id, "model_size": cfg.model_size, + "requires": cfg.requires, + "platform_compatible": ( + not cfg.requires or any(p in supported_platforms for p in cfg.requires) + ), "check_loaded": lambda c=cfg: check_model_loaded(c), } for cfg in registry_configs @@ -359,6 +364,8 @@ async def get_model_status(): downloading=is_downloading, size_mb=size_mb, loaded=loaded, + platform_compatible=config["platform_compatible"], + requires=config["requires"], ) ) except Exception: @@ -378,6 +385,8 @@ async def get_model_status(): downloading=is_downloading, size_mb=None, loaded=loaded, + platform_compatible=config["platform_compatible"], + requires=config["requires"], ) ) diff --git a/backend/utils/platform_detect.py b/backend/utils/platform_detect.py index 1ec2980a..12425c63 100644 --- a/backend/utils/platform_detect.py +++ b/backend/utils/platform_detect.py @@ -9,7 +9,7 @@ def is_apple_silicon() -> bool: """ Check if running on Apple Silicon (arm64 macOS). - + Returns: True if on Apple Silicon, False otherwise """ @@ -33,3 +33,46 @@ def get_backend_type() -> Literal["mlx", "pytorch"]: # Fall through to PyTorch. return "pytorch" return "pytorch" + + +def get_supported_platforms() -> list[str]: + """Return which compute platforms the current machine supports. + + Possible values: "cuda", "mps", "xpu", "rocm", "cpu" + + Rules: + - "cpu" is always included (every machine can run CPU inference). + - "cuda" is added when PyTorch reports a CUDA device available. + - "rocm" is added on ROCm builds (torch.version.hip is set). + - "mps" is added when the Metal Performance Shaders backend is available. + - "xpu" is added when Intel Extension for PyTorch detects an Arc/XPU device. + + Apple Silicon machines therefore return ["mps", "cpu"], a typical + CUDA Linux machine returns ["cuda", "cpu"], an Intel Arc machine returns + ["xpu", "cpu"], and a CPU-only machine returns ["cpu"]. + """ + supported: list[str] = [] + + try: + import torch + + if torch.cuda.is_available(): + # Distinguish ROCm from CUDA — both report via cuda.is_available() + # on the ROCm PyTorch build, but torch.version.hip is non-None. + is_rocm = hasattr(torch.version, "hip") and torch.version.hip is not None + if is_rocm: + supported.append("rocm") + else: + supported.append("cuda") + + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + supported.append("mps") + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + supported.append("xpu") + + except ImportError: + pass # torch not available at all — only CPU + + supported.append("cpu") + return supported