Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 104 additions & 5 deletions app/src/components/Generation/EngineModelSelector.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { useEffect } from 'react';
import { useQuery } from '@tanstack/react-query';
import { Lock } from 'lucide-react';
import type { UseFormReturn } from 'react-hook-form';
import { FormControl } from '@/components/ui/form';
import {
Expand All @@ -8,6 +10,8 @@ import {
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip';
import { apiClient } from '@/lib/api/client';
import type { VoiceProfileResponse } from '@/lib/api/types';
import { getLanguageOptionsForEngine } from '@/lib/constants/languages';
import type { GenerationFormValues } from '@/lib/hooks/useGenerationForm';
Expand Down Expand Up @@ -39,6 +43,22 @@ const ENGINE_DESCRIPTIONS: Record<string, string> = {
kokoro: '82M params, CPU realtime, 8 langs',
};

/** Map from engine name to the model_name used by the backend status API. */
const ENGINE_TO_MODEL_NAME: Partial<Record<string, string>> = {
kokoro: 'kokoro',
luxtts: 'luxtts',
chatterbox: 'chatterbox-tts',
chatterbox_turbo: 'chatterbox-turbo',
};

/** Derive the backend model_name string for an ENGINE_OPTIONS entry. */
function deriveModelName(opt: (typeof ENGINE_OPTIONS)[number]): string {
if (opt.value.includes(':')) {
return opt.engine + '-tts-' + opt.value.split(':')[1];
}
return ENGINE_TO_MODEL_NAME[opt.engine] ?? '';
}

/** Engines that only support English and should force language to 'en' on select. */
const ENGLISH_ONLY_ENGINES = new Set(['luxtts', 'chatterbox_turbo']);

Expand Down Expand Up @@ -119,6 +139,54 @@ export function EngineModelSelector({ form, compact, selectedProfile }: EngineMo
const selectValue = getSelectValue(engine, modelSize);
const availableOptions = getAvailableOptions(selectedProfile);

// Fetch model status to get platform_compatible per engine.
// staleTime is long — this rarely changes within a session.
const { data: modelStatus } = useQuery({
queryKey: ['modelStatus'],
queryFn: () => apiClient.getModelStatus(),
staleTime: 1000 * 60 * 5,
});

// Build a set of engine names that are platform-incompatible.
const incompatibleEngines = new Set<string>();
const engineRequires: Record<string, string[]> = {};
if (modelStatus?.models) {
for (const m of modelStatus.models) {
if (!m.platform_compatible && m.requires && m.requires.length > 0) {
incompatibleEngines.add(m.model_name.split('-')[0]); // coarse key
// Map by engine derived from model_name patterns
// We'll match by checking the option engine field below
}
}
// More precise: map from engine name to requires via model entries
for (const m of modelStatus.models) {
if (m.requires && m.requires.length > 0 && !m.platform_compatible) {
// Derive engine name from model — use the requires list keyed by display
// We rely on the fact that incompatible means requires is non-empty
// Store by hf_repo_id to later map to engine
}
}
}

// Build engine -> {compatible, requires} from model status
const engineCompatibility: Record<string, { compatible: boolean; requires: string[] }> = {};
if (modelStatus?.models) {
for (const m of modelStatus.models) {
// Derive the engine name from model_name by checking ENGINE_OPTIONS
for (const opt of ENGINE_OPTIONS) {
const optModelName = deriveModelName(opt);
if (m.model_name === optModelName || m.model_name.startsWith(opt.engine.replace('_', '-'))) {
if (!engineCompatibility[opt.engine] || !m.platform_compatible) {
engineCompatibility[opt.engine] = {
compatible: m.platform_compatible,
requires: m.requires ?? [],
};
}
}
}
}
}

const currentEngineAvailable = availableOptions.some((opt) => opt.value === selectValue);

useEffect(() => {
Expand All @@ -140,11 +208,42 @@ export function EngineModelSelector({ form, compact, selectedProfile }: EngineMo
</SelectTrigger>
</FormControl>
<SelectContent>
{availableOptions.map((opt) => (
<SelectItem key={opt.value} value={opt.value} className={itemClass}>
{opt.label}
</SelectItem>
))}
{availableOptions.map((opt) => {
const compat = engineCompatibility[opt.engine];
const isIncompatible = compat && !compat.compatible && compat.requires.length > 0;
const requiresLabel = compat?.requires.join('/') ?? '';

if (isIncompatible) {
return (
<Tooltip key={opt.value}>
<TooltipTrigger asChild>
<div>
<SelectItem
key={opt.value}
value={opt.value}
className={`${itemClass ?? ''} opacity-50 cursor-not-allowed`}
disabled
>
<span className="flex items-center gap-1.5">
<Lock className="h-3 w-3 shrink-0" />
{opt.label}
</span>
</SelectItem>
</div>
</TooltipTrigger>
<TooltipContent side="right">
Requires {requiresLabel} hardware
</TooltipContent>
</Tooltip>
);
}

return (
<SelectItem key={opt.value} value={opt.value} className={itemClass}>
{opt.label}
</SelectItem>
);
})}
</SelectContent>
</Select>
);
Expand Down
13 changes: 13 additions & 0 deletions app/src/components/ServerSettings/ModelManagement.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
HardDrive,
Heart,
Loader2,
Lock,
RotateCcw,
Scale,
Trash2,
Expand Down Expand Up @@ -552,6 +553,8 @@ export function ModelManagement() {
<CircleX className="h-4 w-4 text-destructive" />
) : isDownloading ? (
<Loader2 className="h-4 w-4 animate-spin text-muted-foreground" />
) : !model.platform_compatible ? (
<Lock className="h-4 w-4 text-muted-foreground/50" />
) : model.loaded ? (
<CircleCheck className="h-4 w-4 text-accent" />
) : model.downloaded ? (
Expand Down Expand Up @@ -589,6 +592,16 @@ export function ModelManagement() {
{t('common.error')}
</Badge>
)}
{!model.platform_compatible && model.requires.length > 0 && (
<Badge
variant="outline"
className="text-[10px] h-5 text-muted-foreground border-muted-foreground/30"
title={`Requires: ${model.requires.join(', ')}`}
>
<Lock className="h-2.5 w-2.5 mr-1" />
{model.requires.join('/')}
</Badge>
)}
{model.loaded && (
<Badge className="text-[10px] h-5 bg-accent/15 text-accent border-accent/30 hover:bg-accent/15">
{t('models.status.loaded')}
Expand Down
4 changes: 4 additions & 0 deletions app/src/lib/api/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ export interface ModelStatus {
downloading: boolean; // True if download is in progress
size_mb?: number;
loaded: boolean;
/** False when the engine has hardware requirements that this machine does not meet. */
platform_compatible: boolean;
/** Hardware platforms required — mirrors ModelConfig.requires on the backend. */
requires: string[];
}

export interface HuggingFaceModelInfo {
Expand Down
48 changes: 46 additions & 2 deletions backend/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# unconditional HuggingFace metadata call that otherwise raises on
# HF_HUB_OFFLINE=1 and on network failures.
from ..utils import hf_offline_patch # noqa: F401
from ..utils.platform_detect import get_backend_type, get_supported_platforms

import threading
from dataclasses import dataclass, field
Expand All @@ -21,8 +22,6 @@
DEFAULT_LLM_MAX_TOKENS = 512
DEFAULT_LLM_TEMPERATURE = 0.7

from ..utils.platform_detect import get_backend_type

LANGUAGE_CODE_TO_NAME = {
"zh": "chinese",
"en": "english",
Expand Down Expand Up @@ -58,6 +57,14 @@ class ModelConfig:
needs_trim: bool = False
supports_instruct: bool = False
languages: list[str] = field(default_factory=lambda: ["en"])
requires: list[str] = field(default_factory=list)
"""Hardware platforms required to run this engine.

Values: "cuda", "mps", "xpu", "rocm", "cpu"
An empty list means the engine runs on all platforms.
A non-empty list means the engine ONLY runs on (any of) the listed
platforms — the UI will hide/disable it on incompatible hardware.
"""


@runtime_checkable
Expand Down Expand Up @@ -510,8 +517,45 @@ def engine_has_model_sizes(engine: str) -> bool:
return len(configs) > 1


def is_engine_platform_compatible(engine: str) -> bool:
"""Return True if the current machine can run at least one variant of the engine.

An engine with an empty ``requires`` list is always compatible.
Otherwise the machine must support at least one of the required platforms
for at least one model variant.
"""
configs = [c for c in get_tts_model_configs() if c.engine == engine]
if not configs:
return True # unknown engine — don't gate
supported = get_supported_platforms()
return any(
(not cfg.requires) or any(p in supported for p in cfg.requires)
for cfg in configs
)

Comment thread
coderabbitai[bot] marked this conversation as resolved.

async def load_engine_model(engine: str, model_size: str = "default") -> None:
"""Load a model for the given engine, handling engines with multiple model sizes."""
from fastapi import HTTPException

# Hard guard — refuse to load an engine on an incompatible platform.
configs = [c for c in get_tts_model_configs() if c.engine == engine]
if configs:
# Use per-variant requirements when model_size is known; fall back to first.
selected = next((c for c in configs if c.model_size == model_size), configs[0])
requires = selected.requires
if requires:
supported = get_supported_platforms()
if not any(p in supported for p in requires):
raise HTTPException(
status_code=400,
detail=(
f"Engine '{engine}' requires one of: {requires}. "
f"This machine supports: {supported}. "
"Download a compatible engine from Settings → Models."
),
)

backend = get_tts_backend_for_engine(engine)
if engine in ("qwen", "qwen_custom_voice"):
await backend.load_model_async(model_size)
Expand Down
2 changes: 2 additions & 0 deletions backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,8 @@ class ModelStatus(BaseModel):
downloading: bool = False # True if download is in progress
size_mb: Optional[float] = None
loaded: bool = False
platform_compatible: bool = True # False when requires != [] and current platform not in requires
requires: List[str] = [] # Hardware platform requirements (mirrors ModelConfig.requires)


class ModelStatusListResponse(BaseModel):
Expand Down
11 changes: 10 additions & 1 deletion backend/routes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sqlalchemy.orm import Session

from .. import models
from ..utils.platform_detect import get_backend_type
from ..utils.platform_detect import get_backend_type, get_supported_platforms
from ..services.task_queue import create_background_task
from ..utils.progress import get_progress_manager
from ..utils.tasks import get_task_manager
Expand Down Expand Up @@ -242,13 +242,18 @@ async def get_model_status():

from ..backends import get_all_model_configs, check_model_loaded

supported_platforms = get_supported_platforms()
registry_configs = get_all_model_configs()
model_configs = [
{
"model_name": cfg.model_name,
"display_name": cfg.display_name,
"hf_repo_id": cfg.hf_repo_id,
"model_size": cfg.model_size,
"requires": cfg.requires,
"platform_compatible": (
not cfg.requires or any(p in supported_platforms for p in cfg.requires)
),
"check_loaded": lambda c=cfg: check_model_loaded(c),
}
for cfg in registry_configs
Expand Down Expand Up @@ -359,6 +364,8 @@ async def get_model_status():
downloading=is_downloading,
size_mb=size_mb,
loaded=loaded,
platform_compatible=config["platform_compatible"],
requires=config["requires"],
)
)
except Exception:
Expand All @@ -378,6 +385,8 @@ async def get_model_status():
downloading=is_downloading,
size_mb=None,
loaded=loaded,
platform_compatible=config["platform_compatible"],
requires=config["requires"],
)
)

Expand Down
45 changes: 44 additions & 1 deletion backend/utils/platform_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def is_apple_silicon() -> bool:
"""
Check if running on Apple Silicon (arm64 macOS).

Returns:
True if on Apple Silicon, False otherwise
"""
Expand All @@ -33,3 +33,46 @@ def get_backend_type() -> Literal["mlx", "pytorch"]:
# Fall through to PyTorch.
return "pytorch"
return "pytorch"


def get_supported_platforms() -> list[str]:
"""Return which compute platforms the current machine supports.

Possible values: "cuda", "mps", "xpu", "rocm", "cpu"

Rules:
- "cpu" is always included (every machine can run CPU inference).
- "cuda" is added when PyTorch reports a CUDA device available.
- "rocm" is added on ROCm builds (torch.version.hip is set).
- "mps" is added when the Metal Performance Shaders backend is available.
- "xpu" is added when Intel Extension for PyTorch detects an Arc/XPU device.

Apple Silicon machines therefore return ["mps", "cpu"], a typical
CUDA Linux machine returns ["cuda", "cpu"], an Intel Arc machine returns
["xpu", "cpu"], and a CPU-only machine returns ["cpu"].
"""
supported: list[str] = []

try:
import torch

if torch.cuda.is_available():
# Distinguish ROCm from CUDA — both report via cuda.is_available()
# on the ROCm PyTorch build, but torch.version.hip is non-None.
is_rocm = hasattr(torch.version, "hip") and torch.version.hip is not None
if is_rocm:
supported.append("rocm")
else:
supported.append("cuda")

if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
supported.append("mps")

if hasattr(torch, "xpu") and torch.xpu.is_available():
supported.append("xpu")

except ImportError:
pass # torch not available at all — only CPU

supported.append("cpu")
return supported