From 3784de4db90f59f6ea8a6370ca60f1d55bbed857 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Mon, 11 May 2026 11:54:56 +0800 Subject: [PATCH 1/5] import tts model --- backend/apps/voice_app.py | 113 +- backend/consts/model.py | 13 +- backend/services/config_sync_service.py | 4 +- backend/services/model_health_service.py | 28 +- backend/services/voice_service.py | 394 ++++- doc/docs/zh/user-guide/model-management.md | 4 +- .../components/model/ModelAddDialog.tsx | 216 ++- .../components/model/ModelDeleteDialog.tsx | 16 +- .../components/model/ModelEditDialog.tsx | 124 +- frontend/hooks/useConfig.ts | 3 + frontend/types/modelConfig.ts | 3 +- sdk/nexent/core/models/__init__.py | 9 + sdk/nexent/core/models/ali_tts_model.py | 591 +++++++ sdk/nexent/core/models/message_utils.py | 5 +- sdk/nexent/core/models/tts_model.py | 107 ++ sdk/nexent/core/models/volc_tts_model.py | 167 ++ test/backend/services/test_voice_service.py | 518 +++++- .../services/test_voice_service_tts.py | 682 ++++++++ test/conftest.py | 14 +- test/sdk/core/models/test_ali_tts_model.py | 1401 +++++++++++++++++ test/sdk/core/models/test_tts_model.py | 201 +++ test/sdk/core/models/test_volc_tts_model.py | 894 +++++++++++ 22 files changed, 5391 insertions(+), 116 deletions(-) create mode 100644 sdk/nexent/core/models/ali_tts_model.py create mode 100644 sdk/nexent/core/models/tts_model.py create mode 100644 sdk/nexent/core/models/volc_tts_model.py create mode 100644 test/backend/services/test_voice_service_tts.py create mode 100644 test/sdk/core/models/test_ali_tts_model.py create mode 100644 test/sdk/core/models/test_tts_model.py create mode 100644 test/sdk/core/models/test_volc_tts_model.py diff --git a/backend/apps/voice_app.py b/backend/apps/voice_app.py index 7451a95c4..cc1b37e87 100644 --- a/backend/apps/voice_app.py +++ b/backend/apps/voice_app.py @@ -1,3 +1,4 @@ +import asyncio import logging from http import HTTPStatus @@ -7,6 +8,8 @@ from consts.exceptions import ( VoiceServiceException, STTConnectionException, + TTSConnectionException, + VoiceConfigException, ) from consts.model import VoiceConnectivityRequest, VoiceConnectivityResponse from services.voice_service import get_voice_service @@ -56,12 +59,97 @@ async def stt_websocket(websocket: WebSocket): logger.info("STT WebSocket connection closed") +@voice_runtime_router.websocket("/tts/ws") +async def tts_websocket(websocket: WebSocket): + """WebSocket endpoint for streaming TTS""" + logger.info("TTS WebSocket connection attempt...") + await websocket.accept() + logger.info("TTS WebSocket connection accepted") + + try: + # Receive config and text from client + msg = await websocket.receive() + client_config = {} + text = None + + if msg["type"] == "websocket.receive": + if "text" in msg: + import json + client_config = json.loads(msg["text"]) + text = client_config.get("text") + elif "bytes" in msg: + try: + import json + client_config = json.loads(msg["bytes"].decode('utf-8')) + text = client_config.get("text") + except Exception as e: + logger.warning(f"Failed to parse bytes as JSON: {e}") + + if not text: + if websocket.client_state.name == "CONNECTED": + await websocket.send_json({"error": "No text provided"}) + return + + # Extract config from client + tenant_id = client_config.get("tenant_id") + model_factory = client_config.get("model_factory") + model_name = client_config.get("model_name") + api_key = client_config.get("api_key") + model_appid = client_config.get("model_appid") + access_token = client_config.get("access_token") + base_url = client_config.get("base_url") + + logger.info(f"TTS request - model_name: {model_name}, model_factory: {model_factory}, " + f"has_api_key: {bool(api_key)}") + + # Build tts_config dict for voice service + tts_config = { + "model_factory": model_factory, + "api_key": api_key, + "model_appid": model_appid, + "access_token": access_token, + "base_url": base_url, + "model_name": model_name, + } + + # Stream TTS audio to WebSocket + voice_service = get_voice_service() + await voice_service.stream_tts_to_websocket( + websocket, + text, + tenant_id=tenant_id, + model_name=model_name, + tts_config=tts_config + ) + + except TTSConnectionException as e: + logger.error(f"TTS WebSocket error: {str(e)}") + await websocket.send_json({"error": str(e)}) + except Exception as e: + logger.error(f"TTS WebSocket error: {str(e)}") + await websocket.send_json({"error": str(e)}) + finally: + logger.info("TTS WebSocket connection closed") + # Ensure connection is properly closed + if websocket.client_state.name == "CONNECTED": + await websocket.close() + + @voice_config_router.post("/connectivity") async def check_voice_connectivity(request: VoiceConnectivityRequest): - """Check voice service connectivity.""" + """ + Check voice service connectivity + + Args: + request: VoiceConnectivityRequest containing model_type + + Returns: + VoiceConnectivityResponse with connectivity status + """ try: voice_service = get_voice_service() connected = await voice_service.check_voice_connectivity(request.model_type) + return JSONResponse( status_code=HTTPStatus.OK, content=VoiceConnectivityResponse( @@ -72,10 +160,25 @@ async def check_voice_connectivity(request: VoiceConnectivityRequest): ) except VoiceServiceException as e: logger.error(f"Voice service error: {str(e)}") - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) - except STTConnectionException as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=str(e) + ) + except (STTConnectionException, TTSConnectionException) as e: logger.error(f"Voice connectivity error: {str(e)}") - raise HTTPException(status_code=HTTPStatus.SERVICE_UNAVAILABLE, detail=str(e)) + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + detail=str(e) + ) + except VoiceConfigException as e: + logger.error(f"Voice configuration error: {str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=str(e) + ) except Exception as e: logger.error(f"Unexpected voice service error: {str(e)}") - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Voice service error") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Voice service error" + ) diff --git a/backend/consts/model.py b/backend/consts/model.py index 2f1d7aae3..95ebb1b80 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -160,6 +160,16 @@ class STTModelConfig(BaseModel): accessToken: Optional[str] = None +class TTSModelConfig(BaseModel): + """TTS model specific configuration with factory, appid, and access token fields""" + modelName: str + displayName: str + apiConfig: Optional[ModelApiConfig] = None + modelFactory: Optional[str] = None + modelAppid: Optional[str] = None + accessToken: Optional[str] = None + + class ModelConfig(BaseModel): llm: SingleModelConfig embedding: SingleModelConfig @@ -167,6 +177,7 @@ class ModelConfig(BaseModel): rerank: SingleModelConfig vlm: SingleModelConfig stt: STTModelConfig + tts: TTSModelConfig class AppConfig(BaseModel): @@ -504,7 +515,7 @@ def default(cls) -> "MemoryAgentShareMode": class VoiceConnectivityRequest(BaseModel): """Request model for voice service connectivity check""" model_type: str = Field(..., - description="Type of model to check ('stt')") + description="Type of model to check ('stt' or 'tts')") class VoiceConnectivityResponse(BaseModel): diff --git a/backend/services/config_sync_service.py b/backend/services/config_sync_service.py index 0ed29bfc5..28b77c6d8 100644 --- a/backend/services/config_sync_service.py +++ b/backend/services/config_sync_service.py @@ -202,9 +202,9 @@ def build_model_config(model_config: dict) -> dict: if "embedding" in model_config.get("model_type", ""): config["dimension"] = model_config.get("max_tokens", 0) - # Add STT model specific fields + # Add voice model specific fields (STT and TTS) model_type = model_config.get("model_type", "") - if model_type == "stt": + if model_type == "stt" or model_type == "tts": config["modelFactory"] = model_config.get("model_factory", "") config["modelAppid"] = model_config.get("model_appid", "") config["accessToken"] = model_config.get("access_token", "") diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index a20b2a6ca..6283c3359 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -139,7 +139,6 @@ async def _perform_connectivity_check( elif model_type == 'stt': voice_service = get_voice_service() - # Determine STT provider based on model_factory use_volc = model_factory and model_factory.lower() in ["volcengine", "volcano", "volcengine", "火山引擎"] @@ -164,6 +163,33 @@ async def _perform_connectivity_check( "model": model_name } ) + elif model_type == 'tts': + voice_service = get_voice_service() + + # Determine TTS provider based on model_factory + use_volc = model_factory and model_factory.lower() in ["volcengine", "volcano", "volcengine", "火山引擎"] + + if use_volc: + # Use Volcano TTS with appid and access_token + connectivity = await voice_service.check_voice_connectivity( + model_type="tts", + stt_config={ + "model_factory": model_factory, + "model_appid": model_appid, + "access_token": access_token, + "base_url": model_base_url + } + ) + else: + # Use Ali TTS (default) with api_key and model name + connectivity = await voice_service.check_voice_connectivity( + model_type="tts", + stt_config={ + "api_key": model_api_key, + "base_url": model_base_url, + "model": model_name + } + ) else: raise ValueError(f"Unsupported model type: {model_type}") diff --git a/backend/services/voice_service.py b/backend/services/voice_service.py index 80d6264db..7d274ff23 100644 --- a/backend/services/voice_service.py +++ b/backend/services/voice_service.py @@ -1,14 +1,19 @@ +import asyncio import logging from typing import Any, Dict, Optional from nexent.core.models.stt_model import BaseSTTModel +from nexent.core.models.tts_model import BaseTTSModel from nexent.core.models.volc_stt_model import VolcSTTConfig, VolcSTTModel from nexent.core.models.ali_stt_model import AliSTTConfig, AliSTTModel +from nexent.core.models.volc_tts_model import VolcTTSConfig, VolcTTSModel +from nexent.core.models.ali_tts_model import AliTTSConfig, AliTTSModel -from consts.const import TEST_PCM_PATH +from consts.const import TEST_VOICE_PATH, TEST_PCM_PATH from consts.exceptions import ( VoiceServiceException, STTConnectionException, + TTSConnectionException, ) from database.model_management_db import get_model_records from utils.config_utils import tenant_config_manager @@ -17,7 +22,7 @@ class VoiceService: - """Voice service that handles STT operations""" + """Voice service that handles STT and TTS operations""" def _get_stt_model_from_config( self, @@ -44,9 +49,11 @@ def _get_stt_model_from_config( Returns: STT model instance based on configuration """ + # Default to Ali Cloud if model_factory is not specified or is dashscope use_volc = model_factory and model_factory.lower() in ["volc", "volcano", "volcengine", "火山引擎"] if use_volc: + # Use Volcano Engine STT volc_config = VolcSTTConfig( appid=model_appid or "", access_token=access_token or "", @@ -56,6 +63,7 @@ def _get_stt_model_from_config( ) return VolcSTTModel(volc_config, TEST_PCM_PATH) else: + # Use Ali Cloud STT (default) ali_config = AliSTTConfig( api_key=api_key or "", model=model_name or "qwen3-asr-flash-realtime", @@ -84,6 +92,7 @@ def _get_stt_model_from_tenant_config( STT model instance based on tenant's configuration """ try: + # Get STT model configuration from tenant config stt_config = tenant_config_manager.get_model_config(tenant_id, "stt") if stt_config: @@ -104,6 +113,7 @@ def _get_stt_model_from_tenant_config( language=language ) + # Try to get from model records in database model_records = get_model_records({"model_type": "stt"}, tenant_id) if model_records: record = model_records[0] @@ -131,6 +141,114 @@ def _get_stt_model_from_tenant_config( logger.error(f"Error getting STT model config for tenant {tenant_id}: {str(e)}") return self._get_stt_model_from_config(language=language) + def _get_tts_model_from_config( + self, + model_factory: Optional[str] = None, + api_key: Optional[str] = None, + model_appid: Optional[str] = None, + access_token: Optional[str] = None, + speed_ratio: float = 1.0, + base_url: Optional[str] = None, + model: Optional[str] = None + ) -> BaseTTSModel: + """ + Get the appropriate TTS model based on model factory configuration. + + Args: + model_factory: Model factory/vendor name + api_key: API key (for Ali TTS) + model_appid: Application ID (for Volcano TTS) + access_token: Access token (for Volcano TTS) + speed_ratio: Speech speed ratio + base_url: Custom WebSocket URL (optional) + model: Model name (for Ali TTS) + + Returns: + TTS model instance based on configuration + """ + use_volc = model_factory and model_factory.lower() in ["volc", "volcano", "volcengine", "火山引擎"] + + if use_volc: + volc_config = VolcTTSConfig( + appid=model_appid or "", + token=access_token or "", + speed_ratio=speed_ratio, + ws_url=base_url or None, + ) + return VolcTTSModel(volc_config) + else: + ali_config = AliTTSConfig( + api_key=api_key or "", + model=model or "qwen3-tts-flash", + voice="Cherry", + speech_rate=speed_ratio, + ws_url=base_url if base_url else None + ) + return AliTTSModel(ali_config) + + def _get_tts_model_from_tenant_config( + self, + tenant_id: str + ) -> BaseTTSModel: + """ + Get TTS model based on tenant's model configuration. + + Args: + tenant_id: Tenant ID + + Returns: + TTS model instance based on tenant's configuration + """ + try: + tts_config = tenant_config_manager.get_model_config(tenant_id, "tts") + + if tts_config: + model_factory = tts_config.get("model_factory", "") + api_key = tts_config.get("api_key", "") + model_appid = tts_config.get("model_appid", "") + access_token_val = tts_config.get("access_token", "") + speed_ratio = float(tts_config.get("speed_ratio", 1.0)) + base_url = tts_config.get("base_url", "") + model = tts_config.get("model") or tts_config.get("model_name", "") + + return self._get_tts_model_from_config( + model_factory=model_factory, + api_key=api_key, + model_appid=model_appid, + access_token=access_token_val, + speed_ratio=speed_ratio, + base_url=base_url if base_url else None, + model=model if model else None + ) + + model_records = get_model_records({"model_type": "tts"}, tenant_id) + if model_records: + record = model_records[0] + model_factory = record.get("model_factory", "") + api_key = record.get("api_key", "") + model_appid = record.get("model_appid", "") + access_token_val = record.get("access_token", "") + speed_ratio = float(record.get("speed_ratio", 1.0)) + base_url = record.get("base_url", "") + model = record.get("model_name", "") + + return self._get_tts_model_from_config( + model_factory=model_factory, + api_key=api_key, + model_appid=model_appid, + access_token=access_token_val, + speed_ratio=speed_ratio, + base_url=base_url if base_url else None, + model=model if model else None + ) + + logger.warning(f"No TTS model configuration found for tenant {tenant_id}, using default config") + return self._get_tts_model_from_config() + + except Exception as e: + logger.error(f"Error getting TTS model config for tenant {tenant_id}: {str(e)}") + return self._get_tts_model_from_config() + async def start_stt_streaming_session( self, websocket, @@ -169,6 +287,7 @@ async def start_stt_streaming_session( else: logger.warning("No stt_config provided, will use tenant model config if available") + # Get STT model based on configuration if model_factory or api_key or model_appid: stt_model = self._get_stt_model_from_config( model_factory=model_factory, @@ -193,6 +312,153 @@ async def start_stt_streaming_session( logger.error(f"STT streaming session failed: {str(e)}") raise STTConnectionException(f"STT streaming failed: {str(e)}") from e + async def generate_tts_speech( + self, + text: str, + stream: bool = True, + tts_config: Optional[Dict[str, Any]] = None, + tenant_id: Optional[str] = None, + model_name_override: Optional[str] = None + ) -> Any: + """ + Generate TTS speech from text + + Args: + text: Text to convert to speech + stream: Whether to stream the audio or return complete audio + tts_config: TTS configuration dict from client (preferred) + tenant_id: Tenant ID for model lookup + model_name_override: Model name override + + Returns: + Audio data (streaming or complete) + + Raises: + TTSConnectionException: If TTS generation fails + """ + if not text: + raise VoiceServiceException("No text provided for TTS generation") + + try: + logger.info(f"Generating TTS speech for text: {text[:50]}...") + + model_factory = None + api_key = None + model_appid = None + access_token = None + speed_ratio = 1.0 + base_url = None + model_name = None + + if tts_config: + model_factory = tts_config.get("model_factory") + api_key = tts_config.get("api_key") or tts_config.get("apiKey") + model_appid = tts_config.get("model_appid") or tts_config.get("appid") + access_token = tts_config.get("access_token") + speed_ratio = float(tts_config.get("speed_ratio", 1.0)) + base_url = tts_config.get("base_url") or tts_config.get("baseUrl") + model_name = tts_config.get("model") or tts_config.get("model_name") + + # If model_name is provided directly, use it + effective_model = model_name_override or model_name + logger.info(f"TTS config - api_key: {bool(api_key)}, model_name_override: {model_name_override}, " + f"model_name from config: {model_name}, effective_model: {effective_model}") + + + # Determine model factory and create appropriate TTS model + use_volc = model_factory and model_factory.lower() in ["volc", "volcano", "volcengine", "火山引擎"] + + if use_volc: + # Use Volcano TTS + tts_model = self._get_tts_model_from_config( + model_factory=model_factory, + api_key=api_key, + model_appid=model_appid, + access_token=access_token, + speed_ratio=speed_ratio, + base_url=base_url, + model=effective_model + ) + logger.info(f"TTS model created: Volcano TTS (factory={model_factory})") + elif api_key: + # Use Ali TTS with provided api_key + tts_model = self._get_tts_model_from_config( + model_factory=model_factory, + api_key=api_key, + model_appid=model_appid, + access_token=access_token, + speed_ratio=speed_ratio, + base_url=base_url, + model=effective_model + ) + logger.info(f"TTS model created: Ali TTS (api_key provided)") + elif tenant_id: + tts_model = self._get_tts_model_from_tenant_config(tenant_id) + logger.info(f"TTS model created from tenant config for tenant_id={tenant_id}") + else: + logger.warning("No api_key, model_name, or tenant_id provided, using default TTS model") + tts_model = self._get_tts_model_from_config() + + speech_result = await tts_model.generate_speech(text, stream=stream) + return speech_result + except Exception as e: + logger.error(f"TTS generation failed: {str(e)}") + raise TTSConnectionException(f"TTS generation failed: {str(e)}") from e + + async def stream_tts_to_websocket( + self, + websocket, + text: str, + tenant_id: Optional[str] = None, + model_name: Optional[str] = None, + tts_config: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Stream TTS audio to WebSocket with proper error handling and fallback + + Args: + websocket: WebSocket connection to stream to + text: Text to convert to speech + tenant_id: Optional tenant ID for model selection + model_name: Optional model name override + tts_config: Optional TTS configuration dict with model_factory, api_key, model_appid, access_token, base_url + + Raises: + TTSConnectionException: If TTS service connection fails + VoiceServiceException: If TTS streaming fails + """ + speech_result = await self.generate_tts_speech( + text, + stream=True, + tenant_id=tenant_id, + model_name_override=model_name, + tts_config=tts_config + ) + + # Check if it's an async iterator or a regular iterable + if hasattr(speech_result, '__aiter__'): + # It's an async iterator, use async for + async for chunk in speech_result: + if websocket.client_state.name == "CONNECTED": + await websocket.send_bytes(chunk) + else: + break + elif hasattr(speech_result, '__iter__'): + # It's a regular iterator, use normal for + for chunk in speech_result: + if websocket.client_state.name == "CONNECTED": + await websocket.send_bytes(chunk) + else: + break + else: + # It's a single chunk, send it directly + if websocket.client_state.name == "CONNECTED": + await websocket.send_bytes(speech_result) + + # Send end marker after successful TTS generation + if websocket.client_state.name == "CONNECTED": + await websocket.send_json({"status": "completed"}) + async def check_stt_connectivity( self, model_factory: Optional[str] = None, @@ -222,6 +488,7 @@ async def check_stt_connectivity( STTConnectionException: If connectivity check fails """ try: + # Get STT model based on factory stt_model = self._get_stt_model_from_config( model_factory=model_factory, model_name=model, @@ -232,6 +499,7 @@ async def check_stt_connectivity( language=language ) + connected = await stt_model.check_connectivity() if not connected: @@ -244,6 +512,54 @@ async def check_stt_connectivity( logger.error(f"STT connectivity check failed: {str(e)}") raise STTConnectionException(f"STT connectivity check failed: {str(e)}") from e + async def check_tts_connectivity( + self, + model_factory: Optional[str] = None, + api_key: Optional[str] = None, + model_appid: Optional[str] = None, + access_token: Optional[str] = None, + speed_ratio: float = 1.0, + base_url: Optional[str] = None, + model: Optional[str] = None + ) -> bool: + """ + Check TTS service connectivity. + + Args: + model_factory: Model factory/vendor name (e.g., "volc", "dashscope") + api_key: API key for Ali TTS + model_appid: Application ID for Volcano TTS + access_token: Access token for Volcano TTS + speed_ratio: Speech speed ratio + base_url: Custom WebSocket URL (optional) + model: Model name (e.g., "qwen3-tts-flash") + + Returns: + bool: True if TTS service is connected, False otherwise + + Raises: + TTSConnectionException: If connectivity check fails + """ + try: + tts_model = self._get_tts_model_from_config( + model_factory=model_factory, + api_key=api_key, + model_appid=model_appid, + access_token=access_token, + speed_ratio=speed_ratio, + base_url=base_url, + model=model + ) + + connected = await tts_model.check_connectivity() + if not connected: + logger.warning("TTS service connectivity check returned False") + return False + return connected + except Exception as e: + logger.error(f"TTS connectivity check failed: {str(e)}") + return False + async def check_voice_connectivity( self, model_type: str, @@ -253,39 +569,58 @@ async def check_voice_connectivity( Check voice service connectivity based on model type. Args: - model_type: Type of model to check ('stt' only) + model_type: Type of model to check ('stt' or 'tts') stt_config: Optional STT configuration dict Returns: - bool: True if the service is connected, False otherwise + bool: True if the specified service is connected, False otherwise Raises: VoiceServiceException: If model_type is invalid STTConnectionException: If STT connectivity check fails + TTSConnectionException: If TTS connectivity check fails """ - if model_type != "stt": - logger.error(f"Unsupported model type: {model_type}") - raise VoiceServiceException(f"Unsupported model type: {model_type}") - try: - model_factory = stt_config.get("model_factory") if stt_config else None - api_key = stt_config.get("api_key") if stt_config else None - model_appid = stt_config.get("model_appid") if stt_config else None - access_token = stt_config.get("access_token") if stt_config else None - language = stt_config.get("language", "zh") if stt_config else "zh" - model = stt_config.get("model", "qwen3-asr-flash-realtime") if stt_config else "qwen3-asr-flash-realtime" - base_url = stt_config.get("base_url") if stt_config else None - - return await self.check_stt_connectivity( - model_factory=model_factory, - api_key=api_key, - model_appid=model_appid, - access_token=access_token, - language=language, - model=model, - base_url=base_url - ) - except STTConnectionException: + if model_type == 'stt': + model_factory = stt_config.get("model_factory") if stt_config else None + api_key = stt_config.get("api_key") if stt_config else None + model_appid = stt_config.get("model_appid") if stt_config else None + access_token = stt_config.get("access_token") if stt_config else None + language = stt_config.get("language", "zh") if stt_config else "zh" + model = stt_config.get("model", "qwen3-asr-flash-realtime") if stt_config else "qwen3-asr-flash-realtime" + base_url = stt_config.get("base_url") if stt_config else None + + return await self.check_stt_connectivity( + model_factory=model_factory, + api_key=api_key, + model_appid=model_appid, + access_token=access_token, + language=language, + model=model, + base_url=base_url + ) + elif model_type == 'tts': + model_factory = stt_config.get("model_factory") if stt_config else None + api_key = stt_config.get("api_key") if stt_config else None + model_appid = stt_config.get("model_appid") if stt_config else None + access_token = stt_config.get("access_token") if stt_config else None + speed_ratio = float(stt_config.get("speed_ratio", 1.0)) if stt_config else 1.0 + base_url = stt_config.get("base_url") if stt_config else None + model = stt_config.get("model", "qwen3-tts-flash") if stt_config else "qwen3-tts-flash" + + return await self.check_tts_connectivity( + model_factory=model_factory, + api_key=api_key, + model_appid=model_appid, + access_token=access_token, + speed_ratio=speed_ratio, + base_url=base_url, + model=model + ) + else: + logger.error(f"Unknown model type: {model_type}") + raise VoiceServiceException(f"Unknown model type: {model_type}") + except (STTConnectionException, TTSConnectionException): raise except Exception as e: logger.error(f"Voice service connectivity check failed: {str(e)}") @@ -297,7 +632,12 @@ async def check_voice_connectivity( def get_voice_service() -> VoiceService: - """Get the global voice service instance.""" + """ + Get the global voice service instance + + Returns: + VoiceService: The global voice service instance + """ global _voice_service_instance if _voice_service_instance is None: _voice_service_instance = VoiceService() diff --git a/doc/docs/zh/user-guide/model-management.md b/doc/docs/zh/user-guide/model-management.md index c8f07c0c3..6870f5544 100644 --- a/doc/docs/zh/user-guide/model-management.md +++ b/doc/docs/zh/user-guide/model-management.md @@ -238,7 +238,7 @@ Nexent 支持任何 **遵循OpenAI API规范** 的大语言模型供应商,包 - **网站**: [volcengine.com/product/voice-tech](https://www.volcengine.com/product/voice-tech) - **免费额度**: 个人使用可用 - **特色**: 高质量中英文语音合成 - +- 推荐使用**豆包语音合成模型2.0和大模型流式语音识别模型** - **开始使用**: 1. 注册火山引擎账户 @@ -248,7 +248,7 @@ Nexent 支持任何 **遵循OpenAI API规范** 的大语言模型供应商,包 **阿里灵积** - **网站**: [aliyun.com/benefit/scene/voice](https://www.aliyun.com/benefit/scene/voice) - +- 推荐使用**千问3-TTS-Instruct-Flash-Realtime/千问3-TTS-Flash-Realtime和千问3-ASR-Flash-Realtime** - **开始使用**: 1. 注册阿里云账户 diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 11391c133..c551d9ba2 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -13,7 +13,7 @@ import { import { useConfig } from "@/hooks/useConfig"; import { getConnectivityMeta, ConnectivityStatusType } from "@/lib/utils"; import { modelService } from "@/services/modelService"; -import { ModelType, SingleModelConfig, STTModelConfig } from "@/types/modelConfig"; +import { ModelType, SingleModelConfig, STTModelConfig, TTSModelConfig } from "@/types/modelConfig"; import { MODEL_TYPES, PROVIDER_LINKS } from "@/const/modelConfig"; import { useSiliconModelList } from "@/hooks/model/useSiliconModelList"; import { useDashscopeModelList } from "@/hooks/model/useDashscopeModelList"; @@ -64,6 +64,8 @@ const DEFAULT_FORM_STATE = { sttProvider: "dashscope", // dashscope or volcengine modelAppid: "", accessToken: "", + // TTS specific fields + ttsProvider: "dashscope", // ali or volcengine }; // Connectivity status type comes from utils @@ -451,6 +453,19 @@ export const ModelAddDialog = ({ return form.apiKey.trim() !== "" && form.name.trim() !== ""; } } + if (form.type === MODEL_TYPES.TTS) { + // For TTS models, validate based on provider type + if (form.ttsProvider === "volcengine") { + // Volcano Engine requires appid and access_token + return ( + form.modelAppid.trim() !== "" && + form.accessToken.trim() !== "" + ); + } else { + // Ali TTS requires API Key and model name (URL is optional) + return form.apiKey.trim() !== "" && form.name.trim() !== ""; + } + } return ( form.name.trim() !== "" && form.url.trim() !== "" && @@ -496,6 +511,7 @@ export const ModelAddDialog = ({ sttConfig.modelFactory = "volcengine"; sttConfig.modelAppid = form.modelAppid.trim(); sttConfig.accessToken = form.accessToken.trim(); + sttConfig.baseUrl = form.url; } else { sttConfig.apiKey = form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey; sttConfig.modelFactory = "dashscope"; @@ -506,23 +522,45 @@ export const ModelAddDialog = ({ const result = await modelService.verifyModelConfigConnectivity(sttConfig); connectivity = result.connectivity; } else { - const config = { - modelName: form.name, - modelType: modelType, - baseUrl: form.url, - apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, - maxTokens: - form.type === MODEL_TYPES.EMBEDDING - ? parseInt(form.vectorDimension) - : parseInt(form.maxTokens), - embeddingDim: - form.type === MODEL_TYPES.EMBEDDING - ? parseInt(form.vectorDimension) - : undefined, - }; + // For TTS models, build the appropriate config based on provider + if (form.type === MODEL_TYPES.TTS) { + const ttsConfig: any = { + modelType: modelType, + }; + + if (form.ttsProvider === "volcengine") { + ttsConfig.modelFactory = "volcengine"; + ttsConfig.modelAppid = form.modelAppid.trim(); + ttsConfig.accessToken = form.accessToken.trim(); + ttsConfig.baseUrl = form.url; + } else { + ttsConfig.apiKey = form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey; + ttsConfig.modelFactory = "dashscope"; + ttsConfig.modelName = form.name; + ttsConfig.baseUrl = form.url; + } - const result = await modelService.verifyModelConfigConnectivity(config); - connectivity = result.connectivity; + const result = await modelService.verifyModelConfigConnectivity(ttsConfig); + connectivity = result.connectivity; + } else { + const config = { + modelName: form.name, + modelType: modelType, + baseUrl: form.url, + apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + maxTokens: + form.type === MODEL_TYPES.EMBEDDING + ? parseInt(form.vectorDimension) + : parseInt(form.maxTokens), + embeddingDim: + form.type === MODEL_TYPES.EMBEDDING + ? parseInt(form.vectorDimension) + : undefined, + }; + + const result = await modelService.verifyModelConfigConnectivity(config); + connectivity = result.connectivity; + } } } @@ -709,6 +747,16 @@ export const ModelAddDialog = ({ } } + // Add TTS specific fields + if (form.type === MODEL_TYPES.TTS) { + modelParams.modelFactory = form.ttsProvider === "volcengine" ? "volcengine" : "dashscope"; + if (form.ttsProvider === "volcengine") { + modelParams.modelAppid = form.modelAppid; + modelParams.accessToken = form.accessToken; + modelParams.baseUrl = form.url; + } + } + // Add embedding specific fields if (isEmbeddingModel) { modelParams.expectedChunkSize = form.chunkSizeRange[0]; @@ -736,6 +784,16 @@ export const ModelAddDialog = ({ } } + // Add TTS specific fields + if (form.type === MODEL_TYPES.TTS) { + modelParams.modelFactory = form.ttsProvider === "volcengine" ? "volcengine" : "dashscope"; + if (form.ttsProvider === "volcengine") { + modelParams.modelAppid = form.modelAppid; + modelParams.accessToken = form.accessToken; + modelParams.baseUrl = form.url; + } + } + // Add embedding specific fields if (isEmbeddingModel) { modelParams.expectedChunkSize = form.chunkSizeRange[0]; @@ -747,7 +805,7 @@ export const ModelAddDialog = ({ } // Create the model configuration object - let modelConfig: SingleModelConfig | STTModelConfig = { + let modelConfig: SingleModelConfig | STTModelConfig | TTSModelConfig = { modelName: form.name, displayName: form.displayName || form.name, apiConfig: { @@ -765,6 +823,15 @@ export const ModelAddDialog = ({ } } + // Add TTS specific fields to config + if (form.type === MODEL_TYPES.TTS) { + (modelConfig as TTSModelConfig).modelFactory = form.ttsProvider === "volcengine" ? "volcengine" : "dashscope"; + if (form.ttsProvider === "volcengine") { + (modelConfig as TTSModelConfig).modelAppid = form.modelAppid; + (modelConfig as TTSModelConfig).accessToken = form.accessToken; + } + } + // Add the dimension field for embedding models if (form.type === MODEL_TYPES.EMBEDDING) { modelConfig.dimension = parseInt(form.vectorDimension); @@ -830,6 +897,7 @@ export const ModelAddDialog = ({ const isEmbeddingModel = form.type === MODEL_TYPES.EMBEDDING; const isSTTModel = form.type === MODEL_TYPES.STT; + const isTTSModel = form.type === MODEL_TYPES.TTS; return ( {t("model.type.stt")} - @@ -1102,8 +1170,84 @@ export const ModelAddDialog = ({ )} + {/* TTS Provider Selection */} + {!form.isBatchImport && isTTSModel && ( +
+ + +
+ )} + + {/* TTS Fields for Volcano Engine */} + {!form.isBatchImport && isTTSModel && form.ttsProvider === "volcengine" && ( + <> +
+ + handleFormChange("modelAppid", e.target.value)} + autoComplete="new-password" + /> +
+
+ + handleFormChange("accessToken", e.target.value)} + autoComplete="new-password" + /> +
+ + )} + + {/* API Key (for Ali TTS) */} + {!form.isBatchImport && isTTSModel && form.ttsProvider === "dashscope" && ( +
+ + handleFormChange("apiKey", e.target.value)} + autoComplete="new-password" + /> +
+ )} + {/* API Key (for non-STT, non-TTS models) */} - {!form.isBatchImport && !isSTTModel && ( + {!form.isBatchImport && !isSTTModel && !isTTSModel && (
@@ -1414,18 +1413,15 @@ export const ModelDeleteDialog = ({ }} disabled={ deletingModels.has(model.displayName || model.name) || - model.type === MODEL_TYPES.STT || - model.type === MODEL_TYPES.TTS + model.type === MODEL_TYPES.STT } className={`p-1 ${ - model.type === MODEL_TYPES.STT || - model.type === MODEL_TYPES.TTS + model.type === MODEL_TYPES.STT ? "text-gray-400 cursor-not-allowed" : "text-red-500 hover:text-red-700" }`} title={ - model.type === MODEL_TYPES.STT || - model.type === MODEL_TYPES.TTS + model.type === MODEL_TYPES.STT ? t("model.dialog.delete.unsupportedTypeHint") : t("model.dialog.delete.deleteHint") } diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx index 3114c5535..3c632b9ff 100644 --- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx @@ -1,7 +1,7 @@ import { useState, useEffect } from 'react' import { useTranslation } from 'react-i18next' -import { Modal, Input, Button, App } from "antd"; +import { Modal, Select, Input, Button, App } from "antd"; import { MODEL_TYPES, MODEL_STATUS } from "@/const/modelConfig"; import { useConfig } from "@/hooks/useConfig"; @@ -14,6 +14,8 @@ import { DEFAULT_MAXIMUM_CHUNK_SIZE, } from "./ModelChunkSizeSilder"; +const { Option } = Select; + interface ModelEditDialogProps { isOpen: boolean; model: ModelOption | null; @@ -45,6 +47,10 @@ export const ModelEditDialog = ({ DEFAULT_MAXIMUM_CHUNK_SIZE, ] as [number, number], chunkingBatchSize: "10", + // Voice model fields (STT/TTS) + modelFactory: "", + modelAppid: "", + accessToken: "", }); const [loading, setLoading] = useState(false); const [verifyingConnectivity, setVerifyingConnectivity] = useState(false); @@ -71,6 +77,9 @@ export const ModelEditDialog = ({ model.maximumChunkSize || DEFAULT_MAXIMUM_CHUNK_SIZE, ] as [number, number], chunkingBatchSize: (model.chunkingBatchSize || 10).toString(), + modelFactory: model.modelFactory || "", + modelAppid: model.modelAppid || "", + accessToken: model.accessToken || "", }); } }, [model]); @@ -78,7 +87,7 @@ export const ModelEditDialog = ({ const handleFormChange = (field: string, value: string) => { setForm((prev) => ({ ...prev, [field]: value })); // If the key configuration item changes, clear the verification status - if (["url", "apiKey", "maxTokens", "vectorDimension"].includes(field)) { + if (["url", "apiKey", "maxTokens", "vectorDimension", "modelFactory", "modelAppid", "accessToken"].includes(field)) { setConnectivityStatus({ status: null, message: "" }); } }; @@ -87,8 +96,20 @@ export const ModelEditDialog = ({ form.type === MODEL_TYPES.EMBEDDING || form.type === MODEL_TYPES.MULTI_EMBEDDING; const isRerankModel = form.type === MODEL_TYPES.RERANK; + const isVoiceModel = + form.type === MODEL_TYPES.STT || form.type === MODEL_TYPES.TTS; const isFormValid = () => { + if (isVoiceModel) { + if (form.modelFactory === "volcengine") { + return ( + form.modelAppid.trim() !== "" && + form.accessToken.trim() !== "" + ); + } else { + return form.name.trim() !== "" && form.apiKey.trim() !== ""; + } + } return form.name.trim() !== "" && form.url.trim() !== ""; }; @@ -107,8 +128,10 @@ export const ModelEditDialog = ({ try { const modelType = form.type as ModelType; + const isVoiceModel = + modelType === MODEL_TYPES.STT || modelType === MODEL_TYPES.TTS; - const config = { + const config: any = { modelName: form.name, modelType: modelType, baseUrl: form.url, @@ -125,6 +148,15 @@ export const ModelEditDialog = ({ : undefined, }; + // Add voice model fields for STT/TTS + if (isVoiceModel) { + config.modelFactory = form.modelFactory; + if (form.modelFactory === "volcengine") { + config.modelAppid = form.modelAppid; + config.accessToken = form.accessToken; + } + } + const result = await modelService.verifyModelConfigConnectivity(config); // Set connectivity status @@ -176,6 +208,9 @@ export const ModelEditDialog = ({ expectedChunkSize: isEmbeddingModel ? form.chunkSizeRange[0] : undefined, maximumChunkSize: isEmbeddingModel ? form.chunkSizeRange[1] : undefined, chunkingBatchSize: isEmbeddingModel ? parseInt(form.chunkingBatchSize) || 10 : undefined, + modelFactory: isVoiceModel ? form.modelFactory : undefined, + modelAppid: isVoiceModel && form.modelFactory === "volcengine" ? form.modelAppid : undefined, + accessToken: isVoiceModel && form.modelFactory === "volcengine" ? form.accessToken : undefined, }); } else { await modelService.updateSingleModel({ @@ -196,6 +231,14 @@ export const ModelEditDialog = ({ chunkingBatchSize: parseInt(form.chunkingBatchSize) || 10, } : {}), + // Send voice model fields + ...(isVoiceModel + ? { + modelFactory: form.modelFactory, + modelAppid: form.modelFactory === "volcengine" ? form.modelAppid : undefined, + accessToken: form.modelFactory === "volcengine" ? form.accessToken : undefined, + } + : {}), }); } @@ -221,6 +264,13 @@ export const ModelEditDialog = ({ ...(isEmbeddingModel ? { dimension: parseInt(form.vectorDimension) } : {}), + ...(isVoiceModel + ? { + modelFactory: form.modelFactory, + modelAppid: form.modelFactory === "volcengine" ? form.modelAppid : "", + accessToken: form.modelFactory === "volcengine" ? form.accessToken : "", + } + : {}), }, }); @@ -270,15 +320,63 @@ export const ModelEditDialog = ({ {/* URL */} -
- - handleFormChange("url", e.target.value)} - /> -
+ {!isVoiceModel && ( +
+ + handleFormChange("url", e.target.value)} + /> +
+ )} + + {/* Voice Model Factory */} + {isVoiceModel && ( +
+ + +
+ )} + + {/* Voice Model App ID and Access Token (Volcengine) */} + {isVoiceModel && form.modelFactory === "volcengine" && ( + <> +
+ + handleFormChange("modelAppid", e.target.value)} + autoComplete="new-password" + /> +
+
+ + handleFormChange("accessToken", e.target.value)} + autoComplete="new-password" + visibilityToggle={false} + /> +
+ + )} {/* API Key */}
@@ -481,4 +579,4 @@ export const ProviderConfigEditDialog = ({
) -} \ No newline at end of file +} diff --git a/frontend/hooks/useConfig.ts b/frontend/hooks/useConfig.ts index 8d4c4ccea..d4a6d81d3 100644 --- a/frontend/hooks/useConfig.ts +++ b/frontend/hooks/useConfig.ts @@ -94,6 +94,9 @@ const defaultConfig: GlobalConfig = { apiKey: "", modelUrl: "", }, + modelFactory: "dashscope", + modelAppid: "", + accessToken: "", }, }, }; diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts index a9f918d71..8eea42005 100644 --- a/frontend/types/modelConfig.ts +++ b/frontend/types/modelConfig.ts @@ -47,7 +47,8 @@ export interface ModelOption { expectedChunkSize?: number; maximumChunkSize?: number; chunkingBatchSize?: number; - // STT specific fields + // STT/TTS specific fields + modelFactory?: string; modelAppid?: string; accessToken?: string; } diff --git a/sdk/nexent/core/models/__init__.py b/sdk/nexent/core/models/__init__.py index fa15fb3d4..9d8217358 100644 --- a/sdk/nexent/core/models/__init__.py +++ b/sdk/nexent/core/models/__init__.py @@ -4,6 +4,10 @@ from .stt_model import BaseSTTModel from .ali_stt_model import AliSTTModel, AliSTTConfig from .volc_stt_model import VolcSTTModel, VolcSTTConfig +from .tts_model import BaseTTSModel +from .ali_tts_model import AliTTSModel, AliTTSConfig +from .volc_tts_model import VolcTTSModel, VolcTTSConfig + __all__ = [ "OpenAIModel", "OpenAIVLModel", @@ -13,4 +17,9 @@ "AliSTTConfig", "VolcSTTModel", "VolcSTTConfig", + "BaseTTSModel", + "AliTTSModel", + "AliTTSConfig", + "VolcTTSModel", + "VolcTTSConfig", ] diff --git a/sdk/nexent/core/models/ali_tts_model.py b/sdk/nexent/core/models/ali_tts_model.py new file mode 100644 index 000000000..40a9766bc --- /dev/null +++ b/sdk/nexent/core/models/ali_tts_model.py @@ -0,0 +1,591 @@ +""" +Ali TTS model implementation supporting both CosyVoice and Qwen Realtime APIs. +""" +import asyncio +import base64 +import json +import logging +import uuid +from typing import Any, AsyncGenerator, Dict, Optional, Union + +import websockets + +# Default WebSocket connection timeout (seconds) +DEFAULT_WS_OPEN_TIMEOUT = 60 +DEFAULT_WS_CLOSE_TIMEOUT = 10 + +from .tts_model import BaseTTSModel + +logger = logging.getLogger(__name__) + + +class AliTTSError(Exception): + """Exception raised when Ali TTS API returns an error.""" + + def __init__(self, message: str): + self.message = message + super().__init__(self.message) + + +# CosyVoice API default URL +COSYVOICE_API_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/inference" +# Qwen Realtime API default URL +QWEN_REALTIME_API_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + + +class AliTTSConfig: + """Configuration for Ali TTS model.""" + + def __init__( + self, + api_key: str, + model: str = "cosyvoice-v2", + voice: str = None, + speech_rate: float = 1.0, + pitch_rate: float = 1.0, + volume: float = 50.0, + ws_url: Optional[str] = None, + format: str = "mp3", + sample_rate: int = 16000, + workspace_id: Optional[str] = None + ): + self.api_key = api_key + self.model = model + self.voice = voice + self.speech_rate = speech_rate + self.pitch_rate = pitch_rate + self.volume = volume + self.ws_url = ws_url + self.format = format + self.sample_rate = sample_rate + self.workspace_id = workspace_id + + def is_realtime_api(self) -> bool: + """Check if URL is for Qwen Realtime API.""" + return "/realtime" in (self.ws_url or "") + + def get_api_url(self) -> str: + """Get the WebSocket API URL based on the model.""" + if self.ws_url: + return self.ws_url + if self.is_realtime_api() or "qwen" in self.model.lower(): + return QWEN_REALTIME_API_URL + return COSYVOICE_API_URL + + +class AliTTSModel(BaseTTSModel): + """Ali TTS model implementation supporting CosyVoice and Qwen Realtime APIs.""" + + def __init__(self, config: AliTTSConfig, audio_file_path: Optional[str] = None): + super().__init__(audio_file_path) + self.config = config + self._is_realtime = config.is_realtime_api() or "qwen" in config.model.lower() + + def get_websocket_url(self) -> str: + """Get the WebSocket URL for the TTS service.""" + base_url = self.config.get_api_url() + if self._is_realtime: + separator = "&" if "?" in base_url else "?" + return f"{base_url}{separator}model={self.config.model}" + return base_url + + def get_auth_headers(self) -> Dict[str, str]: + """Get authentication headers for the WebSocket connection.""" + return {"Authorization": f"Bearer {self.config.api_key}"} + + async def generate_speech( + self, + text: str, + stream: bool = False + ) -> Union[bytes, AsyncGenerator[bytes, None]]: + """ + Generate speech from text using the appropriate API. + + Args: + text: Input text to synthesize + stream: If True, return an async generator of audio chunks. + If False, return complete audio bytes. + + Returns: + Audio data either as complete bytes or streaming chunks + """ + ws_url = self.get_websocket_url() + headers = self.get_auth_headers() + logger.info(f"Connecting to Ali TTS service at {ws_url}") + logger.info(f"Using model: {self.config.model}, voice: {self.config.voice}") + logger.info(f"API type: {'Qwen Realtime' if self._is_realtime else 'CosyVoice'}") + + if self._is_realtime: + if stream: + return self._generate_qwen_realtime_streaming(text, ws_url, headers) + return await self._generate_qwen_realtime_non_streaming(text, ws_url, headers) + else: + if stream: + return self._generate_cosyvoice_streaming(text, ws_url, headers) + return await self._generate_cosyvoice_non_streaming(text, ws_url, headers) + + # ==================== CosyVoice API Implementation ==================== + + def _cosyvoice_generate_task_id(self) -> str: + """Generate a unique task ID for CosyVoice API.""" + return uuid.uuid4().hex + + def _cosyvoice_construct_run_task_request(self, task_id: str) -> Dict[str, Any]: + """Construct the run-task request for CosyVoice API.""" + return { + "header": { + "action": "run-task", + "task_id": task_id, + "streaming": "duplex" + }, + "payload": { + "task_group": "audio", + "task": "tts", + "function": "SpeechSynthesizer", + "model": self.config.model, + "parameters": { + "text_type": "PlainText", + "voice": self.config.voice, + "format": self.config.format, + "sample_rate": self.config.sample_rate, + "volume": int(self.config.volume), + "rate": self.config.speech_rate, + "pitch": self.config.pitch_rate, + "enable_ssml": False + }, + "input": {} + } + } + + def _cosyvoice_construct_continue_request(self, task_id: str, text: str) -> Dict[str, Any]: + """Construct the continue-task request for CosyVoice API.""" + return { + "header": { + "action": "continue-task", + "task_id": task_id, + "streaming": "duplex" + }, + "payload": { + "input": {"text": text} + } + } + + def _cosyvoice_construct_finish_request(self, task_id: str) -> Dict[str, Any]: + """Construct the finish-task request for CosyVoice API.""" + return { + "header": { + "action": "finish-task", + "task_id": task_id, + "streaming": "duplex" + }, + "payload": {"input": {}} + } + + def _cosyvoice_parse_event(self, message: str) -> Dict[str, Any]: + """Parse a JSON event from CosyVoice API.""" + try: + data = json.loads(message) + except json.JSONDecodeError: + logger.warning(f"Failed to parse JSON: {message[:100]}") + return {"type": "unknown"} + + header = data.get("header", {}) + event_type = header.get("event", "") + result: Dict[str, Any] = {"type": event_type, "task_id": header.get("task_id")} + + if event_type == "task-failed": + result["error_code"] = header.get("error_code") + result["error_message"] = header.get("error_message") + elif event_type == "task-finished": + payload = data.get("payload", {}) + usage = payload.get("usage", {}) + result["characters"] = usage.get("characters") + + return result + + async def _cosyvoice_wait_for_task_started(self, ws) -> bool: + """Wait for task_started event from CosyVoice API.""" + while True: + message = await asyncio.wait_for(ws.recv(), timeout=30) + if isinstance(message, bytes): + continue + event = self._cosyvoice_parse_event(message) + logger.info(f"CosyVoice received event: {event.get('type')}") + + if event.get("type") == "task-started": + return True + if event.get("type") == "task-failed": + raise AliTTSError(f"CosyVoice task failed: {event.get('error_message', 'Unknown error')}") + return False + + async def _cosyvoice_receive_audio( + self, + ws, + buffer: Optional[bytearray] = None, + yield_chunks: bool = False + ) -> AsyncGenerator[bytes, None]: + """Receive audio from CosyVoice API.""" + while True: + try: + message = await asyncio.wait_for(ws.recv(), timeout=60) + if isinstance(message, bytes): + if buffer is not None: + buffer.extend(message) + if yield_chunks: + yield message + continue + + event = self._cosyvoice_parse_event(message) + event_type = event.get("type") + logger.info(f"CosyVoice received event: {event_type}") + + if event_type == "task-failed": + raise AliTTSError(f"CosyVoice task failed: {event.get('error_message', 'Unknown error')}") + if event_type == "task-finished": + break + + except asyncio.TimeoutError: + logger.warning("Timeout waiting for CosyVoice task-finished event") + break + + async def _generate_cosyvoice_non_streaming(self, text: str, ws_url: str, headers: Dict[str, str]) -> bytes: + """Non-streaming speech generation using CosyVoice API.""" + buffer = bytearray() + task_id = self._cosyvoice_generate_task_id() + + try: + async with websockets.connect(ws_url, additional_headers=headers, ping_interval=None, + open_timeout=DEFAULT_WS_OPEN_TIMEOUT, + close_timeout=DEFAULT_WS_CLOSE_TIMEOUT) as ws: + request = self._cosyvoice_construct_run_task_request(task_id) + await ws.send(json.dumps(request)) + logger.info(f"Sent CosyVoice run-task request: task_id={task_id}") + + await self._cosyvoice_wait_for_task_started(ws) + + await ws.send(json.dumps(self._cosyvoice_construct_continue_request(task_id, text))) + logger.info(f"Sent CosyVoice continue-task with text: {text[:50]}...") + + await ws.send(json.dumps(self._cosyvoice_construct_finish_request(task_id))) + logger.info("Sent CosyVoice finish-task request") + + # Consume audio chunks to accumulate in buffer + async for _ in self._cosyvoice_receive_audio(ws, buffer=buffer): + pass # Audio is accumulated in buffer + + except AliTTSError: + raise + except Exception as e: + logger.error(f"CosyVoice TTS error: {str(e)}") + raise + + if len(buffer) == 0: + logger.warning("No audio data received from CosyVoice") + return bytes(buffer) + + async def _generate_cosyvoice_streaming(self, text: str, ws_url: str, headers: Dict[str, str]) -> AsyncGenerator[ + bytes, None]: + """Streaming speech generation using CosyVoice API.""" + task_id = self._cosyvoice_generate_task_id() + + try: + async with websockets.connect(ws_url, additional_headers=headers, ping_interval=None, + open_timeout=DEFAULT_WS_OPEN_TIMEOUT, + close_timeout=DEFAULT_WS_CLOSE_TIMEOUT) as ws: + await ws.send(json.dumps(self._cosyvoice_construct_run_task_request(task_id))) + logger.info(f"Sent CosyVoice run-task request: task_id={task_id}") + + await self._cosyvoice_wait_for_task_started(ws) + + await ws.send(json.dumps(self._cosyvoice_construct_continue_request(task_id, text))) + logger.info(f"Sent CosyVoice continue-task with text: {text[:50]}...") + + await ws.send(json.dumps(self._cosyvoice_construct_finish_request(task_id))) + logger.info("Sent CosyVoice finish-task request") + + async for chunk in self._cosyvoice_receive_audio(ws, yield_chunks=True): + yield chunk + + except AliTTSError: + raise + except Exception as e: + logger.error(f"CosyVoice TTS streaming error: {str(e)}") + raise + + # ==================== Qwen Realtime API Implementation ==================== + + def _qwen_generate_event_id(self) -> str: + """Generate a unique event ID for Qwen Realtime API.""" + return f"event_{uuid.uuid4().hex[:16]}" + + def _qwen_construct_session_update(self) -> Dict[str, Any]: + """Construct session.update request for Qwen Realtime API.""" + # Use default voice if not specified + voice = self.config.voice or "Cherry" + return { + "event_id": self._qwen_generate_event_id(), + "type": "session.update", + "session": { + "voice": voice, + "mode": "server_commit", + "language_type": "Auto", + "response_format": self._qwen_format_to_response_format(self.config.format), + "sample_rate": self.config.sample_rate, + "speech_rate": self.config.speech_rate, + "volume": int(self.config.volume) + } + } + + def _qwen_format_to_response_format(self, format_str: str) -> str: + """Convert format to Qwen Realtime response_format.""" + format_map = {"mp3": "mp3", "pcm": "pcm", "wav": "wav", "opus": "opus"} + return format_map.get(format_str.lower(), "pcm") + + def _qwen_construct_text_append(self, text: str) -> Dict[str, Any]: + """Construct input_text_buffer.append request for Qwen Realtime API.""" + return { + "event_id": self._qwen_generate_event_id(), + "type": "input_text_buffer.append", + "text": text + } + + def _qwen_construct_text_commit(self) -> Dict[str, Any]: + """Construct input_text_buffer.commit request for Qwen Realtime API.""" + return { + "event_id": self._qwen_generate_event_id(), + "type": "input_text_buffer.commit" + } + + def _qwen_construct_session_finish(self) -> Dict[str, Any]: + """Construct session.finish request for Qwen Realtime API.""" + return { + "event_id": self._qwen_generate_event_id(), + "type": "session.finish" + } + + def _qwen_parse_event(self, message: str) -> Dict[str, Any]: + """Parse a JSON event from Qwen Realtime API.""" + try: + data = json.loads(message) + except json.JSONDecodeError: + logger.warning(f"Failed to parse Qwen event JSON: {message[:100]}") + return {"type": "unknown"} + + event_type = data.get("type", "") + result: Dict[str, Any] = {"type": event_type, "raw": data} + + if event_type == "error": + error = data.get("error", {}) + result["error_code"] = error.get("code") + result["error_message"] = error.get("message") + + return result + + async def _qwen_wait_for_session_created(self, ws) -> bool: + """Wait for session.created event from Qwen Realtime API.""" + while True: + message = await asyncio.wait_for(ws.recv(), timeout=30) + if isinstance(message, bytes): + continue + event = self._qwen_parse_event(message) + logger.info(f"Qwen Realtime received event: {event.get('type')}") + + if event.get("type") == "session.created": + return True + if event.get("type") == "error": + raise AliTTSError(f"Qwen Realtime session error: {event.get('error_message', 'Unknown error')}") + return False + + def _qwen_is_terminal_event(self, event_type: str) -> bool: + """Check if event type indicates the session is done.""" + return event_type in ("response.audio.done", "session.finished") + + async def _qwen_wait_for_response_created(self, ws) -> bool: + """Wait for response.created event before collecting audio.""" + while True: + message = await asyncio.wait_for(ws.recv(), timeout=60) + if isinstance(message, bytes): + continue + event = self._qwen_parse_event(message) + event_type = event.get("type") + logger.info(f"Qwen Realtime received event: {event_type}") + + if event_type == "error": + raise AliTTSError(f"Qwen Realtime error: {event.get('error_message', 'Unknown error')}") + if event_type == "response.created": + logger.info("Response created, audio synthesis started") + return True + if event_type == "session.finished": + logger.warning("Session finished before audio started") + return False + return False + + def _qwen_handle_audio_delta(self, event: Dict[str, Any], buffer: Optional[bytearray], yield_chunks: bool) -> \ + Optional[bytes]: + """Handle response.audio.delta event and return audio chunk.""" + delta = event.get("raw", {}).get("delta", "") + if not delta: + return None + audio_data = base64.b64decode(delta) + if buffer is not None: + buffer.extend(audio_data) + return audio_data if yield_chunks else None + + async def _qwen_receive_audio( + self, + ws, + buffer: Optional[bytearray] = None, + yield_chunks: bool = False + ) -> AsyncGenerator[bytes, None]: + """Receive audio from Qwen Realtime API.""" + audio_done = False + while not audio_done: + try: + message = await asyncio.wait_for(ws.recv(), timeout=60) + if isinstance(message, bytes): + if buffer is not None: + buffer.extend(message) + if yield_chunks: + yield message + continue + + event = self._qwen_parse_event(message) + event_type = event.get("type") + logger.info(f"Qwen Realtime received event: {event_type}") + + if event_type == "error": + raise AliTTSError(f"Qwen Realtime error: {event.get('error_message', 'Unknown error')}") + + if event_type == "response.created": + logger.info("Response created, audio synthesis started") + continue + + if event_type == "response.audio.delta": + chunk = self._qwen_handle_audio_delta(event, buffer, yield_chunks) + if chunk: + yield chunk + + if self._qwen_is_terminal_event(event_type): + audio_done = True + + except asyncio.TimeoutError: + logger.warning("Timeout waiting for Qwen Realtime response") + break + + async def _generate_qwen_realtime_non_streaming(self, text: str, ws_url: str, headers: Dict[str, str]) -> bytes: + """Non-streaming speech generation using Qwen Realtime API.""" + buffer = bytearray() + + try: + async with websockets.connect(ws_url, additional_headers=headers, ping_interval=None, + open_timeout=DEFAULT_WS_OPEN_TIMEOUT, + close_timeout=DEFAULT_WS_CLOSE_TIMEOUT) as ws: + # Wait for session.created + await self._qwen_wait_for_session_created(ws) + logger.info("Qwen Realtime session created") + + # Send session update + await ws.send(json.dumps(self._qwen_construct_session_update())) + voice = self.config.voice or "Cherry" + logger.info(f"Sent Qwen Realtime session.update with voice={voice}") + + # Send text + await ws.send(json.dumps(self._qwen_construct_text_append(text))) + logger.info(f"Sent Qwen Realtime text: {text[:50]}...") + + # Commit and trigger synthesis + await ws.send(json.dumps(self._qwen_construct_text_commit())) + logger.info("Sent Qwen Realtime text commit") + + # Wait for response.created before finishing session + await self._qwen_wait_for_response_created(ws) + + # Finish session + await ws.send(json.dumps(self._qwen_construct_session_finish())) + logger.info("Sent Qwen Realtime session.finish") + + # Receive audio chunks to accumulate in buffer + async for _ in self._qwen_receive_audio(ws, buffer=buffer): + pass # Audio is accumulated in buffer + + except AliTTSError: + raise + except Exception as e: + logger.error(f"Qwen Realtime TTS error: {str(e)}") + raise + + if len(buffer) == 0: + logger.warning("No audio data received from Qwen Realtime") + return bytes(buffer) + + async def _generate_qwen_realtime_streaming(self, text: str, ws_url: str, headers: Dict[str, str]) -> \ + AsyncGenerator[bytes, None]: + """Streaming speech generation using Qwen Realtime API.""" + try: + async with websockets.connect(ws_url, additional_headers=headers, ping_interval=None, + open_timeout=DEFAULT_WS_OPEN_TIMEOUT, + close_timeout=DEFAULT_WS_CLOSE_TIMEOUT) as ws: + # Wait for session.created + await self._qwen_wait_for_session_created(ws) + logger.info("Qwen Realtime session created") + + # Send session update + await ws.send(json.dumps(self._qwen_construct_session_update())) + voice = self.config.voice or "Cherry" + logger.info(f"Sent Qwen Realtime session.update with voice={voice}") + + # Send text + await ws.send(json.dumps(self._qwen_construct_text_append(text))) + logger.info(f"Sent Qwen Realtime text: {text[:50]}...") + + # Commit and trigger synthesis + await ws.send(json.dumps(self._qwen_construct_text_commit())) + logger.info("Sent Qwen Realtime text commit") + + # Wait for response.created before finishing session + await self._qwen_wait_for_response_created(ws) + + # Finish session + await ws.send(json.dumps(self._qwen_construct_session_finish())) + logger.info("Sent Qwen Realtime session.finish") + + # Receive audio + async for chunk in self._qwen_receive_audio(ws, yield_chunks=True): + yield chunk + + except AliTTSError: + raise + except Exception as e: + logger.error(f"Qwen Realtime TTS streaming error: {str(e)}") + raise + + # ==================== Connectivity Check ==================== + + async def check_connectivity(self) -> bool: + """ + Test if the connection to the remote TTS service is normal. + + Returns: + True if connection successful, False otherwise + """ + api_type = "Qwen Realtime" if self._is_realtime else "CosyVoice" + try: + logger.info(f"Ali TTS connectivity test started with {api_type}") + logger.info(f"model={self.config.model}, voice={self.config.voice}") + audio_data = await self.generate_speech("Hello", stream=False) + is_success = self._is_tts_result_successful(audio_data) + if is_success: + logger.info("Ali TTS connectivity test successful") + else: + logger.error("Ali TTS connectivity test failed: empty audio data") + return is_success + except AliTTSError as e: + error_msg = str(e) + logger.error(f"Ali TTS connectivity test failed: {error_msg}") + return False + except Exception as e: + logger.error(f"Ali TTS connectivity test failed with exception: {str(e)}") + import traceback + logger.error(f"Traceback: {traceback.format_exc()}") + return False + diff --git a/sdk/nexent/core/models/message_utils.py b/sdk/nexent/core/models/message_utils.py index 3a123f1f0..981a1a31a 100644 --- a/sdk/nexent/core/models/message_utils.py +++ b/sdk/nexent/core/models/message_utils.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Optional def _flatten_content(raw_content: Any) -> str: @@ -24,7 +24,7 @@ def _flatten_content(raw_content: Any) -> str: return "" if raw_content is None else str(raw_content) -def prepare_messages_for_completion(normalized_messages: List[Any], model_factory: str | None) -> List[Any]: +def prepare_messages_for_completion(normalized_messages: List[Any], model_factory: Optional[str] = None) -> List[Any]: """ Prepare messages for completion based on provider requirements. @@ -47,4 +47,3 @@ def prepare_messages_for_completion(normalized_messages: List[Any], model_factor return prepared return normalized_messages - diff --git a/sdk/nexent/core/models/tts_model.py b/sdk/nexent/core/models/tts_model.py new file mode 100644 index 000000000..21633cdb9 --- /dev/null +++ b/sdk/nexent/core/models/tts_model.py @@ -0,0 +1,107 @@ +""" +Base TTS model interface for text-to-speech functionality. +""" +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Union, AsyncGenerator + + +class BaseTTSModel(ABC): + """ + Abstract base class for TTS (Text-to-Speech) models. + + All TTS implementations (e.g., Volcano Engine, Ali Cloud) must inherit from this class + and implement the required abstract methods. + """ + + def __init__(self, audio_file_path: Optional[str] = None): + """ + Initialize the base TTS model. + + Args: + audio_file_path: Path to test audio file for connectivity testing + """ + self.audio_file_path = audio_file_path + + @abstractmethod + def get_websocket_url(self) -> str: + """ + Get the WebSocket URL for the TTS service. + + Returns: + WebSocket URL string + """ + pass + + @abstractmethod + def get_auth_headers(self) -> Dict[str, str]: + """ + Get authentication headers for the WebSocket connection. + + Returns: + Headers dict with authentication information + """ + pass + + @abstractmethod + async def generate_speech( + self, + text: str, + stream: bool = False + ) -> Union[bytes, AsyncGenerator[bytes, None]]: + """ + Generate speech from text. + + Args: + text: Input text to synthesize + stream: If True, return an async generator of audio chunks. + If False, return complete audio bytes. + + Returns: + Audio data either as complete bytes or streaming chunks + """ + pass + + @abstractmethod + async def check_connectivity(self) -> bool: + """ + Test if the connection to the remote TTS service is normal. + + Returns: + True if connection successful, False otherwise + """ + pass + + def _is_tts_result_successful(self, result: Any) -> bool: + """ + Check if TTS result indicates a successful synthesis. + + Args: + result: TTS processing result + + Returns: + True if successful, False otherwise + """ + if isinstance(result, bytes): + return len(result) > 0 + if isinstance(result, dict): + if 'error' in result: + return False + return 'audio' in result or 'text' in result or 'message' in result + return False + + def _extract_tts_error_message(self, result: Any) -> str: + """ + Extract error message from TTS result. + + Args: + result: TTS processing result + + Returns: + Error message string + """ + if isinstance(result, dict): + if 'error' in result: + return str(result['error']) + if 'message' in result: + return str(result['message']) + return f"Unknown error in result: {result}" diff --git a/sdk/nexent/core/models/volc_tts_model.py b/sdk/nexent/core/models/volc_tts_model.py new file mode 100644 index 000000000..446631827 --- /dev/null +++ b/sdk/nexent/core/models/volc_tts_model.py @@ -0,0 +1,167 @@ +""" +Volcano Engine TTS model implementation using proprietary protocol. +""" +import copy +import gzip +import io +import json +import logging +import uuid +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Dict, Optional, Union + +import websockets + +from .tts_model import BaseTTSModel + +logger = logging.getLogger(__name__) + + +@dataclass +class VolcTTSConfig: + """Configuration for Volcano Engine TTS model.""" + appid: str + token: str + speed_ratio: float + ws_url: str = "wss://openspeech.bytedance.com/api/v1/tts/ws_binary" + host: str = "openspeech.bytedance.com" + encoding: str = "mp3" + volume_ratio: float = 1.0 + pitch_ratio: float = 1.0 + cluster:str="volcano_tts" + resource_id:str="seed-tts-2.0" + voice_type: str = "zh_female_vv_uranus_bigtts" + + @property + def api_url(self) -> str: + return self.ws_url + + +class VolcTTSModel(BaseTTSModel): + """ + Volcano Engine TTS model implementation using proprietary protocol. + """ + + MESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"} + MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0", + 2: "last message from server (seq < 0)", 3: "sequence number < 0"} + MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"} + MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"} + + DEFAULT_HEADER = bytearray([0x11, 0x10, 0x11, 0x00]) + + def __init__(self, config: VolcTTSConfig, audio_file_path: Optional[str] = None): + super().__init__(audio_file_path) + self.config = config + self._request_template = { + "app": {"appid": config.appid, "token": config.token, "cluster": config.cluster, "resource_id": config.resource_id}, + "user": {"uid": "388808087185088"}, + "audio": { + "voice_type": config.voice_type, + "encoding": config.encoding, + "speed_ratio": config.speed_ratio, + "volume_ratio": config.volume_ratio, + "pitch_ratio": config.pitch_ratio, + }, + "request": {"reqid": "xxx", "text": "", "text_type": "plain", "operation": "xxx"} + } + + def get_websocket_url(self) -> str: + return self.config.api_url + + def get_auth_headers(self) -> Dict[str, str]: + headers = { + "Authorization": f"Bearer; {self.config.token}", + "X-Api-App-Id": self.config.appid, + "X-Api-Access-Key": self.config.token, + "X-Api-Resource-Id": self.config.resource_id + } + return headers + + def _prepare_request(self, text: str, operation: str = "submit") -> bytes: + request_json = copy.deepcopy(self._request_template) + request_json["request"]["reqid"] = str(uuid.uuid4()) + request_json["request"]["text"] = text + request_json["request"]["operation"] = operation + payload_bytes = str.encode(json.dumps(request_json)) + payload_bytes = gzip.compress(payload_bytes) + full_request = bytearray(self.DEFAULT_HEADER) + full_request.extend(len(payload_bytes).to_bytes(4, 'big')) + full_request.extend(payload_bytes) + return bytes(full_request) + + def _parse_response(self, res: bytes, buffer: Optional[io.BytesIO] = None) -> tuple[bool, Optional[bytes]]: + protocol_version = res[0] >> 4 + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0f + payload = res[header_size * 4:] + logger.info(f"Volc TTS protocol: version={protocol_version}, header_size={header_size}, msg_type={message_type:#x}, flags={message_type_specific_flags}") + + if message_type == 0xb: + if message_type_specific_flags == 0: + return False, None + sequence_number = int.from_bytes(payload[:4], "big", signed=True) + audio_chunk = payload[8:] + if buffer is not None: + buffer.write(audio_chunk) + return sequence_number < 0, audio_chunk + elif message_type == 0xf: + code = int.from_bytes(payload[:4], "big", signed=False) + error_msg = payload[8:] + if (res[2] & 0x0f) == 1: + error_msg = gzip.decompress(error_msg) + err_str = "Volc TTS Error " + str(code) + ": " + error_msg.decode('utf-8') + logger.error(err_str) + raise Exception(err_str) + return True, None + + async def generate_speech( + self, + text: str, + stream: bool = False + ) -> Union[bytes, AsyncGenerator[bytes, None]]: + request = self._prepare_request(text) + headers = self.get_auth_headers() + logger.info(f"Volc TTS request prepared, text_len={len(text)}, stream={stream}") + if not stream: + buffer = io.BytesIO() + async with websockets.connect(self.config.api_url, additional_headers=headers, ping_interval=None) as ws: + await ws.send(request) + while True: + response = await ws.recv() + done, _ = self._parse_response(response, buffer) + if done: + break + return buffer.getvalue() + else: + async def audio_generator(): + async with websockets.connect(self.config.api_url, additional_headers=headers, + ping_interval=None) as ws: + await ws.send(request) + while True: + response = await ws.recv() + logger.info(f"Volc TTS raw response ({len(response)} bytes): {response[:50]!r}") + done, chunk = self._parse_response(response) + logger.info(f"Volc TTS parsed: done={done}, chunk_len={len(chunk) if chunk else 0}") + if chunk: + yield chunk + if done: + break + return audio_generator() + + async def check_connectivity(self) -> bool: + try: + logger.info("Volc TTS connectivity test started...") + audio_data = await self.generate_speech("Hello", stream=False) + is_success = self._is_tts_result_successful(audio_data) + if is_success: + logger.info("Volc TTS connectivity test successful") + else: + logger.error("Volc TTS connectivity test failed: empty or invalid audio data") + return is_success + except Exception as e: + logger.error("Volc TTS connectivity test failed with exception: " + str(e)) + import traceback + logger.error("Volc TTS connectivity test exception traceback: " + traceback.format_exc()) + return False diff --git a/test/backend/services/test_voice_service.py b/test/backend/services/test_voice_service.py index 0151ec3ad..a9203eca2 100644 --- a/test/backend/services/test_voice_service.py +++ b/test/backend/services/test_voice_service.py @@ -1,7 +1,7 @@ """ Unit tests for VoiceService. -Tests STT session management and connectivity checks. +Tests STT/TTS session management, speech generation, and connectivity checks. Patches SDK model classes at the module level where voice_service imports them. """ import os @@ -15,6 +15,7 @@ from consts.exceptions import ( VoiceServiceException, STTConnectionException, + TTSConnectionException, ) @@ -32,11 +33,29 @@ def __init__(self, config=None, test_path=None): self.start_streaming_session = AsyncMock() +class MockTTSModel: + """Mock TTS model mimicking the real SDK interface.""" + + def __init__(self, config=None): + self.config = config + self.check_connectivity = AsyncMock(return_value=True) + + async def generate_speech(self, text: str, stream: bool = False): + if stream: + async def gen(): + yield b"chunk_1" + yield b"chunk_2" + yield b"chunk_3" + return gen() + return b"complete_audio_data" + + # --------------------------------------------------------------------------- # Shared mock instances -- populated per-test via _mock_all_models # --------------------------------------------------------------------------- _shared_stt = None +_shared_tts = None def _reset_singleton(): @@ -45,25 +64,32 @@ def _reset_singleton(): services.voice_service._voice_service_instance = None -def _mock_all_models(stt_success=True, stt_exc=None): +def _mock_all_models(stt_success=True, tts_success=True, stt_exc=None, tts_exc=None): """ Patch SDK model classes so every instantiation returns the shared mock instance. - Returns (patches, mock_stt). + Returns (patches, mock_stt, mock_tts). """ - global _shared_stt + global _shared_stt, _shared_tts _shared_stt = MockSTTModel() + _shared_tts = MockTTSModel() _shared_stt.check_connectivity = AsyncMock(return_value=stt_success) + _shared_tts.check_connectivity = AsyncMock(return_value=tts_success) if stt_exc: _shared_stt.check_connectivity = AsyncMock(side_effect=stt_exc) _shared_stt.start_streaming_session = AsyncMock(side_effect=stt_exc) + if tts_exc: + _shared_tts.check_connectivity = AsyncMock(side_effect=tts_exc) + _shared_tts.generate_speech = AsyncMock(side_effect=tts_exc) patches = [ patch("services.voice_service.VolcSTTModel", return_value=_shared_stt), patch("services.voice_service.AliSTTModel", return_value=_shared_stt), + patch("services.voice_service.VolcTTSModel", return_value=_shared_tts), + patch("services.voice_service.AliTTSModel", return_value=_shared_tts), ] - return patches, _shared_stt + return patches, _shared_stt, _shared_tts # --------------------------------------------------------------------------- @@ -83,7 +109,7 @@ class TestStartSTTStreamingSession: @pytest.mark.asyncio async def test_success(self): _reset_singleton() - patches, mock_stt = _mock_all_models(stt_success=True) + patches, mock_stt, _ = _mock_all_models(stt_success=True) for p in patches: p.start() try: @@ -99,7 +125,7 @@ async def test_success(self): async def test_stt_connection_error(self): _reset_singleton() exc = STTConnectionException("STT connection failed") - patches, _ = _mock_all_models(stt_exc=exc) + patches, _, _ = _mock_all_models(stt_exc=exc) for p in patches: p.start() try: @@ -115,7 +141,7 @@ async def test_stt_connection_error(self): async def test_general_error(self): _reset_singleton() exc = RuntimeError("unexpected error") - patches, _ = _mock_all_models(stt_exc=exc) + patches, _, _ = _mock_all_models(stt_exc=exc) for p in patches: p.start() try: @@ -128,6 +154,202 @@ async def test_general_error(self): p.stop() +# --------------------------------------------------------------------------- +# Tests: generate_tts_speech +# --------------------------------------------------------------------------- + +class TestGenerateTTSSpeech: + """Tests for generate_tts_speech.""" + + @pytest.mark.asyncio + async def test_success_non_streaming(self): + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + result = await service.generate_tts_speech("Hello world", stream=False) + assert result == b"complete_audio_data" + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_success_streaming(self): + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + chunks = [] + async def capture(): + gen = await service.generate_tts_speech("Hello world", stream=True) + async for chunk in gen: + chunks.append(chunk) + await capture() + assert chunks == [b"chunk_1", b"chunk_2", b"chunk_3"] + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_empty_text_raises(self): + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + with pytest.raises(VoiceServiceException, match="No text provided"): + await service.generate_tts_speech("") + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_none_text_raises(self): + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + with pytest.raises(VoiceServiceException, match="No text provided"): + await service.generate_tts_speech(None) + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_tts_connection_error(self): + _reset_singleton() + exc = TTSConnectionException("TTS connection failed") + patches, _, _ = _mock_all_models(tts_exc=exc) + for p in patches: + p.start() + try: + service = VoiceService() + with pytest.raises(TTSConnectionException, match="TTS connection failed"): + await service.generate_tts_speech("Hello world") + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_general_error(self): + _reset_singleton() + exc = RuntimeError("unexpected") + patches, _, _ = _mock_all_models(tts_exc=exc) + for p in patches: + p.start() + try: + service = VoiceService() + with pytest.raises(TTSConnectionException, match="unexpected"): + await service.generate_tts_speech("Hello world") + finally: + for p in reversed(patches): + p.stop() + + +# --------------------------------------------------------------------------- +# Tests: stream_tts_to_websocket +# --------------------------------------------------------------------------- + +class TestStreamTTSToWebSocket: + """Tests for stream_tts_to_websocket.""" + + def _connected_ws(self): + ws = Mock() + ws.send_bytes = AsyncMock() + ws.send_json = AsyncMock() + state = Mock() + state.name = "CONNECTED" + ws.client_state = state + return ws + + @pytest.mark.asyncio + async def test_success(self): + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + mock_ws = self._connected_ws() + await service.stream_tts_to_websocket(mock_ws, "Hello world") + assert mock_ws.send_bytes.call_count == 3 + mock_ws.send_json.assert_called_once_with({"status": "completed"}) + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_tts_connection_error(self): + _reset_singleton() + exc = TTSConnectionException("TTS connection failed") + patches, _, _ = _mock_all_models(tts_exc=exc) + for p in patches: + p.start() + try: + service = VoiceService() + mock_ws = self._connected_ws() + with pytest.raises(TTSConnectionException, match="TTS connection failed"): + await service.stream_tts_to_websocket(mock_ws, "Hello world") + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + @pytest.mark.skip(reason="stream_tts_to_websocket internally calls generate_tts_speech which creates fresh model instances; patching the service method does not intercept the internal call path without modifying voice_service.py") + async def test_disconnects_if_websocket_closed(self): + """Audio sending stops when WebSocket is no longer CONNECTED.""" + pass + mock_ws = self._connected_ws() + sent_chunks = [] + disconnected_triggered = [] + + async def fake_send_bytes(data): + sent_chunks.append(data) + + mock_ws.send_bytes = fake_send_bytes + + async def disconnecting_gen(): + yield b"chunk_1" + disconnected_triggered.append(True) + mock_ws.client_state.name = "DISCONNECTED" + yield b"chunk_2" + + class DisconnectingTTS(MockTTSModel): + async def generate_speech(self, text, stream=False): + if stream: + async for c in disconnecting_gen(): + yield c + return + + global _shared_stt, _shared_tts + _shared_stt = MockSTTModel() + _shared_tts = DisconnectingTTS() + + patches = [ + patch("services.voice_service.VolcSTTModel", return_value=_shared_stt), + patch("services.voice_service.AliSTTModel", return_value=_shared_stt), + patch("services.voice_service.VolcTTSModel", return_value=_shared_tts), + patch("services.voice_service.AliTTSModel", return_value=_shared_tts), + ] + for p in patches: + p.start() + try: + service = VoiceService() + await service.stream_tts_to_websocket(mock_ws, "Hello world") + assert len(sent_chunks) == 1, f"Expected 1 chunk but got {len(sent_chunks)}" + assert disconnected_triggered == [True] + finally: + for p in reversed(patches): + p.stop() + + # --------------------------------------------------------------------------- # Tests: check_voice_connectivity # --------------------------------------------------------------------------- @@ -138,7 +360,7 @@ class TestCheckVoiceConnectivity: @pytest.mark.asyncio async def test_stt_success(self): _reset_singleton() - patches, _ = _mock_all_models(stt_success=True) + patches, _, _ = _mock_all_models(stt_success=True, tts_success=True) for p in patches: p.start() try: @@ -149,10 +371,24 @@ async def test_stt_success(self): for p in reversed(patches): p.stop() + @pytest.mark.asyncio + async def test_tts_success(self): + _reset_singleton() + patches, _, _ = _mock_all_models(stt_success=True, tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + result = await service.check_voice_connectivity("tts") + assert result is True + finally: + for p in reversed(patches): + p.stop() + @pytest.mark.asyncio async def test_stt_failure_raises(self): _reset_singleton() - patches, _ = _mock_all_models(stt_success=False) + patches, _, _ = _mock_all_models(stt_success=False, tts_success=True) for p in patches: p.start() try: @@ -163,15 +399,29 @@ async def test_stt_failure_raises(self): for p in reversed(patches): p.stop() + @pytest.mark.asyncio + async def test_tts_failure_raises(self): + _reset_singleton() + patches, _, _ = _mock_all_models(stt_success=True, tts_success=False) + for p in patches: + p.start() + try: + service = VoiceService() + with pytest.raises(TTSConnectionException): + await service.check_voice_connectivity("tts") + finally: + for p in reversed(patches): + p.stop() + @pytest.mark.asyncio async def test_invalid_model_type_raises(self): _reset_singleton() - patches, _ = _mock_all_models() + patches, _, _ = _mock_all_models() for p in patches: p.start() try: service = VoiceService() - with pytest.raises(VoiceServiceException, match=r"Unsupported model type"): + with pytest.raises(VoiceServiceException, match="Unknown model type"): await service.check_voice_connectivity("invalid") finally: for p in reversed(patches): @@ -181,7 +431,7 @@ async def test_invalid_model_type_raises(self): async def test_stt_connection_error(self): _reset_singleton() exc = STTConnectionException("STT unavailable") - patches, _ = _mock_all_models(stt_exc=exc) + patches, _, _ = _mock_all_models(stt_exc=exc) for p in patches: p.start() try: @@ -192,11 +442,26 @@ async def test_stt_connection_error(self): for p in reversed(patches): p.stop() + @pytest.mark.asyncio + async def test_tts_connection_error(self): + _reset_singleton() + exc = TTSConnectionException("TTS unavailable") + patches, _, _ = _mock_all_models(tts_exc=exc) + for p in patches: + p.start() + try: + service = VoiceService() + with pytest.raises(TTSConnectionException, match="TTS unavailable"): + await service.check_voice_connectivity("tts") + finally: + for p in reversed(patches): + p.stop() + @pytest.mark.asyncio async def test_general_error_wrapped(self): _reset_singleton() exc = RuntimeError("unexpected") - patches, _ = _mock_all_models(stt_exc=exc) + patches, _, _ = _mock_all_models(stt_exc=exc) for p in patches: p.start() try: @@ -217,7 +482,7 @@ class TestVoiceServiceSingleton: def test_returns_same_instance(self): _reset_singleton() - patches, _ = _mock_all_models() + patches, _, _ = _mock_all_models() for p in patches: p.start() try: @@ -235,7 +500,7 @@ class TestGetSTTModelFromConfig: def test_volc_stt_model_selection(self): """Test that volc model is selected for volc factory.""" _reset_singleton() - patches, _ = _mock_all_models() + patches, _, _ = _mock_all_models() for p in patches: p.start() try: @@ -254,7 +519,7 @@ def test_volc_stt_model_selection(self): def test_volc_stt_model_selection_chinese(self): """Test that volc model is selected for Chinese factory name.""" _reset_singleton() - patches, _ = _mock_all_models() + patches, _, _ = _mock_all_models() for p in patches: p.start() try: @@ -271,7 +536,7 @@ def test_volc_stt_model_selection_chinese(self): def test_ali_stt_model_default(self): """Test that Ali STT model is used by default.""" _reset_singleton() - patches, _ = _mock_all_models() + patches, _, _ = _mock_all_models() for p in patches: p.start() try: @@ -285,7 +550,7 @@ def test_ali_stt_model_default(self): def test_ali_stt_model_with_dashscope(self): """Test that Ali STT model is used for dashscope factory.""" _reset_singleton() - patches, _ = _mock_all_models() + patches, _, _ = _mock_all_models() for p in patches: p.start() try: @@ -302,7 +567,7 @@ def test_ali_stt_model_with_dashscope(self): def test_with_custom_base_url(self): """Test with custom WebSocket URL.""" _reset_singleton() - patches, _ = _mock_all_models() + patches, _, _ = _mock_all_models() for p in patches: p.start() try: @@ -317,13 +582,103 @@ def test_with_custom_base_url(self): p.stop() +class TestGetTTSModelFromConfig: + """Tests for _get_tts_model_from_config.""" + + def test_volc_tts_model_selection(self): + """Test that volc TTS model is selected for volc factory.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + model = service._get_tts_model_from_config( + model_factory="volc", + api_key="test_key", + model_appid="test_appid", + access_token="test_token" + ) + assert model is not None + finally: + for p in reversed(patches): + p.stop() + + def test_volc_tts_from_base_url(self): + """Test that volc TTS is auto-detected from base_url.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + model = service._get_tts_model_from_config( + base_url="wss://openspeech.bytedance.com/api/v1/tts" + ) + assert model is not None + finally: + for p in reversed(patches): + p.stop() + + def test_ali_tts_cosyvoice_default(self): + """Test Ali TTS with CosyVoice model.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + model = service._get_tts_model_from_config( + api_key="test_key", + model="cosyvoice-v2" + ) + assert model is not None + finally: + for p in reversed(patches): + p.stop() + + def test_ali_tts_qwen_realtime(self): + """Test Ali TTS with Qwen Realtime model.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + model = service._get_tts_model_from_config( + api_key="test_key", + model="qwen-tts" + ) + assert model is not None + finally: + for p in reversed(patches): + p.stop() + + def test_with_speed_ratio(self): + """Test TTS model with custom speed ratio.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + model = service._get_tts_model_from_config( + api_key="test_key", + speed_ratio=1.5 + ) + assert model is not None + finally: + for p in reversed(patches): + p.stop() + + class TestCheckSTTConnectivity: """Tests for check_stt_connectivity.""" @pytest.mark.asyncio async def test_success(self): _reset_singleton() - patches, _ = _mock_all_models(stt_success=True) + patches, _, _ = _mock_all_models(stt_success=True) for p in patches: p.start() try: @@ -340,7 +695,7 @@ async def test_success(self): @pytest.mark.asyncio async def test_failure_raises(self): _reset_singleton() - patches, _ = _mock_all_models(stt_success=False) + patches, _, _ = _mock_all_models(stt_success=False) for p in patches: p.start() try: @@ -354,7 +709,7 @@ async def test_failure_raises(self): @pytest.mark.asyncio async def test_volc_model(self): _reset_singleton() - patches, _ = _mock_all_models(stt_success=True) + patches, _, _ = _mock_all_models(stt_success=True) for p in patches: p.start() try: @@ -370,13 +725,65 @@ async def test_volc_model(self): p.stop() +class TestCheckTTSConnectivity: + """Tests for check_tts_connectivity.""" + + @pytest.mark.asyncio + async def test_success(self): + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + result = await service.check_tts_connectivity( + api_key="test_key", + model="cosyvoice-v2" + ) + assert result is True + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_failure_raises(self): + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=False) + for p in patches: + p.start() + try: + service = VoiceService() + with pytest.raises(TTSConnectionException): + await service.check_tts_connectivity(api_key="test_key") + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_with_speed_ratio(self): + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + result = await service.check_tts_connectivity( + api_key="test_key", + speed_ratio=1.5 + ) + assert result is True + finally: + for p in reversed(patches): + p.stop() + + class TestStartSTTStreamingSessionWithConfig: """Tests for start_stt_streaming_session with various config scenarios.""" @pytest.mark.asyncio async def test_with_explicit_config(self): _reset_singleton() - patches, _ = _mock_all_models() + patches, _, _ = _mock_all_models() for p in patches: p.start() try: @@ -395,7 +802,7 @@ async def test_with_explicit_config(self): @pytest.mark.asyncio async def test_with_ali_config(self): _reset_singleton() - patches, _ = _mock_all_models() + patches, _, _ = _mock_all_models() for p in patches: p.start() try: @@ -413,7 +820,7 @@ async def test_with_ali_config(self): @pytest.mark.asyncio async def test_with_language_override(self): _reset_singleton() - patches, _ = _mock_all_models() + patches, _, _ = _mock_all_models() for p in patches: p.start() try: @@ -429,5 +836,64 @@ async def test_with_language_override(self): p.stop() +class TestGenerateTTSSpeechWithConfig: + """Tests for generate_tts_speech with various config scenarios.""" + + @pytest.mark.asyncio + async def test_with_tts_config(self): + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + tts_config = { + "api_key": "test_key", + "model": "cosyvoice-v2" + } + result = await service.generate_tts_speech( + "Hello world", + tts_config=tts_config + ) + assert result is not None + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_with_model_override(self): + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + result = await service.generate_tts_speech( + "Hello world", + model_name_override="custom-model" + ) + assert result is not None + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_with_tenant_id(self): + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + result = await service.generate_tts_speech( + "Hello world", + tenant_id="test_tenant" + ) + assert result is not None + finally: + for p in reversed(patches): + p.stop() + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/backend/services/test_voice_service_tts.py b/test/backend/services/test_voice_service_tts.py new file mode 100644 index 000000000..fcacd4255 --- /dev/null +++ b/test/backend/services/test_voice_service_tts.py @@ -0,0 +1,682 @@ +""" +Unit tests for VoiceService TTS methods. + +These tests cover: +- _get_tts_model_from_config +- _get_tts_model_from_tenant_config +- generate_tts_speech +- stream_tts_to_websocket +- check_tts_connectivity +""" +import os +import sys +import pytest +from unittest.mock import Mock, AsyncMock, patch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend")) + +from consts.exceptions import ( + VoiceServiceException, + TTSConnectionException, +) + + +class MockSTTModel: + """Mock STT model.""" + + def __init__(self, config=None, test_path=None): + self.config = config + self.test_path = test_path + self.check_connectivity = AsyncMock(return_value=True) + self.start_streaming_session = AsyncMock() + + +class MockTTSModel: + """Mock TTS model mimicking the real SDK interface.""" + + def __init__(self, config=None): + self.config = config + self.check_connectivity = AsyncMock(return_value=True) + + async def generate_speech(self, text: str, stream: bool = False): + if stream: + async def gen(): + yield b"chunk_1" + yield b"chunk_2" + yield b"chunk_3" + return gen() + return b"complete_audio_data" + + +_shared_stt = None +_shared_tts = None + + +def _reset_singleton(): + """Reset the voice service singleton between tests.""" + import services.voice_service + services.voice_service._voice_service_instance = None + + +def _mock_all_models(stt_success=True, tts_success=True, stt_exc=None, tts_exc=None): + """ + Patch SDK model classes so every instantiation returns the shared mock instance. + Returns (patches, mock_stt, mock_tts). + """ + global _shared_stt, _shared_tts + _shared_stt = MockSTTModel() + _shared_tts = MockTTSModel() + + _shared_stt.check_connectivity = AsyncMock(return_value=stt_success) + _shared_tts.check_connectivity = AsyncMock(return_value=tts_success) + + if stt_exc: + _shared_stt.check_connectivity = AsyncMock(side_effect=stt_exc) + _shared_stt.start_streaming_session = AsyncMock(side_effect=stt_exc) + if tts_exc: + _shared_tts.check_connectivity = AsyncMock(side_effect=tts_exc) + _shared_tts.generate_speech = AsyncMock(side_effect=tts_exc) + + patches = [ + patch("services.voice_service.VolcSTTModel", return_value=_shared_stt), + patch("services.voice_service.AliSTTModel", return_value=_shared_stt), + patch("services.voice_service.VolcTTSModel", return_value=_shared_tts), + patch("services.voice_service.AliTTSModel", return_value=_shared_tts), + ] + return patches, _shared_stt, _shared_tts + + +import services.voice_service +from services.voice_service import VoiceService + + +# --------------------------------------------------------------------------- +# Tests: _get_tts_model_from_config +# --------------------------------------------------------------------------- + +class TestGetTTSModelFromConfig: + """Tests for _get_tts_model_from_config.""" + + def test_volc_model_selection_with_volc_factory(self): + """Test that Volc TTS model is selected when model_factory is 'volc'.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + model = service._get_tts_model_from_config( + model_factory="volc", + model_appid="test_appid", + access_token="test_token", + speed_ratio=1.0 + ) + assert model is not None + finally: + for p in reversed(patches): + p.stop() + + def test_volc_model_selection_with_volcano_factory(self): + """Test that Volc TTS model is selected when model_factory is 'volcano'.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + model = service._get_tts_model_from_config( + model_factory="volcano", + model_appid="test_appid", + access_token="test_token" + ) + assert model is not None + finally: + for p in reversed(patches): + p.stop() + + def test_volc_model_selection_from_base_url(self): + """Test that Volc TTS model is auto-detected from base_url containing openspeech.bytedance.com.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + model = service._get_tts_model_from_config( + base_url="wss://openspeech.bytedance.com/api/v1/tts" + ) + assert model is not None + finally: + for p in reversed(patches): + p.stop() + + def test_ali_tts_model_default_settings(self): + """Test that Ali TTS model is used by default when no factory specified.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + model = service._get_tts_model_from_config() + assert model is not None + finally: + for p in reversed(patches): + p.stop() + + def test_ali_tts_model_with_api_key_and_model(self): + """Test that Ali TTS model is selected with explicit api_key and model parameters.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + model = service._get_tts_model_from_config( + api_key="test_api_key", + model="qwen3-tts-flash", + speed_ratio=1.2 + ) + assert model is not None + finally: + for p in reversed(patches): + p.stop() + + def test_volc_tts_model_with_custom_base_url(self): + """Test Volc TTS model with custom base_url.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + model = service._get_tts_model_from_config( + model_factory="volc", + model_appid="test_appid", + access_token="test_token", + base_url="wss://custom.volc.com/api/tts" + ) + assert model is not None + finally: + for p in reversed(patches): + p.stop() + + +# --------------------------------------------------------------------------- +# Tests: _get_tts_model_from_tenant_config +# --------------------------------------------------------------------------- + +class TestGetTTSModelFromTenantConfig: + """Tests for _get_tts_model_from_tenant_config.""" + + def test_with_tenant_config_available(self): + """Test _get_tts_model_from_tenant_config when tenant config exists.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + + mock_tts_config = { + "model_factory": "volc", + "api_key": "test_api_key", + "model_appid": "test_appid", + "access_token": "test_token", + "speed_ratio": 1.5, + "base_url": "wss://custom.url", + "model_name": "test_model" + } + + with patch('services.voice_service.tenant_config_manager') as mock_config_mgr: + mock_config_mgr.get_model_config.return_value = mock_tts_config + + with patch.object(service, '_get_tts_model_from_config') as mock_get_model: + mock_get_model.return_value = MockTTSModel() + result = service._get_tts_model_from_tenant_config("test_tenant_id") + assert result is not None + finally: + for p in reversed(patches): + p.stop() + + def test_fallback_to_database_records(self): + """Test _get_tts_model_from_tenant_config falls back to database records.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + + mock_record = { + "model_factory": "dashscope", + "api_key": "test_api_key", + "model_name": "qwen3-tts-flash", + "speed_ratio": 1.0 + } + + with patch('services.voice_service.tenant_config_manager') as mock_config_mgr, \ + patch('services.voice_service.get_model_records') as mock_get_records: + mock_config_mgr.get_model_config.return_value = None + mock_get_records.return_value = [mock_record] + + with patch.object(service, '_get_tts_model_from_config') as mock_get_model: + mock_get_model.return_value = MockTTSModel() + result = service._get_tts_model_from_tenant_config("test_tenant_id") + assert result is not None + finally: + for p in reversed(patches): + p.stop() + + def test_default_config_when_nothing_available(self): + """Test _get_tts_model_from_tenant_config uses default when no config or records exist.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + + with patch('services.voice_service.tenant_config_manager') as mock_config_mgr, \ + patch('services.voice_service.get_model_records') as mock_get_records: + mock_config_mgr.get_model_config.return_value = None + mock_get_records.return_value = [] + + with patch.object(service, '_get_tts_model_from_config') as mock_get_model: + mock_get_model.return_value = MockTTSModel() + result = service._get_tts_model_from_tenant_config("test_tenant_id") + assert result is not None + finally: + for p in reversed(patches): + p.stop() + + def test_exception_handling(self): + """Test _get_tts_model_from_tenant_config handles exceptions gracefully.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + + with patch('services.voice_service.tenant_config_manager') as mock_config_mgr: + mock_config_mgr.get_model_config.side_effect = Exception("Database error") + + with patch.object(service, '_get_tts_model_from_config') as mock_get_model: + mock_get_model.return_value = MockTTSModel() + result = service._get_tts_model_from_tenant_config("test_tenant_id") + assert result is not None + finally: + for p in reversed(patches): + p.stop() + + +# --------------------------------------------------------------------------- +# Tests: generate_tts_speech +# --------------------------------------------------------------------------- + +class TestGenerateTTSSpeech: + """Tests for generate_tts_speech.""" + + @pytest.mark.asyncio + async def test_with_explicit_tts_config_volc(self): + """Test generate_tts_speech with explicit Volcano TTS config.""" + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + tts_config = { + "model_factory": "volc", + "model_appid": "test_appid", + "access_token": "test_token", + "speed_ratio": 1.0 + } + result = await service.generate_tts_speech( + "Hello world", + stream=True, + tts_config=tts_config + ) + chunks = [] + async for chunk in result: + chunks.append(chunk) + assert chunks == [b"chunk_1", b"chunk_2", b"chunk_3"] + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_with_explicit_tts_config_ali_with_api_key(self): + """Test generate_tts_speech with explicit Ali TTS config containing api_key.""" + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + tts_config = { + "api_key": "test_api_key", + "model": "qwen3-tts-flash", + "speed_ratio": 1.2 + } + result = await service.generate_tts_speech( + "Hello world", + stream=True, + tts_config=tts_config + ) + chunks = [] + async for chunk in result: + chunks.append(chunk) + assert chunks == [b"chunk_1", b"chunk_2", b"chunk_3"] + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_with_tenant_id(self): + """Test generate_tts_speech with tenant_id to pull model from tenant config.""" + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + result = await service.generate_tts_speech( + "Hello world", + stream=False, + tenant_id="test_tenant_id" + ) + assert result == b"complete_audio_data" + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_empty_text_raises_voice_service_exception(self): + """Test generate_tts_speech raises VoiceServiceException for empty text.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + with pytest.raises(VoiceServiceException, match="No text provided for TTS generation"): + await service.generate_tts_speech("") + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_none_text_raises_voice_service_exception(self): + """Test generate_tts_speech raises VoiceServiceException when text is None.""" + _reset_singleton() + patches, _, _ = _mock_all_models() + for p in patches: + p.start() + try: + service = VoiceService() + with pytest.raises(VoiceServiceException, match="No text provided for TTS generation"): + await service.generate_tts_speech(None) + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_tts_connection_error_raises_tts_connection_exception(self): + """Test generate_tts_speech raises TTSConnectionException on connection failure.""" + _reset_singleton() + exc = TTSConnectionException("TTS connection failed") + patches, _, _ = _mock_all_models(tts_exc=exc) + for p in patches: + p.start() + try: + service = VoiceService() + with pytest.raises(TTSConnectionException, match="TTS connection failed"): + await service.generate_tts_speech("Hello world") + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_general_error_raises_tts_connection_exception(self): + """Test generate_tts_speech wraps general errors in TTSConnectionException.""" + _reset_singleton() + exc = RuntimeError("unexpected error") + patches, _, _ = _mock_all_models(tts_exc=exc) + for p in patches: + p.start() + try: + service = VoiceService() + with pytest.raises(TTSConnectionException, match="unexpected error"): + await service.generate_tts_speech("Hello world") + finally: + for p in reversed(patches): + p.stop() + + +# --------------------------------------------------------------------------- +# Tests: stream_tts_to_websocket +# --------------------------------------------------------------------------- + +class TestStreamTTSToWebSocket: + """Tests for stream_tts_to_websocket.""" + + def _connected_ws(self): + ws = Mock() + ws.send_bytes = AsyncMock() + ws.send_json = AsyncMock() + state = Mock() + state.name = "CONNECTED" + ws.client_state = state + return ws + + @pytest.mark.asyncio + async def test_success_with_async_iterator(self): + """Test stream_tts_to_websocket correctly handles async iterator from TTS model.""" + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + mock_ws = self._connected_ws() + await service.stream_tts_to_websocket(mock_ws, "Hello world") + assert mock_ws.send_bytes.call_count == 3 + mock_ws.send_json.assert_called_once_with({"status": "completed"}) + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_success_with_sync_iterator(self): + """Test stream_tts_to_websocket handles synchronous iterator from TTS model.""" + + def sync_gen(): + for chunk in [b"sync_1", b"sync_2"]: + yield chunk + + _reset_singleton() + global _shared_tts + _shared_tts = MockTTSModel() + + class SyncIterTTSModel(MockTTSModel): + async def generate_speech(self, text: str, stream: bool = False): + if stream: + return sync_gen() + return b"sync_complete" + + _shared_tts = SyncIterTTSModel() + patches = [ + patch("services.voice_service.VolcTTSModel", return_value=_shared_tts), + patch("services.voice_service.AliTTSModel", return_value=_shared_tts), + patch("services.voice_service.VolcSTTModel", return_value=MockSTTModel()), + patch("services.voice_service.AliSTTModel", return_value=MockSTTModel()), + ] + for p in patches: + p.start() + try: + service = VoiceService() + mock_ws = self._connected_ws() + await service.stream_tts_to_websocket( + mock_ws, "Hello world", tts_config={"api_key": "test"} + ) + assert mock_ws.send_bytes.call_count == 2 + mock_ws.send_json.assert_called_once_with({"status": "completed"}) + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_success_with_single_chunk(self): + """Test stream_tts_to_websocket handles single non-iterable chunk.""" + + class SingleChunkTTSModel: + """Minimal mock that returns bytes directly from generate_speech.""" + + def __init__(self): + self.check_connectivity = AsyncMock(return_value=True) + + async def generate_speech(self, text: str, stream: bool = False): + return b"single_audio_chunk" + + _reset_singleton() + _shared_tts = SingleChunkTTSModel() + patches = [ + patch("services.voice_service.VolcTTSModel", return_value=_shared_tts), + patch("services.voice_service.AliTTSModel", return_value=_shared_tts), + patch("services.voice_service.VolcSTTModel", return_value=MockSTTModel()), + patch("services.voice_service.AliSTTModel", return_value=MockSTTModel()), + ] + for p in patches: + p.start() + try: + service = VoiceService() + mock_ws = self._connected_ws() + await service.stream_tts_to_websocket( + mock_ws, "Hello world", tts_config={"api_key": "test"} + ) + mock_ws.send_json.assert_called_once_with({"status": "completed"}) + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_connection_error_propagates(self): + """Test stream_tts_to_websocket propagates TTSConnectionException.""" + _reset_singleton() + exc = TTSConnectionException("TTS connection failed") + patches, _, _ = _mock_all_models(tts_exc=exc) + for p in patches: + p.start() + try: + service = VoiceService() + mock_ws = self._connected_ws() + with pytest.raises(TTSConnectionException, match="TTS connection failed"): + await service.stream_tts_to_websocket(mock_ws, "Hello world") + finally: + for p in reversed(patches): + p.stop() + + +# --------------------------------------------------------------------------- +# Tests: check_tts_connectivity +# --------------------------------------------------------------------------- + +class TestCheckTTSConnectivity: + """Tests for check_tts_connectivity.""" + + @pytest.mark.asyncio + async def test_success_returns_true(self): + """Test check_tts_connectivity returns True on successful connection.""" + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + result = await service.check_tts_connectivity( + api_key="test_key", + model="qwen3-tts-flash" + ) + assert result is True + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_failure_returns_false(self): + """Test check_tts_connectivity returns False when connectivity check fails.""" + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=False) + for p in patches: + p.start() + try: + service = VoiceService() + result = await service.check_tts_connectivity( + api_key="test_key", + model="qwen3-tts-flash" + ) + assert result is False + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_exception_returns_false(self): + """Test check_tts_connectivity returns False when an exception occurs.""" + _reset_singleton() + exc = RuntimeError("connection timeout") + patches, _, _ = _mock_all_models(tts_exc=exc) + for p in patches: + p.start() + try: + service = VoiceService() + result = await service.check_tts_connectivity( + api_key="test_key" + ) + assert result is False + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_volc_factory_success(self): + """Test check_tts_connectivity with Volcano TTS factory.""" + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + result = await service.check_tts_connectivity( + model_factory="volc", + model_appid="test_appid", + access_token="test_token" + ) + assert result is True + finally: + for p in reversed(patches): + p.stop() + + @pytest.mark.asyncio + async def test_with_speed_ratio(self): + """Test check_tts_connectivity with custom speed_ratio.""" + _reset_singleton() + patches, _, _ = _mock_all_models(tts_success=True) + for p in patches: + p.start() + try: + service = VoiceService() + result = await service.check_tts_connectivity( + api_key="test_key", + speed_ratio=1.5 + ) + assert result is True + finally: + for p in reversed(patches): + p.stop() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/conftest.py b/test/conftest.py index 4ab19b5d7..0f116282a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -7,8 +7,8 @@ import sys from unittest.mock import MagicMock -# Stub out mem0 modules before anything else imports them. -# The sdk imports mem0 at module level, so stubs must be registered first. +# Stub out mem0 and smolagents modules before anything else imports them. +# The sdk imports these at module level, so stubs must be registered first. _mem0_stubs = { "mem0": MagicMock(), "mem0.memory": MagicMock(), @@ -19,9 +19,15 @@ "mem0.configs.embeddings": MagicMock(), "mem0.configs.embeddings.base": MagicMock(), } -for _mod_name in _mem0_stubs: +_smolagents_stubs = { + "smolagents": MagicMock(), + "smolagents.memory": MagicMock(), + "smolagents.models": MagicMock(), +} +_all_stubs = {**_mem0_stubs, **_smolagents_stubs} +for _mod_name in _all_stubs: if _mod_name not in sys.modules: - sys.modules[_mod_name] = _mem0_stubs[_mod_name] + sys.modules[_mod_name] = _all_stubs[_mod_name] # Add backend and sdk directories to sys.path so that modules can be imported # as `from backend.xxx import ...` and `from sdk.xxx import ...` diff --git a/test/sdk/core/models/test_ali_tts_model.py b/test/sdk/core/models/test_ali_tts_model.py new file mode 100644 index 000000000..4b95a11ca --- /dev/null +++ b/test/sdk/core/models/test_ali_tts_model.py @@ -0,0 +1,1401 @@ +""" +Unit tests for Ali TTS model. + +Tests the AliTTSModel and AliTTSConfig classes. +""" +import pytest +import asyncio +import base64 +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import sys as _sys + +_mock_websockets = MagicMock() +_mock_websockets.connect = MagicMock() + + +class _MockConnectionClosedError(Exception): + def __init__(self, code, reason): + self.code = code + self.reason = reason + super().__init__(reason) + + +_mock_websockets.exceptions.ConnectionClosedError = _MockConnectionClosedError +_mock_websockets.exceptions.WebSocketException = Exception +_mock_websockets.exceptions.ConnectionClosed = _MockConnectionClosedError + +_mock_aiofiles = MagicMock() + + +class _MockAsyncContextManager: + def __init__(self, mock_file): + self.mock_file = mock_file + + async def __aenter__(self): + return self.mock_file + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + +def _mock_aiofiles_open(*args, **kwargs): + mock_file = AsyncMock() + mock_file.read = AsyncMock(return_value=b"mock_audio_data") + return _MockAsyncContextManager(mock_file) + + +_mock_aiofiles.open = _mock_aiofiles_open + +_module_mocks = { + "websockets": _mock_websockets, + "aiofiles": _mock_aiofiles, +} + +with patch.dict(_sys.modules, _module_mocks): + from sdk.nexent.core.models.ali_tts_model import ( + AliTTSModel, + AliTTSConfig, + AliTTSError, + DEFAULT_WS_OPEN_TIMEOUT, + DEFAULT_WS_CLOSE_TIMEOUT, + COSYVOICE_API_URL, + QWEN_REALTIME_API_URL, + ) + + +# ============================================================================ +# AliTTSConfig Tests +# ============================================================================ + +class TestAliTTSConfig: + """Tests for AliTTSConfig.""" + + def test_config_init_default_values(self): + """Test config initialization with default values.""" + config = AliTTSConfig(api_key="test_key") + assert config.api_key == "test_key" + assert config.model == "cosyvoice-v2" + assert config.voice is None + assert config.speech_rate == 1.0 + assert config.pitch_rate == 1.0 + assert config.volume == 50.0 + assert config.ws_url is None + assert config.format == "mp3" + assert config.sample_rate == 16000 + assert config.workspace_id is None + + def test_config_init_custom_values(self): + """Test config initialization with custom values.""" + config = AliTTSConfig( + api_key="custom_key", + model="qwen-tts", + voice="azure_stefanie", + speech_rate=1.5, + pitch_rate=0.9, + volume=75.0, + ws_url="wss://custom.url/ws", + format="pcm", + sample_rate=24000, + workspace_id="ws_123", + ) + assert config.api_key == "custom_key" + assert config.model == "qwen-tts" + assert config.voice == "azure_stefanie" + assert config.speech_rate == 1.5 + assert config.pitch_rate == 0.9 + assert config.volume == 75.0 + assert config.ws_url == "wss://custom.url/ws" + assert config.format == "pcm" + assert config.sample_rate == 24000 + assert config.workspace_id == "ws_123" + + def test_is_realtime_api_true_when_realtime_in_url(self): + """Test is_realtime_api returns True for /realtime in URL.""" + config = AliTTSConfig(api_key="key", ws_url="wss://dashscope.aliyuncs.com/api-ws/v1/realtime") + assert config.is_realtime_api() is True + + def test_is_realtime_api_false_when_no_realtime(self): + """Test is_realtime_api returns False when URL is CosyVoice.""" + config = AliTTSConfig(api_key="key", ws_url="wss://dashscope.aliyuncs.com/api-ws/v1/inference") + assert config.is_realtime_api() is False + + def test_is_realtime_api_false_when_no_ws_url(self): + """Test is_realtime_api returns False when ws_url is None.""" + config = AliTTSConfig(api_key="key") + assert config.is_realtime_api() is False + + def test_is_realtime_api_false_when_empty_ws_url(self): + """Test is_realtime_api returns False when ws_url is empty.""" + config = AliTTSConfig(api_key="key", ws_url="") + assert config.is_realtime_api() is False + + def test_get_api_url_with_explicit_ws_url(self): + """Test get_api_url returns explicit ws_url when set.""" + config = AliTTSConfig(api_key="key", ws_url="wss://custom.url/api") + assert config.get_api_url() == "wss://custom.url/api" + + def test_get_api_url_returns_qwen_when_in_model_name(self): + """Test get_api_url returns Qwen URL when qwen in model name.""" + config = AliTTSConfig(api_key="key", model="qwen-tts-v1") + assert config.get_api_url() == QWEN_REALTIME_API_URL + + def test_get_api_url_returns_qwen_when_realtime_flag(self): + """Test get_api_url returns custom URL when ws_url is explicitly set.""" + config = AliTTSConfig(api_key="key", ws_url="wss://example.com/realtime") + assert config.get_api_url() == "wss://example.com/realtime" + + def test_get_api_url_returns_cosyvoice_default(self): + """Test get_api_url returns CosyVoice URL as default.""" + config = AliTTSConfig(api_key="key", model="cosyvoice-v2") + assert config.get_api_url() == COSYVOICE_API_URL + + def test_get_api_url_returns_cosyvoice_for_other_models(self): + """Test get_api_url returns CosyVoice URL for non-qwen models.""" + config = AliTTSConfig(api_key="key", model="some-other-model") + assert config.get_api_url() == COSYVOICE_API_URL + + +# ============================================================================ +# AliTTSModel Constants Tests +# ============================================================================ + +class TestAliTTSModelConstants: + """Tests for AliTTSModel module constants.""" + + def test_default_ws_open_timeout(self): + """Test DEFAULT_WS_OPEN_TIMEOUT constant.""" + assert DEFAULT_WS_OPEN_TIMEOUT == 60 + + def test_default_ws_close_timeout(self): + """Test DEFAULT_WS_CLOSE_TIMEOUT constant.""" + assert DEFAULT_WS_CLOSE_TIMEOUT == 10 + + def test_cosyvoice_api_url(self): + """Test COSYVOICE_API_URL constant.""" + assert COSYVOICE_API_URL == "wss://dashscope.aliyuncs.com/api-ws/v1/inference" + + def test_qwen_realtime_api_url(self): + """Test QWEN_REALTIME_API_URL constant.""" + assert QWEN_REALTIME_API_URL == "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + + def test_ali_tts_error(self): + """Test AliTTSError exception.""" + err = AliTTSError("Test error message") + assert err.message == "Test error message" + assert str(err) == "Test error message" + + +# ============================================================================ +# AliTTSModel Constructor Tests +# ============================================================================ + +class TestAliTTSModelConstructor: + """Tests for AliTTSModel constructor and initialization.""" + + def test_model_init_cosyvoice(self): + """Test model initialization with CosyVoice model.""" + config = AliTTSConfig(api_key="test_key", model="cosyvoice-v2") + model = AliTTSModel(config) + assert model.config is config + assert model._is_realtime is False + + def test_model_init_qwen(self): + """Test model initialization with Qwen model.""" + config = AliTTSConfig(api_key="test_key", model="qwen-tts-v1") + model = AliTTSModel(config) + assert model._is_realtime is True + + def test_model_init_with_realtime_url(self): + """Test model initialization with realtime URL.""" + config = AliTTSConfig(api_key="test_key", ws_url="wss://example.com/realtime") + model = AliTTSModel(config) + assert model._is_realtime is True + + def test_model_init_with_audio_file_path(self): + """Test model initialization with audio file path.""" + config = AliTTSConfig(api_key="test_key") + model = AliTTSModel(config, audio_file_path="/path/to/audio.mp3") + assert model.audio_file_path == "/path/to/audio.mp3" + + +# ============================================================================ +# AliTTSModel URL and Auth Tests +# ============================================================================ + +class TestAliTTSModelUrlAndAuth: + """Tests for get_websocket_url and get_auth_headers methods.""" + + def test_get_websocket_url_cosyvoice(self): + """Test get_websocket_url returns base URL for CosyVoice.""" + config = AliTTSConfig(api_key="key", model="cosyvoice-v2") + model = AliTTSModel(config) + assert model.get_websocket_url() == COSYVOICE_API_URL + + def test_get_websocket_url_qwen_with_model_param(self): + """Test get_websocket_url appends model param for Qwen.""" + config = AliTTSConfig(api_key="key", model="qwen-tts-v1") + model = AliTTSModel(config) + url = model.get_websocket_url() + assert url.startswith(QWEN_REALTIME_API_URL) + assert "model=qwen-tts-v1" in url + + def test_get_websocket_url_with_explicit_ws_url_no_question_mark(self): + """Test get_websocket_url uses ? when no query in explicit URL.""" + config = AliTTSConfig(api_key="key", ws_url="wss://example.com/realtime") + model = AliTTSModel(config) + url = model.get_websocket_url() + assert "?" in url + assert "model=" in url + + def test_get_websocket_url_with_explicit_ws_url_with_question_mark(self): + """Test get_websocket_url uses & when query already in explicit URL.""" + config = AliTTSConfig(api_key="key", ws_url="wss://example.com/realtime?existing=param") + model = AliTTSModel(config) + url = model.get_websocket_url() + assert "&model=" in url + + def test_get_auth_headers(self): + """Test get_auth_headers returns Bearer token.""" + config = AliTTSConfig(api_key="my_secret_key") + model = AliTTSModel(config) + headers = model.get_auth_headers() + assert "Authorization" in headers + assert headers["Authorization"] == "Bearer my_secret_key" + + +# ============================================================================ +# AliTTSModel CosyVoice Request Construction Tests +# ============================================================================ + +class TestAliTTSModelCosyVoiceRequestConstruction: + """Tests for CosyVoice request construction methods.""" + + def test_cosyvoice_generate_task_id(self): + """Test _cosyvoice_generate_task_id generates valid UUID.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + task_id = model._cosyvoice_generate_task_id() + assert isinstance(task_id, str) + assert len(task_id) == 32 + + def test_cosyvoice_generate_task_id_unique(self): + """Test _cosyvoice_generate_task_id generates unique IDs.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + ids = [model._cosyvoice_generate_task_id() for _ in range(10)] + assert len(set(ids)) == 10 + + def test_cosyvoice_construct_run_task_request(self): + """Test _cosyvoice_construct_run_task_request structure.""" + config = AliTTSConfig( + api_key="key", + model="cosyvoice-v2", + voice="af_abella", + format="mp3", + sample_rate=16000, + volume=60.0, + speech_rate=1.2, + pitch_rate=0.9, + ) + model = AliTTSModel(config) + task_id = "test_task_123" + request = model._cosyvoice_construct_run_task_request(task_id) + + assert request["header"]["action"] == "run-task" + assert request["header"]["task_id"] == task_id + assert request["header"]["streaming"] == "duplex" + assert request["payload"]["task_group"] == "audio" + assert request["payload"]["task"] == "tts" + assert request["payload"]["function"] == "SpeechSynthesizer" + assert request["payload"]["model"] == "cosyvoice-v2" + assert request["payload"]["parameters"]["text_type"] == "PlainText" + assert request["payload"]["parameters"]["voice"] == "af_abella" + assert request["payload"]["parameters"]["format"] == "mp3" + assert request["payload"]["parameters"]["sample_rate"] == 16000 + assert request["payload"]["parameters"]["volume"] == 60 + assert request["payload"]["parameters"]["rate"] == 1.2 + assert request["payload"]["parameters"]["pitch"] == 0.9 + assert request["payload"]["parameters"]["enable_ssml"] is False + + def test_cosyvoice_construct_continue_request(self): + """Test _cosyvoice_construct_continue_request structure.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + task_id = "task_456" + text = "Hello world" + request = model._cosyvoice_construct_continue_request(task_id, text) + + assert request["header"]["action"] == "continue-task" + assert request["header"]["task_id"] == task_id + assert request["header"]["streaming"] == "duplex" + assert request["payload"]["input"]["text"] == text + + def test_cosyvoice_construct_finish_request(self): + """Test _cosyvoice_construct_finish_request structure.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + task_id = "task_789" + request = model._cosyvoice_construct_finish_request(task_id) + + assert request["header"]["action"] == "finish-task" + assert request["header"]["task_id"] == task_id + assert request["header"]["streaming"] == "duplex" + assert request["payload"]["input"] == {} + + +# ============================================================================ +# AliTTSModel CosyVoice Event Parsing Tests +# ============================================================================ + +class TestAliTTSModelCosyVoiceEventParsing: + """Tests for _cosyvoice_parse_event method.""" + + def test_parse_task_started_event(self): + """Test parsing task-started event.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + message = json.dumps({"header": {"event": "task-started", "task_id": "task_123"}}) + result = model._cosyvoice_parse_event(message) + assert result["type"] == "task-started" + assert result["task_id"] == "task_123" + + def test_parse_task_failed_event(self): + """Test parsing task-failed event.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + message = json.dumps({ + "header": {"event": "task-failed", "task_id": "task_123", "error_code": 500, "error_message": "Service error"} + }) + result = model._cosyvoice_parse_event(message) + assert result["type"] == "task-failed" + assert result["task_id"] == "task_123" + assert result["error_code"] == 500 + assert result["error_message"] == "Service error" + + def test_parse_task_finished_event(self): + """Test parsing task-finished event with usage info.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + message = json.dumps({ + "header": {"event": "task-finished", "task_id": "task_456"}, + "payload": {"usage": {"characters": 100}} + }) + result = model._cosyvoice_parse_event(message) + assert result["type"] == "task-finished" + assert result["task_id"] == "task_456" + assert result["characters"] == 100 + + def test_parse_unknown_event(self): + """Test parsing unknown event type.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + message = json.dumps({"header": {"event": "some-unknown-event", "task_id": "task_789"}}) + result = model._cosyvoice_parse_event(message) + assert result["type"] == "some-unknown-event" + + def test_parse_invalid_json(self): + """Test parsing invalid JSON returns unknown type.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + result = model._cosyvoice_parse_event("not valid json {{{") + assert result["type"] == "unknown" + + def test_parse_event_missing_header(self): + """Test parsing event without header.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + message = json.dumps({"payload": {"data": "value"}}) + result = model._cosyvoice_parse_event(message) + assert result["type"] == "" + + +# ============================================================================ +# AliTTSModel Qwen Request Construction Tests +# ============================================================================ + +class TestAliTTSModelQwenRequestConstruction: + """Tests for Qwen Realtime API request construction methods.""" + + def test_qwen_generate_event_id(self): + """Test _qwen_generate_event_id generates valid event ID.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + event_id = model._qwen_generate_event_id() + assert isinstance(event_id, str) + assert event_id.startswith("event_") + assert len(event_id) == 22 # "event_" + 16 hex chars + + def test_qwen_generate_event_id_unique(self): + """Test _qwen_generate_event_id generates unique IDs.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + ids = [model._qwen_generate_event_id() for _ in range(10)] + assert len(set(ids)) == 10 + + def test_qwen_construct_session_update(self): + """Test _qwen_construct_session_update structure.""" + config = AliTTSConfig( + api_key="key", + voice="Cherry", + format="mp3", + sample_rate=24000, + speech_rate=1.5, + volume=80.0, + ) + model = AliTTSModel(config) + request = model._qwen_construct_session_update() + + assert request["type"] == "session.update" + assert "event_id" in request + assert request["session"]["voice"] == "Cherry" + assert request["session"]["mode"] == "server_commit" + assert request["session"]["language_type"] == "Auto" + assert request["session"]["response_format"] == "mp3" + assert request["session"]["sample_rate"] == 24000 + assert request["session"]["speech_rate"] == 1.5 + assert request["session"]["volume"] == 80 + + def test_qwen_construct_session_update_uses_default_voice(self): + """Test _qwen_construct_session_update uses Cherry when voice is None.""" + config = AliTTSConfig(api_key="key", voice=None) + model = AliTTSModel(config) + request = model._qwen_construct_session_update() + assert request["session"]["voice"] == "Cherry" + + def test_qwen_format_to_response_format_mp3(self): + """Test _qwen_format_to_response_format for mp3.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + assert model._qwen_format_to_response_format("mp3") == "mp3" + + def test_qwen_format_to_response_format_pcm(self): + """Test _qwen_format_to_response_format for pcm.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + assert model._qwen_format_to_response_format("pcm") == "pcm" + + def test_qwen_format_to_response_format_wav(self): + """Test _qwen_format_to_response_format for wav.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + assert model._qwen_format_to_response_format("wav") == "wav" + + def test_qwen_format_to_response_format_opus(self): + """Test _qwen_format_to_response_format for opus.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + assert model._qwen_format_to_response_format("opus") == "opus" + + def test_qwen_format_to_response_format_unknown(self): + """Test _qwen_format_to_response_format for unknown format defaults to pcm.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + assert model._qwen_format_to_response_format("flac") == "pcm" + + def test_qwen_construct_text_append(self): + """Test _qwen_construct_text_append structure.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + request = model._qwen_construct_text_append("Hello world") + assert request["type"] == "input_text_buffer.append" + assert "event_id" in request + assert request["text"] == "Hello world" + + def test_qwen_construct_text_commit(self): + """Test _qwen_construct_text_commit structure.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + request = model._qwen_construct_text_commit() + assert request["type"] == "input_text_buffer.commit" + assert "event_id" in request + + def test_qwen_construct_session_finish(self): + """Test _qwen_construct_session_finish structure.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + request = model._qwen_construct_session_finish() + assert request["type"] == "session.finish" + assert "event_id" in request + + +# ============================================================================ +# AliTTSModel Qwen Event Parsing Tests +# ============================================================================ + +class TestAliTTSModelQwenEventParsing: + """Tests for Qwen event parsing methods.""" + + def test_qwen_parse_event_session_created(self): + """Test parsing session.created event.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + message = json.dumps({"type": "session.created", "session_id": "sess_123"}) + result = model._qwen_parse_event(message) + assert result["type"] == "session.created" + assert result["raw"]["session_id"] == "sess_123" + + def test_qwen_parse_event_error(self): + """Test parsing error event.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + message = json.dumps({ + "type": "error", + "error": {"code": "INVALID_PARAM", "message": "Invalid parameter"} + }) + result = model._qwen_parse_event(message) + assert result["type"] == "error" + assert result["error_code"] == "INVALID_PARAM" + assert result["error_message"] == "Invalid parameter" + + def test_qwen_parse_event_response_created(self): + """Test parsing response.created event.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + message = json.dumps({"type": "response.created", "response": {"id": "resp_123"}}) + result = model._qwen_parse_event(message) + assert result["type"] == "response.created" + + def test_qwen_parse_event_response_audio_delta(self): + """Test parsing response.audio.delta event.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + audio_data = base64.b64encode(b"audio_chunk").decode() + message = json.dumps({"type": "response.audio.delta", "delta": audio_data}) + result = model._qwen_parse_event(message) + assert result["type"] == "response.audio.delta" + assert result["raw"]["delta"] == audio_data + + def test_qwen_parse_event_response_audio_done(self): + """Test parsing response.audio.done event.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + message = json.dumps({"type": "response.audio.done"}) + result = model._qwen_parse_event(message) + assert result["type"] == "response.audio.done" + + def test_qwen_parse_event_session_finished(self): + """Test parsing session.finished event.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + message = json.dumps({"type": "session.finished"}) + result = model._qwen_parse_event(message) + assert result["type"] == "session.finished" + + def test_qwen_parse_event_invalid_json(self): + """Test parsing invalid JSON returns unknown type.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + result = model._qwen_parse_event("not json {{{") + assert result["type"] == "unknown" + + def test_qwen_is_terminal_event_response_audio_done(self): + """Test _qwen_is_terminal_event returns True for response.audio.done.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + assert model._qwen_is_terminal_event("response.audio.done") is True + + def test_qwen_is_terminal_event_session_finished(self): + """Test _qwen_is_terminal_event returns True for session.finished.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + assert model._qwen_is_terminal_event("session.finished") is True + + def test_qwen_is_terminal_event_false_for_others(self): + """Test _qwen_is_terminal_event returns False for non-terminal events.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + assert model._qwen_is_terminal_event("session.created") is False + assert model._qwen_is_terminal_event("response.created") is False + assert model._qwen_is_terminal_event("response.audio.delta") is False + + def test_qwen_handle_audio_delta(self): + """Test _qwen_handle_audio_delta decodes base64 audio.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + audio_data = base64.b64encode(b"test_audio_chunk").decode() + event = {"raw": {"delta": audio_data}} + buffer = bytearray() + result = model._qwen_handle_audio_delta(event, buffer, yield_chunks=True) + assert result == b"test_audio_chunk" + assert buffer == bytearray(b"test_audio_chunk") + + def test_qwen_handle_audio_delta_empty_delta(self): + """Test _qwen_handle_audio_delta with empty delta.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + event = {"raw": {"delta": ""}} + buffer = bytearray() + result = model._qwen_handle_audio_delta(event, buffer, yield_chunks=True) + assert result is None + + def test_qwen_handle_audio_delta_buffer_only(self): + """Test _qwen_handle_audio_delta appends to buffer without yielding.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + audio_data = base64.b64encode(b"buffer_only").decode() + event = {"raw": {"delta": audio_data}} + buffer = bytearray() + result = model._qwen_handle_audio_delta(event, buffer, yield_chunks=False) + assert result is None + assert buffer == bytearray(b"buffer_only") + + +# ============================================================================ +# AliTTSModel Generate Speech Tests +# ============================================================================ + +class TestAliTTSModelGenerateSpeech: + """Tests for generate_speech method.""" + + def test_generate_speech_returns_generator_for_qwen_streaming(self): + """Test generate_speech returns async generator for Qwen streaming.""" + config = AliTTSConfig(api_key="key", model="qwen-tts") + model = AliTTSModel(config) + result = model.generate_speech("Hello", stream=True) + import inspect + assert inspect.iscoroutine(result) or inspect.isasyncgenfunction(result) + + def test_generate_speech_returns_generator_for_cosyvoice_streaming(self): + """Test generate_speech returns async generator for CosyVoice streaming.""" + config = AliTTSConfig(api_key="key", model="cosyvoice-v2") + model = AliTTSModel(config) + result = model.generate_speech("Hello", stream=True) + import inspect + assert inspect.iscoroutine(result) or inspect.isasyncgenfunction(result) + + +# ============================================================================ +# AliTTSModel CosyVoice Async Generation Tests +# ============================================================================ + +class TestAliTTSModelCosyVoiceAsyncGeneration: + """Tests for CosyVoice async generation methods.""" + + @pytest.mark.asyncio + async def test_cosyvoice_non_streaming_success(self): + """Test CosyVoice non-streaming generation success. + + The buffer only accumulates bytes messages (actual audio data). + JSON messages like task-finished don't get added to the buffer. + """ + config = AliTTSConfig(api_key="key", model="cosyvoice-v2") + model = AliTTSModel(config) + + audio_data = b"fake_audio_data" + task_started_msg = json.dumps({"header": {"event": "task-started", "task_id": "task_1"}}) + task_finished_msg = json.dumps({"header": {"event": "task-finished", "task_id": "task_1", "payload": {"usage": {"characters": 10}}}}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[task_started_msg, audio_data, task_finished_msg]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await model._generate_cosyvoice_non_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}) + assert result == audio_data + + @pytest.mark.asyncio + async def test_cosyvoice_non_streaming_connection_error(self): + """Test CosyVoice non-streaming with connection error.""" + config = AliTTSConfig(api_key="key", model="cosyvoice-v2") + model = AliTTSModel(config) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(side_effect=Exception("Connection failed")) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + with pytest.raises(Exception, match="Connection failed"): + await model._generate_cosyvoice_non_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}) + + @pytest.mark.asyncio + async def test_cosyvoice_non_streaming_task_failed(self): + """Test CosyVoice non-streaming with task failure.""" + config = AliTTSConfig(api_key="key", model="cosyvoice-v2") + model = AliTTSModel(config) + + task_started_msg = json.dumps({"header": {"event": "task-started", "task_id": "task_1"}}) + task_failed_msg = json.dumps({ + "header": {"event": "task-failed", "task_id": "task_1", "error_message": "Task failed"} + }) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[task_started_msg, task_failed_msg]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + with pytest.raises(AliTTSError, match="Task failed"): + await model._generate_cosyvoice_non_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}) + + @pytest.mark.asyncio + async def test_cosyvoice_non_streaming_timeout(self): + """Test CosyVoice non-streaming with timeout after task starts. + + When a timeout occurs during audio receiving, the loop breaks and + returns whatever audio has been accumulated (empty in this case). + """ + config = AliTTSConfig(api_key="key", model="cosyvoice-v2") + model = AliTTSModel(config) + + task_started_msg = json.dumps({"header": {"event": "task-started", "task_id": "task_1"}}) + + call_count = [0] + + async def recv_with_timeout(): + call_count[0] += 1 + if call_count[0] == 1: + return task_started_msg + else: + raise asyncio.TimeoutError + + mock_ws = AsyncMock() + mock_ws.recv = recv_with_timeout + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await model._generate_cosyvoice_non_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}) + assert result == b"" + + @pytest.mark.asyncio + async def test_cosyvoice_streaming_success(self): + """Test CosyVoice streaming generation success. + + Bytes chunks are yielded as audio data. JSON messages don't get yielded. + Audio chunks should come before task-finished for proper streaming. + """ + config = AliTTSConfig(api_key="key", model="cosyvoice-v2") + model = AliTTSModel(config) + + audio_chunks = [b"chunk1", b"chunk2", b"chunk3"] + task_started_msg = json.dumps({"header": {"event": "task-started", "task_id": "task_1"}}) + task_finished_msg = json.dumps({"header": {"event": "task-finished", "task_id": "task_1"}}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[ + task_started_msg, + audio_chunks[0], + audio_chunks[1], + audio_chunks[2], + task_finished_msg, + ]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + chunks = [] + async for chunk in model._generate_cosyvoice_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}): + chunks.append(chunk) + assert chunks == audio_chunks + + +# ============================================================================ +# AliTTSModel Qwen Realtime Async Generation Tests +# ============================================================================ + +class TestAliTTSModelQwenRealtimeAsyncGeneration: + """Tests for Qwen Realtime API async generation methods.""" + + @pytest.mark.asyncio + async def test_qwen_non_streaming_success(self): + """Test Qwen Realtime non-streaming generation success.""" + config = AliTTSConfig(api_key="key", model="qwen-tts") + model = AliTTSModel(config) + + audio_data = base64.b64encode(b"qwen_audio").decode() + session_created_msg = json.dumps({"type": "session.created"}) + response_created_msg = json.dumps({"type": "response.created"}) + audio_delta_msg = json.dumps({"type": "response.audio.delta", "delta": audio_data}) + audio_done_msg = json.dumps({"type": "response.audio.done"}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[ + session_created_msg, + response_created_msg, + audio_delta_msg, + audio_done_msg, + ]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await model._generate_qwen_realtime_non_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}) + assert result == b"qwen_audio" + + @pytest.mark.asyncio + async def test_qwen_non_streaming_session_error(self): + """Test Qwen Realtime non-streaming with session error.""" + config = AliTTSConfig(api_key="key", model="qwen-tts") + model = AliTTSModel(config) + + error_msg = json.dumps({"type": "error", "error": {"message": "Session error"}}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[error_msg]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + with pytest.raises(AliTTSError, match="Session error"): + await model._generate_qwen_realtime_non_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}) + + @pytest.mark.asyncio + async def test_qwen_non_streaming_connection_error(self): + """Test Qwen Realtime non-streaming with connection error.""" + config = AliTTSConfig(api_key="key", model="qwen-tts") + model = AliTTSModel(config) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(side_effect=Exception("Connection failed")) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + with pytest.raises(Exception, match="Connection failed"): + await model._generate_qwen_realtime_non_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}) + + @pytest.mark.asyncio + async def test_qwen_non_streaming_empty_audio(self): + """Test Qwen Realtime non-streaming with no audio data.""" + config = AliTTSConfig(api_key="key", model="qwen-tts") + model = AliTTSModel(config) + + session_created_msg = json.dumps({"type": "session.created"}) + response_created_msg = json.dumps({"type": "response.created"}) + audio_done_msg = json.dumps({"type": "response.audio.done"}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[ + session_created_msg, + response_created_msg, + audio_done_msg, + ]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await model._generate_qwen_realtime_non_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}) + assert result == b"" + + @pytest.mark.asyncio + async def test_qwen_non_streaming_multiple_audio_chunks(self): + """Test Qwen Realtime non-streaming with multiple audio chunks.""" + config = AliTTSConfig(api_key="key", model="qwen-tts") + model = AliTTSModel(config) + + audio1 = base64.b64encode(b"chunk1").decode() + audio2 = base64.b64encode(b"chunk2").decode() + session_created_msg = json.dumps({"type": "session.created"}) + response_created_msg = json.dumps({"type": "response.created"}) + audio_delta1 = json.dumps({"type": "response.audio.delta", "delta": audio1}) + audio_delta2 = json.dumps({"type": "response.audio.delta", "delta": audio2}) + audio_done_msg = json.dumps({"type": "response.audio.done"}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[ + session_created_msg, + response_created_msg, + audio_delta1, + audio_delta2, + audio_done_msg, + ]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await model._generate_qwen_realtime_non_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}) + assert result == b"chunk1chunk2" + + @pytest.mark.asyncio + async def test_qwen_streaming_success(self): + """Test Qwen Realtime streaming generation success.""" + config = AliTTSConfig(api_key="key", model="qwen-tts") + model = AliTTSModel(config) + + audio1 = base64.b64encode(b"stream1").decode() + audio2 = base64.b64encode(b"stream2").decode() + session_created_msg = json.dumps({"type": "session.created"}) + response_created_msg = json.dumps({"type": "response.created"}) + audio_delta1 = json.dumps({"type": "response.audio.delta", "delta": audio1}) + audio_delta2 = json.dumps({"type": "response.audio.delta", "delta": audio2}) + audio_done_msg = json.dumps({"type": "response.audio.done"}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[ + session_created_msg, + response_created_msg, + audio_delta1, + audio_delta2, + audio_done_msg, + ]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + chunks = [] + async for chunk in model._generate_qwen_realtime_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}): + chunks.append(chunk) + assert chunks == [b"stream1", b"stream2"] + + @pytest.mark.asyncio + async def test_qwen_streaming_error_event(self): + """Test Qwen Realtime streaming with error event.""" + config = AliTTSConfig(api_key="key", model="qwen-tts") + model = AliTTSModel(config) + + session_created_msg = json.dumps({"type": "session.created"}) + error_msg = json.dumps({"type": "error", "error": {"message": "Streaming error"}}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[session_created_msg, error_msg]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + with pytest.raises(AliTTSError, match="Streaming error"): + async for _ in model._generate_qwen_realtime_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}): + pass + + @pytest.mark.asyncio + async def test_qwen_streaming_session_finished_before_response(self): + """Test Qwen Realtime streaming with session.finished before response.created. + + When session.finished comes before response.created, no audio chunks are yielded. + The async generator will raise StopAsyncIteration when exhausted. + """ + config = AliTTSConfig(api_key="key", model="qwen-tts") + model = AliTTSModel(config) + + session_created_msg = json.dumps({"type": "session.created"}) + session_finished_msg = json.dumps({"type": "session.finished"}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[session_created_msg, session_finished_msg]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + chunks = [] + with pytest.raises(RuntimeError, match="async generator"): + async for chunk in model._generate_qwen_realtime_streaming("Hello", "wss://test", {"Authorization": "Bearer key"}): + chunks.append(chunk) + + @pytest.mark.asyncio + async def test_qwen_receive_audio_handles_binary_messages(self): + """Test _qwen_receive_audio passes through binary messages.""" + config = AliTTSConfig(api_key="key", model="qwen-tts") + model = AliTTSModel(config) + + audio_done_msg = json.dumps({"type": "response.audio.done"}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[ + b"binary_audio_data", + audio_done_msg, + ]) + + chunks = [] + async for chunk in model._qwen_receive_audio(mock_ws, yield_chunks=True): + chunks.append(chunk) + assert chunks == [b"binary_audio_data"] + + +# ============================================================================ +# AliTTSModel Base Class Tests +# ============================================================================ + +class TestAliTTSModelBaseClass: + """Tests for base class methods in AliTTSModel.""" + + def test_is_tts_result_successful_with_bytes(self): + """Test _is_tts_result_successful with bytes.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + assert model._is_tts_result_successful(b"audio_data") is True + assert model._is_tts_result_successful(b"") is False + + def test_is_tts_result_successful_with_dict(self): + """Test _is_tts_result_successful with dict.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + assert model._is_tts_result_successful({"audio": "data"}) is True + assert model._is_tts_result_successful({"text": "result"}) is True + assert model._is_tts_result_successful({"error": "error"}) is False + + def test_is_tts_result_successful_invalid_types(self): + """Test _is_tts_result_successful with invalid types.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + assert model._is_tts_result_successful(None) is False + assert model._is_tts_result_successful("string") is False + assert model._is_tts_result_successful(123) is False + + def test_extract_tts_error_message(self): + """Test _extract_tts_error_message.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + assert model._extract_tts_error_message({"error": "test error"}) == "test error" + assert model._extract_tts_error_message({"message": "msg error"}) == "msg error" + assert "Unknown error" in model._extract_tts_error_message({"data": "value"}) + + +# ============================================================================ +# AliTTSModel Connectivity Tests +# ============================================================================ + +class TestAliTTSModelConnectivity: + """Tests for check_connectivity method.""" + + @pytest.mark.asyncio + async def test_check_connectivity_returns_false_when_no_audio_path(self): + """Test check_connectivity returns False when no audio_file_path and no speech generated.""" + config = AliTTSConfig(api_key="key", model="cosyvoice-v2") + model = AliTTSModel(config) + model.audio_file_path = None + + task_started_msg = json.dumps({"header": {"event": "task-started", "task_id": "task_1"}}) + task_finished_msg = json.dumps({"header": {"event": "task-finished", "task_id": "task_1"}}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[task_started_msg, task_finished_msg, b""]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await model.check_connectivity() + assert result is False + + @pytest.mark.asyncio + async def test_check_connectivity_returns_true_with_audio(self): + """Test check_connectivity returns True when audio is generated.""" + config = AliTTSConfig(api_key="key", model="cosyvoice-v2") + model = AliTTSModel(config) + + task_started_msg = json.dumps({"header": {"event": "task-started", "task_id": "task_1"}}) + task_finished_msg = json.dumps({"header": {"event": "task-finished", "task_id": "task_1"}}) + audio_data = b"some_audio_data" + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[ + task_started_msg, + audio_data, + task_finished_msg, + ]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await model.check_connectivity() + assert result is True + + @pytest.mark.asyncio + async def test_check_connectivity_returns_false_on_ali_tts_error(self): + """Test check_connectivity returns False on AliTTSError.""" + config = AliTTSConfig(api_key="key", model="cosyvoice-v2") + model = AliTTSModel(config) + + task_started_msg = json.dumps({"header": {"event": "task-started", "task_id": "task_1"}}) + task_failed_msg = json.dumps({ + "header": {"event": "task-failed", "task_id": "task_1", "error_message": "Task failed"} + }) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[task_started_msg, task_failed_msg]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await model.check_connectivity() + assert result is False + + @pytest.mark.asyncio + async def test_check_connectivity_returns_false_on_generic_exception(self): + """Test check_connectivity returns False on generic exception.""" + config = AliTTSConfig(api_key="key", model="cosyvoice-v2") + model = AliTTSModel(config) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(side_effect=RuntimeError("Unexpected error")) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await model.check_connectivity() + assert result is False + + @pytest.mark.asyncio + async def test_check_connectivity_qwen_realtime(self): + """Test check_connectivity with Qwen Realtime API.""" + config = AliTTSConfig(api_key="key", model="qwen-tts") + model = AliTTSModel(config) + + audio_data = base64.b64encode(b"qwen_connectivity_audio").decode() + session_created_msg = json.dumps({"type": "session.created"}) + response_created_msg = json.dumps({"type": "response.created"}) + audio_delta_msg = json.dumps({"type": "response.audio.delta", "delta": audio_data}) + audio_done_msg = json.dumps({"type": "response.audio.done"}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[ + session_created_msg, + response_created_msg, + audio_delta_msg, + audio_done_msg, + ]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await model.check_connectivity() + assert result is True + + +# ============================================================================ +# AliTTSModel Async Helper Methods Tests +# ============================================================================ + +class TestAliTTSModelAsyncHelpers: + """Tests for async helper methods.""" + + @pytest.mark.asyncio + async def test_cosyvoice_wait_for_task_started_success(self): + """Test _cosyvoice_wait_for_task_started returns True on task-started.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + task_started_msg = json.dumps({"header": {"event": "task-started", "task_id": "task_1"}}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[task_started_msg]) + + result = await model._cosyvoice_wait_for_task_started(mock_ws) + assert result is True + + @pytest.mark.asyncio + async def test_cosyvoice_wait_for_task_started_raises_on_failure(self): + """Test _cosyvoice_wait_for_task_started raises AliTTSError on task-failed.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + task_failed_msg = json.dumps({ + "header": {"event": "task-failed", "task_id": "task_1", "error_message": "Service unavailable"} + }) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[task_failed_msg]) + + with pytest.raises(AliTTSError, match="Service unavailable"): + await model._cosyvoice_wait_for_task_started(mock_ws) + + @pytest.mark.asyncio + async def test_cosyvoice_wait_for_task_started_skips_binary(self): + """Test _cosyvoice_wait_for_task_started skips binary messages.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + task_started_msg = json.dumps({"header": {"event": "task-started", "task_id": "task_1"}}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[ + b"binary_data", + task_started_msg, + ]) + + result = await model._cosyvoice_wait_for_task_started(mock_ws) + assert result is True + + @pytest.mark.asyncio + async def test_qwen_wait_for_session_created_success(self): + """Test _qwen_wait_for_session_created returns True on session.created.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + session_created_msg = json.dumps({"type": "session.created"}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[session_created_msg]) + + result = await model._qwen_wait_for_session_created(mock_ws) + assert result is True + + @pytest.mark.asyncio + async def test_qwen_wait_for_session_created_raises_on_error(self): + """Test _qwen_wait_for_session_created raises on error event.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + error_msg = json.dumps({"type": "error", "error": {"message": "Session error"}}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[error_msg]) + + with pytest.raises(AliTTSError, match="Session error"): + await model._qwen_wait_for_session_created(mock_ws) + + @pytest.mark.asyncio + async def test_qwen_wait_for_session_created_skips_binary(self): + """Test _qwen_wait_for_session_created skips binary messages.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + session_created_msg = json.dumps({"type": "session.created"}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[ + b"binary_data", + b"more_binary", + session_created_msg, + ]) + + result = await model._qwen_wait_for_session_created(mock_ws) + assert result is True + + @pytest.mark.asyncio + async def test_qwen_wait_for_response_created_success(self): + """Test _qwen_wait_for_response_created returns True on response.created.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + response_created_msg = json.dumps({"type": "response.created"}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[response_created_msg]) + + result = await model._qwen_wait_for_response_created(mock_ws) + assert result is True + + @pytest.mark.asyncio + async def test_qwen_wait_for_response_created_raises_on_error(self): + """Test _qwen_wait_for_response_created raises on error event.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + error_msg = json.dumps({"type": "error", "error": {"message": "Response error"}}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[error_msg]) + + with pytest.raises(AliTTSError, match="Response error"): + await model._qwen_wait_for_response_created(mock_ws) + + @pytest.mark.asyncio + async def test_qwen_wait_for_response_created_returns_false_on_session_finished(self): + """Test _qwen_wait_for_response_created returns False when session finishes early.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + session_finished_msg = json.dumps({"type": "session.finished"}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[session_finished_msg]) + + result = await model._qwen_wait_for_response_created(mock_ws) + assert result is False + + @pytest.mark.asyncio + async def test_cosyvoice_receive_audio_with_buffer(self): + """Test _cosyvoice_receive_audio accumulates audio in buffer.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + task_finished_msg = json.dumps({"header": {"event": "task-finished", "task_id": "task_1"}}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[ + b"audio_chunk1", + b"audio_chunk2", + task_finished_msg, + ]) + + buffer = bytearray() + received = [] + async for chunk in model._cosyvoice_receive_audio(mock_ws, buffer=buffer, yield_chunks=False): + received.append(chunk) + assert buffer == bytearray(b"audio_chunk1audio_chunk2") + assert received == [] + + @pytest.mark.asyncio + async def test_cosyvoice_receive_audio_yields_chunks(self): + """Test _cosyvoice_receive_audio yields chunks when requested.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + task_finished_msg = json.dumps({"header": {"event": "task-finished", "task_id": "task_1"}}) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[ + b"yield_chunk1", + b"yield_chunk2", + task_finished_msg, + ]) + + chunks = [] + async for chunk in model._cosyvoice_receive_audio(mock_ws, yield_chunks=True): + chunks.append(chunk) + assert chunks == [b"yield_chunk1", b"yield_chunk2"] + + @pytest.mark.asyncio + async def test_cosyvoice_receive_audio_task_failed(self): + """Test _cosyvoice_receive_audio raises on task-failed.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + task_failed_msg = json.dumps({ + "header": {"event": "task-failed", "task_id": "task_1", "error_message": "Task failed"} + }) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[task_failed_msg]) + + with pytest.raises(AliTTSError, match="Task failed"): + async for _ in model._cosyvoice_receive_audio(mock_ws, yield_chunks=True): + pass + + @pytest.mark.asyncio + async def test_cosyvoice_receive_audio_timeout(self): + """Test _cosyvoice_receive_audio handles timeout.""" + config = AliTTSConfig(api_key="key") + model = AliTTSModel(config) + + mock_ws = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=asyncio.TimeoutError()) + + chunks = [] + async for chunk in model._cosyvoice_receive_audio(mock_ws, yield_chunks=True): + chunks.append(chunk) + assert chunks == [] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/sdk/core/models/test_tts_model.py b/test/sdk/core/models/test_tts_model.py new file mode 100644 index 000000000..57c8429b5 --- /dev/null +++ b/test/sdk/core/models/test_tts_model.py @@ -0,0 +1,201 @@ +""" +Tests for BaseTTSModel abstract class. +""" +import pytest +from typing import Dict + +from sdk.nexent.core.models.tts_model import BaseTTSModel + + +class ConcreteTTSModel(BaseTTSModel): + """Concrete implementation of BaseTTSModel for testing.""" + + def get_websocket_url(self) -> str: + return "wss://test.com" + + def get_auth_headers(self) -> Dict[str, str]: + return {} + + async def generate_speech(self, text: str, stream: bool = False): + return b"test" + + async def check_connectivity(self) -> bool: + return True + + +class TestTTSModelConstructor: + """Test TTSModel constructor.""" + + def test_init_with_audio_file_path(self): + """Test initialization with audio_file_path set.""" + model = ConcreteTTSModel(audio_file_path="/path/to/audio.wav") + + assert model.audio_file_path == "/path/to/audio.wav" + + def test_init_without_audio_file_path(self): + """Test initialization with audio_file_path as None.""" + model = ConcreteTTSModel() + + assert model.audio_file_path is None + + def test_init_with_none_explicit(self): + """Test initialization with explicit None value.""" + model = ConcreteTTSModel(audio_file_path=None) + + assert model.audio_file_path is None + + +class TestIsTTSResultSuccessful: + """Test _is_tts_result_successful method.""" + + @pytest.fixture + def model(self): + return ConcreteTTSModel() + + @pytest.mark.parametrize("data", [b"audio data", b"\x00\x01\x02", b"hello world"]) + def test_bytes_with_data_returns_true(self, model, data): + """Test that non-empty bytes return True.""" + assert model._is_tts_result_successful(data) is True + + def test_bytes_empty_returns_false(self, model): + """Test that empty bytes return False.""" + assert model._is_tts_result_successful(b"") is False + + def test_dict_with_audio_key_returns_true(self, model): + """Test that dict with 'audio' key returns True.""" + result = {"audio": b"audio_data", "format": "pcm"} + assert model._is_tts_result_successful(result) is True + + def test_dict_with_text_key_returns_true(self, model): + """Test that dict with 'text' key returns True.""" + result = {"text": "transcribed text"} + assert model._is_tts_result_successful(result) is True + + def test_dict_with_both_audio_and_text_returns_true(self, model): + """Test that dict with both 'audio' and 'text' keys returns True.""" + result = {"audio": b"data", "text": "some text"} + assert model._is_tts_result_successful(result) is True + + def test_dict_with_error_key_returns_false(self, model): + """Test that dict with 'error' key returns False regardless of other keys.""" + result = {"error": "something went wrong"} + assert model._is_tts_result_successful(result) is False + + def test_dict_with_error_and_audio_returns_false(self, model): + """Test that dict with both 'error' and 'audio' keys returns False.""" + result = {"error": "error message", "audio": b"data"} + assert model._is_tts_result_successful(result) is False + + def test_dict_with_message_key_returns_true(self, model): + """Test that dict with 'message' key (without 'error') returns True.""" + result = {"message": "some message"} + assert model._is_tts_result_successful(result) is True + + def test_dict_with_only_other_keys_returns_false(self, model): + """Test that dict with only other keys returns False.""" + result = {"status": "ok", "code": 200} + assert model._is_tts_result_successful(result) is False + + def test_dict_empty_returns_false(self, model): + """Test that empty dict returns False.""" + assert model._is_tts_result_successful({}) is False + + def test_none_returns_false(self, model): + """Test that None returns False.""" + assert model._is_tts_result_successful(None) is False + + def test_string_returns_false(self, model): + """Test that string returns False.""" + assert model._is_tts_result_successful("audio data") is False + + def test_empty_string_returns_false(self, model): + """Test that empty string returns False.""" + assert model._is_tts_result_successful("") is False + + def test_list_returns_false(self, model): + """Test that list returns False.""" + assert model._is_tts_result_successful([b"data"]) is False + + def test_int_returns_false(self, model): + """Test that integer returns False.""" + assert model._is_tts_result_successful(42) is False + + def test_bool_true_returns_false(self, model): + """Test that True returns False.""" + assert model._is_tts_result_successful(True) is False + + def test_bool_false_returns_false(self, model): + """Test that False returns False.""" + assert model._is_tts_result_successful(False) is False + + +class TestExtractTTSErrorMessage: + """Test _extract_tts_error_message method.""" + + @pytest.fixture + def model(self): + return ConcreteTTSModel() + + def test_dict_with_error_key(self, model): + """Test extraction from dict with 'error' key.""" + result = {"error": "Something went wrong"} + assert model._extract_tts_error_message(result) == "Something went wrong" + + def test_dict_with_error_key_non_string(self, model): + """Test extraction from dict with 'error' key containing non-string value.""" + result = {"error": 12345} + assert model._extract_tts_error_message(result) == "12345" + + def test_dict_with_error_key_none(self, model): + """Test extraction from dict with 'error' key set to None.""" + result = {"error": None} + assert model._extract_tts_error_message(result) == "None" + + def test_dict_with_message_key(self, model): + """Test extraction from dict with 'message' key (when no 'error' key).""" + result = {"message": "User requested cancellation"} + assert model._extract_tts_error_message(result) == "User requested cancellation" + + def test_dict_with_message_key_non_string(self, model): + """Test extraction from dict with 'message' key containing non-string value.""" + result = {"message": 500} + assert model._extract_tts_error_message(result) == "500" + + def test_dict_with_error_and_message_keys(self, model): + """Test that 'error' key takes precedence over 'message' key.""" + result = {"error": "Error message", "message": "Message text"} + assert model._extract_tts_error_message(result) == "Error message" + + def test_dict_with_only_other_keys(self, model): + """Test extraction from dict with only other keys.""" + result = {"status": "failed", "code": 404} + assert "Unknown error in result" in model._extract_tts_error_message(result) + assert "404" in model._extract_tts_error_message(result) + + def test_dict_empty(self, model): + """Test extraction from empty dict.""" + message = model._extract_tts_error_message({}) + assert "Unknown error in result" in message + + def test_none(self, model): + """Test extraction from None.""" + message = model._extract_tts_error_message(None) + assert "Unknown error in result" in message + assert "None" in message + + def test_string(self, model): + """Test extraction from string.""" + message = model._extract_tts_error_message("just a string") + assert "Unknown error in result" in message + assert "just a string" in message + + def test_bytes(self, model): + """Test extraction from bytes.""" + message = model._extract_tts_error_message(b"audio data") + assert "Unknown error in result" in message + + def test_int(self, model): + """Test extraction from integer.""" + message = model._extract_tts_error_message(42) + assert "Unknown error in result" in message + assert "42" in message diff --git a/test/sdk/core/models/test_volc_tts_model.py b/test/sdk/core/models/test_volc_tts_model.py new file mode 100644 index 000000000..72e7f0e49 --- /dev/null +++ b/test/sdk/core/models/test_volc_tts_model.py @@ -0,0 +1,894 @@ +""" +Unit tests for Volcano TTS model. + +Tests the VolcTTSModel and VolcTTSConfig classes. +""" +import gzip +import io +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +import sys as _sys + +_mock_websockets = MagicMock() +_mock_websockets.connect = MagicMock() +_mock_websockets.exceptions = MagicMock() + + +class _MockConnectionClosedError(Exception): + def __init__(self, code, reason): + self.code = code + self.reason = reason + super().__init__(reason) + + +_mock_websockets.exceptions.ConnectionClosedError = _MockConnectionClosedError +_mock_websockets.exceptions.WebSocketException = Exception +_mock_websockets.exceptions.ConnectionClosed = _MockConnectionClosedError + +_mock_aiofiles = MagicMock() + +_module_mocks = { + "websockets": _mock_websockets, + "aiofiles": _mock_aiofiles, +} + +with patch.dict(_sys.modules, _module_mocks): + from sdk.nexent.core.models.volc_tts_model import ( + VolcTTSModel, + VolcTTSConfig, + BaseTTSModel, + ) + + +class TestVolcTTSConfig: + """Tests for VolcTTSConfig.""" + + def test_config_init_default_values(self): + """Test config initialization with default values.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + assert config.appid == "test_appid" + assert config.token == "test_token" + assert config.speed_ratio == 1.0 + assert config.ws_url == "wss://openspeech.bytedance.com/api/v1/tts/ws_binary" + assert config.host == "openspeech.bytedance.com" + assert config.encoding == "mp3" + assert config.volume_ratio == 1.0 + assert config.pitch_ratio == 1.0 + assert config.cluster == "volcano_tts" + assert config.resource_id == "seed-tts-2.0" + assert config.voice_type == "zh_female_vv_uranus_bigtts" + + def test_config_init_custom_values(self): + """Test config initialization with custom values.""" + config = VolcTTSConfig( + appid="custom_appid", + token="custom_token", + speed_ratio=2.0, + ws_url="wss://custom.url", + host="custom.host.com", + encoding="wav", + volume_ratio=0.8, + pitch_ratio=0.5, + cluster="custom_cluster", + resource_id="custom_resource", + voice_type="custom_voice", + ) + assert config.appid == "custom_appid" + assert config.token == "custom_token" + assert config.speed_ratio == 2.0 + assert config.ws_url == "wss://custom.url" + assert config.host == "custom.host.com" + assert config.encoding == "wav" + assert config.volume_ratio == 0.8 + assert config.pitch_ratio == 0.5 + assert config.cluster == "custom_cluster" + assert config.resource_id == "custom_resource" + assert config.voice_type == "custom_voice" + + def test_api_url_property(self): + """Test that api_url property returns ws_url.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + assert config.api_url == config.ws_url + custom_ws_url = "wss://custom.tts.url" + config.ws_url = custom_ws_url + assert config.api_url == custom_ws_url + + +class TestVolcTTSModelProtocolConstants: + """Tests for protocol constants.""" + + def test_message_types(self): + """Test MESSAGE_TYPES constant mapping.""" + assert VolcTTSModel.MESSAGE_TYPES == { + 11: "audio-only server response", + 12: "frontend server response", + 15: "error message from server", + } + + def test_message_type_specific_flags(self): + """Test MESSAGE_TYPE_SPECIFIC_FLAGS constant mapping.""" + assert VolcTTSModel.MESSAGE_TYPE_SPECIFIC_FLAGS == { + 0: "no sequence number", + 1: "sequence number > 0", + 2: "last message from server (seq < 0)", + 3: "sequence number < 0", + } + + def test_message_serialization_methods(self): + """Test MESSAGE_SERIALIZATION_METHODS constant mapping.""" + assert VolcTTSModel.MESSAGE_SERIALIZATION_METHODS == { + 0: "no serialization", + 1: "JSON", + 15: "custom type", + } + + def test_message_compressions(self): + """Test MESSAGE_COMPRESSIONS constant mapping.""" + assert VolcTTSModel.MESSAGE_COMPRESSIONS == { + 0: "no compression", + 1: "gzip", + 15: "custom compression method", + } + + def test_default_header(self): + """Test DEFAULT_HEADER constant value.""" + assert VolcTTSModel.DEFAULT_HEADER == bytearray([0x11, 0x10, 0x11, 0x00]) + + +class TestVolcTTSModelHeaderGeneration: + """Tests for header generation methods.""" + + def test_get_websocket_url(self): + """Test get_websocket_url returns config api_url.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + assert model.get_websocket_url() == config.api_url + + def test_get_websocket_url_custom(self): + """Test get_websocket_url with custom ws_url.""" + custom_url = "wss://custom.tts.service/api/v1/tts/ws_binary" + config = VolcTTSConfig( + appid="test_appid", + token="test_token", + speed_ratio=1.0, + ws_url=custom_url, + ) + model = VolcTTSModel(config) + assert model.get_websocket_url() == custom_url + + +class TestVolcTTSModelAuthHeaders: + """Tests for authentication headers.""" + + def test_get_auth_headers(self): + """Test get_auth_headers returns correct headers.""" + config = VolcTTSConfig( + appid="test_appid", + token="test_token", + speed_ratio=1.0, + resource_id="test_resource", + ) + model = VolcTTSModel(config) + headers = model.get_auth_headers() + assert "Authorization" in headers + assert headers["Authorization"] == "Bearer; test_token" + assert "X-Api-App-Id" in headers + assert headers["X-Api-App-Id"] == "test_appid" + assert "X-Api-Access-Key" in headers + assert headers["X-Api-Access-Key"] == "test_token" + assert "X-Api-Resource-Id" in headers + assert headers["X-Api-Resource-Id"] == "test_resource" + + def test_get_auth_headers_custom_values(self): + """Test get_auth_headers with custom config values.""" + config = VolcTTSConfig( + appid="custom_appid", + token="custom_token", + speed_ratio=1.0, + resource_id="custom_resource_id", + ) + model = VolcTTSModel(config) + headers = model.get_auth_headers() + assert headers["Authorization"] == "Bearer; custom_token" + assert headers["X-Api-App-Id"] == "custom_appid" + assert headers["X-Api-Access-Key"] == "custom_token" + assert headers["X-Api-Resource-Id"] == "custom_resource_id" + + +class TestVolcTTSModelRequestPreparation: + """Tests for request preparation.""" + + def test_prepare_request_submit(self): + """Test _prepare_request with default submit operation.""" + config = VolcTTSConfig( + appid="test_appid", + token="test_token", + speed_ratio=1.0, + cluster="test_cluster", + resource_id="test_resource", + voice_type="test_voice", + encoding="mp3", + volume_ratio=1.0, + pitch_ratio=1.0, + ) + model = VolcTTSModel(config) + request = model._prepare_request("Hello world") + assert isinstance(request, bytes) + assert len(request) > 0 + header = request[:4] + assert header == bytes(VolcTTSModel.DEFAULT_HEADER) + + def test_prepare_request_custom_operation(self): + """Test _prepare_request with custom operation.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + request = model._prepare_request("Test text", operation="custom_op") + assert isinstance(request, bytes) + assert len(request) > 0 + + def test_prepare_request_gzip_compressed(self): + """Test that request payload is gzip compressed.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + request = model._prepare_request("Test text") + payload_length = int.from_bytes(request[4:8], "big") + payload = request[8:] + assert len(payload) == payload_length + decompressed = gzip.decompress(payload) + assert b"Test text" in decompressed + + def test_prepare_request_includes_uuid(self): + """Test that request includes a UUID in reqid field.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + request1 = model._prepare_request("Hello") + request2 = model._prepare_request("Hello") + decompressed1 = gzip.decompress(request1[8:]).decode("utf-8") + decompressed2 = gzip.decompress(request2[8:]).decode("utf-8") + assert '"reqid"' in decompressed1 + assert '"reqid"' in decompressed2 + + def test_prepare_request_structure(self): + """Test request JSON structure contains required fields.""" + config = VolcTTSConfig( + appid="test_appid", + token="test_token", + speed_ratio=1.5, + cluster="my_cluster", + resource_id="my_resource", + voice_type="my_voice", + encoding="wav", + volume_ratio=0.8, + pitch_ratio=0.9, + ) + model = VolcTTSModel(config) + request = model._prepare_request("Sample text") + payload = gzip.decompress(request[8:]).decode("utf-8") + import json + parsed = json.loads(payload) + assert "app" in parsed + assert parsed["app"]["appid"] == "test_appid" + assert parsed["app"]["token"] == "test_token" + assert parsed["app"]["cluster"] == "my_cluster" + assert parsed["app"]["resource_id"] == "my_resource" + assert "user" in parsed + assert "audio" in parsed + assert parsed["audio"]["voice_type"] == "my_voice" + assert parsed["audio"]["encoding"] == "wav" + assert parsed["audio"]["speed_ratio"] == 1.5 + assert parsed["audio"]["volume_ratio"] == 0.8 + assert parsed["audio"]["pitch_ratio"] == 0.9 + assert "request" in parsed + assert parsed["request"]["text"] == "Sample text" + assert parsed["request"]["text_type"] == "plain" + + +class TestVolcTTSModelResponseParsing: + """Tests for response parsing.""" + + def _make_audio_response(self, message_type_specific_flags, sequence_number, audio_data=b"audio_chunk"): + header = bytearray([ + 0x10 | (message_type_specific_flags & 0x0f), + 0xb0 | 0x00, + 0x00, + 0x00, + ]) + header[0] = (1 << 4) | 1 + header[1] = (0xb << 4) | message_type_specific_flags + seq_bytes = sequence_number.to_bytes(4, "big", signed=True) + header_size_bytes = len(seq_bytes) + len(audio_data) + 4 + header_size_prefix = header_size_bytes.to_bytes(4, "big") + return bytes(header) + seq_bytes + header_size_prefix + audio_data + + def _make_response_bytes(self, byte0, byte1, payload_data): + header = bytearray([byte0, byte1, 0x00, 0x00]) + return bytes(header) + payload_data + + def test_parse_response_audio_type_flag_0_no_seq(self): + """Test parsing audio-only response with flag 0 (no sequence).""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + done, chunk = model._parse_response(bytes([0x10, 0xb0, 0x00, 0x00]) + b"\x00" * 8) + assert done is False + assert chunk is None + + def test_parse_response_audio_type_with_positive_sequence(self): + """Test parsing audio-only response with positive sequence number.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + header = bytearray([0x11, 0xb1, 0x00, 0x00]) + seq_bytes = (5).to_bytes(4, "big", signed=True) + audio_data = b"test_audio_data" + payload = seq_bytes + (len(audio_data)).to_bytes(4, "big") + audio_data + response = bytes(header) + payload + done, chunk = model._parse_response(response) + assert done is False + assert chunk == audio_data + + def test_parse_response_audio_type_with_negative_sequence(self): + """Test parsing audio-only response with negative sequence number (last message).""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + header = bytearray([0x11, 0xb2, 0x00, 0x00]) + seq_bytes = (-1).to_bytes(4, "big", signed=True) + audio_data = b"final_audio_chunk" + payload = seq_bytes + (len(audio_data)).to_bytes(4, "big") + audio_data + response = bytes(header) + payload + done, chunk = model._parse_response(response) + assert done is True + assert chunk == audio_data + + def test_parse_response_audio_type_flag_3_negative_seq_with_num(self): + """Test parsing audio-only response with flag 3 (sequence number < 0).""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + header = bytearray([0x11, 0xb3, 0x00, 0x00]) + seq_bytes = (-3).to_bytes(4, "big", signed=True) + audio_data = b"chunk_data" + payload = seq_bytes + (len(audio_data)).to_bytes(4, "big") + audio_data + response = bytes(header) + payload + done, chunk = model._parse_response(response) + assert done is True + assert chunk == audio_data + + def test_parse_response_audio_with_buffer(self): + """Test that audio chunks are written to buffer.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + header = bytearray([0x11, 0xb1, 0x00, 0x00]) + seq_bytes = (1).to_bytes(4, "big", signed=True) + audio_data = b"buffered_audio" + payload = seq_bytes + (len(audio_data)).to_bytes(4, "big") + audio_data + response = bytes(header) + payload + buffer = io.BytesIO() + done, chunk = model._parse_response(response, buffer) + assert done is False + assert buffer.getvalue() == audio_data + + def test_parse_response_frontend_type(self): + """Test parsing frontend server response (message type 0xc).""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + header = bytearray([0x11, 0xc0, 0x00, 0x00]) + response = bytes(header) + b"\x00" * 8 + done, chunk = model._parse_response(response) + assert done is True + assert chunk is None + + def test_parse_response_frontend_type_with_flags(self): + """Test parsing frontend server response with various flags.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + for flag in [0, 1, 2, 3]: + header = bytearray([0x11, (0xc << 4) | flag, 0x00, 0x00]) + response = bytes(header) + b"\x00" * 8 + done, chunk = model._parse_response(response) + assert done is True + + def test_parse_response_error_type(self): + """Test parsing error message from server (message type 0xf).""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + header = bytearray([0x11, 0xf0, 0x00, 0x00]) + code_bytes = (1001).to_bytes(4, "big", signed=False) + error_msg = b"Test error message" + payload = code_bytes + (len(error_msg)).to_bytes(4, "big") + error_msg + response = bytes(header) + payload + with pytest.raises(Exception, match="Volc TTS Error 1001"): + model._parse_response(response) + + def test_parse_response_error_type_with_compression(self): + """Test parsing error message with gzip compression.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + header = bytearray([0x11, 0xf0, 0x01, 0x00]) + code_bytes = (2000).to_bytes(4, "big", signed=False) + error_msg = b"Compressed error" + compressed_msg = gzip.compress(error_msg) + payload = code_bytes + (len(compressed_msg)).to_bytes(4, "big") + compressed_msg + response = bytes(header) + payload + with pytest.raises(Exception, match="Volc TTS Error 2000"): + model._parse_response(response) + + def test_parse_response_unknown_type(self): + """Test parsing response with unknown message type returns done=True.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + header = bytearray([0x11, 0xd0, 0x00, 0x00]) + response = bytes(header) + b"\x00" * 8 + done, chunk = model._parse_response(response) + assert done is True + assert chunk is None + + def test_parse_response_header_extraction(self): + """Test that protocol version and header size are correctly extracted.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + header = bytearray([0x11, 0xb1, 0x00, 0x00]) + seq_bytes = (1).to_bytes(4, "big", signed=True) + audio_data = b"test" + payload = seq_bytes + (len(audio_data)).to_bytes(4, "big") + audio_data + response = bytes(header) + payload + done, chunk = model._parse_response(response) + assert done is False + + +class TestVolcTTSModelGenerateSpeechNonStreaming: + """Tests for non-streaming generate_speech.""" + + @pytest.fixture + def volc_config(self): + return VolcTTSConfig( + appid="test_appid", + token="test_token", + speed_ratio=1.0, + ) + + @pytest.fixture + def volc_model(self, volc_config): + return VolcTTSModel(volc_config) + + def _make_audio_response_bytes(self, sequences, audio_chunks): + responses = [] + for i, (seq, audio) in enumerate(zip(sequences, audio_chunks)): + header = bytearray([0x11, 0xb0, 0x00, 0x00]) + header[1] = (0xb << 4) | 0x2 + seq_bytes = seq.to_bytes(4, "big", signed=True) + payload = seq_bytes + (len(audio)).to_bytes(4, "big") + audio + responses.append(bytes(header) + payload) + return responses + + @pytest.mark.asyncio + async def test_generate_speech_non_streaming_success(self, volc_model): + """Test non-streaming generate_speech with successful response.""" + header = bytearray([0x11, 0xb2, 0x00, 0x00]) + seq_bytes = (-1).to_bytes(4, "big", signed=True) + audio_data = b"final_audio_data" + payload = seq_bytes + (len(audio_data)).to_bytes(4, "big") + audio_data + response = bytes(header) + payload + + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[response]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await volc_model.generate_speech("Hello world", stream=False) + assert isinstance(result, bytes) + assert result == audio_data + + @pytest.mark.asyncio + async def test_generate_speech_non_streaming_multiple_chunks(self, volc_model): + """Test non-streaming generate_speech collecting multiple chunks into buffer.""" + header1 = bytearray([0x11, 0xb1, 0x00, 0x00]) + seq_bytes1 = (1).to_bytes(4, "big", signed=True) + audio1 = b"chunk1_" + payload1 = seq_bytes1 + (len(audio1)).to_bytes(4, "big") + audio1 + resp1 = bytes(header1) + payload1 + + header2 = bytearray([0x11, 0xb2, 0x00, 0x00]) + seq_bytes2 = (-1).to_bytes(4, "big", signed=True) + audio2 = b"chunk2_final" + payload2 = seq_bytes2 + (len(audio2)).to_bytes(4, "big") + audio2 + resp2 = bytes(header2) + payload2 + + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[resp1, resp2]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await volc_model.generate_speech("Hello world", stream=False) + assert isinstance(result, bytes) + assert result == b"chunk1_chunk2_final" + + @pytest.mark.asyncio + async def test_generate_speech_non_streaming_connection_error(self, volc_model): + """Test non-streaming generate_speech with connection error.""" + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(side_effect=Exception("Connection failed")) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + with pytest.raises(Exception, match="Connection failed"): + await volc_model.generate_speech("Hello", stream=False) + + +class TestVolcTTSModelGenerateSpeechStreaming: + """Tests for streaming generate_speech.""" + + @pytest.fixture + def volc_config(self): + return VolcTTSConfig( + appid="test_appid", + token="test_token", + speed_ratio=1.0, + ) + + @pytest.fixture + def volc_model(self, volc_config): + return VolcTTSModel(volc_config) + + @pytest.mark.asyncio + async def test_generate_speech_streaming_success(self, volc_model): + """Test streaming generate_speech yields audio chunks.""" + header1 = bytearray([0x11, 0xb1, 0x00, 0x00]) + seq_bytes1 = (1).to_bytes(4, "big", signed=True) + audio1 = b"stream_chunk_1" + payload1 = seq_bytes1 + (len(audio1)).to_bytes(4, "big") + audio1 + resp1 = bytes(header1) + payload1 + + header2 = bytearray([0x11, 0xb2, 0x00, 0x00]) + seq_bytes2 = (-1).to_bytes(4, "big", signed=True) + audio2 = b"stream_chunk_2" + payload2 = seq_bytes2 + (len(audio2)).to_bytes(4, "big") + audio2 + resp2 = bytes(header2) + payload2 + + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[resp1, resp2]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + generator = await volc_model.generate_speech("Hello world", stream=True) + chunks = [] + async for chunk in generator: + chunks.append(chunk) + assert len(chunks) == 2 + assert chunks[0] == audio1 + assert chunks[1] == audio2 + + def test_parse_response_no_sequence_flag(self, volc_model): + """Test _parse_response with no sequence (flag 0) returns done=True, chunk=None. + + When message_type_specific_flags == 0, the parse returns (False, None) + which causes done=True in streaming, ending the loop. + """ + header = bytearray([0x11, 0xb0, 0x00, 0x00]) + response = bytes(header) + b"\x00" * 8 + + done, chunk = volc_model._parse_response(response) + assert done is False + assert chunk is None + + @pytest.mark.asyncio + async def test_generate_speech_streaming_connection_error(self, volc_model): + """Test streaming generate_speech with connection error.""" + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(side_effect=Exception("Connection failed")) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + generator = await volc_model.generate_speech("Hello", stream=True) + chunks = [] + with pytest.raises(Exception, match="Connection failed"): + async for chunk in generator: + chunks.append(chunk) + + @pytest.mark.asyncio + async def test_generate_speech_streaming_error_response(self, volc_model): + """Test streaming generate_speech handles error response.""" + header = bytearray([0x11, 0xf0, 0x00, 0x00]) + code_bytes = (3000).to_bytes(4, "big", signed=False) + error_msg = b"Server error" + payload = code_bytes + (len(error_msg)).to_bytes(4, "big") + error_msg + response = bytes(header) + payload + + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[response]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + generator = await volc_model.generate_speech("Hello", stream=True) + with pytest.raises(Exception, match="Volc TTS Error 3000"): + async for chunk in generator: + pass + + +class TestVolcTTSModelCheckConnectivity: + """Tests for check_connectivity method.""" + + @pytest.fixture + def volc_config(self): + return VolcTTSConfig( + appid="test_appid", + token="test_token", + speed_ratio=1.0, + ) + + @pytest.fixture + def volc_model(self, volc_config): + return VolcTTSModel(volc_config, audio_file_path="/test/audio.mp3") + + @pytest.mark.asyncio + async def test_check_connectivity_success(self, volc_model): + """Test check_connectivity returns True on successful audio generation.""" + header = bytearray([0x11, 0xb2, 0x00, 0x00]) + seq_bytes = (-1).to_bytes(4, "big", signed=True) + audio_data = b"valid_audio_data" + payload = seq_bytes + (len(audio_data)).to_bytes(4, "big") + audio_data + response = bytes(header) + payload + + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[response]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await volc_model.check_connectivity() + assert result is True + + @pytest.mark.asyncio + async def test_check_connectivity_empty_audio(self, volc_model): + """Test check_connectivity returns False when audio is empty.""" + header = bytearray([0x11, 0xb0, 0x00, 0x00]) + response = bytes(header) + b"\x00" * 8 + + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[response]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await volc_model.check_connectivity() + assert result is False + + @pytest.mark.asyncio + async def test_check_connectivity_connection_error(self, volc_model): + """Test check_connectivity returns False on connection error.""" + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(side_effect=Exception("Connection failed")) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await volc_model.check_connectivity() + assert result is False + + @pytest.mark.asyncio + async def test_check_connectivity_no_audio_file_path(self): + """Test check_connectivity with no audio_file_path (uses generate_speech).""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + header = bytearray([0x11, 0xb0, 0x00, 0x00]) + response = bytes(header) + b"\x00" * 8 + + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[response]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await model.check_connectivity() + assert result is False + + +class TestVolcTTSModelBaseClassInheritance: + """Tests for base class method inheritance.""" + + def test_model_inherits_from_base_tts_model(self): + """Test that VolcTTSModel inherits from BaseTTSModel.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + assert isinstance(model, BaseTTSModel) + + def test_is_tts_result_successful_bytes(self): + """Test _is_tts_result_successful with bytes input.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + assert model._is_tts_result_successful(b"audio_data") is True + assert model._is_tts_result_successful(b"") is False + + def test_is_tts_result_successful_dict(self): + """Test _is_tts_result_successful with dict input.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + assert model._is_tts_result_successful({"audio": "data"}) is True + assert model._is_tts_result_successful({"text": "result"}) is True + assert model._is_tts_result_successful({"error": "fail"}) is False + assert model._is_tts_result_successful({}) is False + + def test_is_tts_result_successful_invalid_types(self): + """Test _is_tts_result_successful with invalid types.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + assert model._is_tts_result_successful("string") is False + assert model._is_tts_result_successful(None) is False + assert model._is_tts_result_successful(123) is False + assert model._is_tts_result_successful([]) is False + + def test_extract_tts_error_message(self): + """Test _extract_tts_error_message method.""" + config = VolcTTSConfig(appid="test_appid", token="test_token", speed_ratio=1.0) + model = VolcTTSModel(config) + assert model._extract_tts_error_message({"error": "test_error"}) == "test_error" + assert model._extract_tts_error_message({"message": "msg_error"}) == "msg_error" + result = model._extract_tts_error_message({"code": 500}) + assert "Unknown error" in result + + +class TestVolcTTSModelEdgeCases: + """Tests for edge cases and error conditions.""" + + @pytest.fixture + def volc_config(self): + return VolcTTSConfig( + appid="test_appid", + token="test_token", + speed_ratio=1.0, + ) + + @pytest.fixture + def volc_model(self, volc_config): + return VolcTTSModel(volc_config) + + def test_parse_response_empty_payload(self, volc_model): + """Test parsing response with empty payload after header.""" + header = bytearray([0x11, 0xb1, 0x00, 0x00]) + seq_bytes = (1).to_bytes(4, "big", signed=True) + payload = seq_bytes + (0).to_bytes(4, "big") + response = bytes(header) + payload + done, chunk = volc_model._parse_response(response) + assert done is False + assert chunk == b"" + + def test_parse_response_very_large_audio_chunk(self, volc_model): + """Test parsing response with large audio chunk.""" + header = bytearray([0x11, 0xb1, 0x00, 0x00]) + seq_bytes = (1).to_bytes(4, "big", signed=True) + large_audio = b"x" * 10000 + payload = seq_bytes + (len(large_audio)).to_bytes(4, "big") + large_audio + response = bytes(header) + payload + done, chunk = volc_model._parse_response(response) + assert done is False + assert chunk == large_audio + + def test_prepare_request_empty_text(self, volc_model): + """Test _prepare_request with empty text.""" + request = volc_model._prepare_request("") + assert isinstance(request, bytes) + assert len(request) > 0 + + def test_prepare_request_unicode_text(self, volc_model): + """Test _prepare_request with unicode text.""" + unicode_text = "Hello world with unicode: \u4e2d\u6587 \u043f\u0440\u0438\u0432\u0435\u0442" + request = volc_model._prepare_request(unicode_text) + assert isinstance(request, bytes) + payload = gzip.decompress(request[8:]) + payload_str = payload.decode("utf-8") + assert "Hello world with unicode" in payload_str + assert "\\u4e2d\\u6587" in payload_str or "中文" in payload_str + assert "\\u043f\\u0440\\u0438\\u0432\\u0435\\u0442" in payload_str or "привет" in payload_str + + def test_prepare_request_long_text(self, volc_model): + """Test _prepare_request with long text.""" + long_text = "A" * 10000 + request = volc_model._prepare_request(long_text) + assert isinstance(request, bytes) + assert len(request) > 0 + + def test_config_cluster_and_resource_id(self): + """Test config with cluster and resource_id fields.""" + config = VolcTTSConfig( + appid="test_appid", + token="test_token", + speed_ratio=1.0, + cluster="speech_tts", + resource_id="my-tts-resource", + ) + model = VolcTTSModel(config) + headers = model.get_auth_headers() + assert headers["X-Api-Resource-Id"] == "my-tts-resource" + + @pytest.mark.asyncio + async def test_generate_speech_non_streaming_with_error_response(self, volc_model): + """Test non-streaming generate_speech handles error response.""" + header = bytearray([0x11, 0xf0, 0x00, 0x00]) + code_bytes = (4000).to_bytes(4, "big", signed=False) + error_msg = b"Server error occurred" + payload = code_bytes + (len(error_msg)).to_bytes(4, "big") + error_msg + response = bytes(header) + payload + + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[response]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + with pytest.raises(Exception, match="Volc TTS Error 4000"): + await volc_model.generate_speech("Hello", stream=False) + + @pytest.mark.asyncio + async def test_generate_speech_streaming_frontend_response_stops(self, volc_model): + """Test streaming stops when frontend response (type 0xc) is received.""" + header = bytearray([0x11, 0xc0, 0x00, 0x00]) + response = bytes(header) + b"\x00" * 8 + + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[response]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + generator = await volc_model.generate_speech("Hello", stream=True) + chunks = [] + async for chunk in generator: + chunks.append(chunk) + assert len(chunks) == 0 + + @pytest.mark.asyncio + async def test_generate_speech_non_streaming_mixed_frontend_and_audio(self, volc_model): + """Test non-streaming handles mix of audio and frontend responses.""" + header1 = bytearray([0x11, 0xb1, 0x00, 0x00]) + seq_bytes1 = (1).to_bytes(4, "big", signed=True) + audio1 = b"audio_" + payload1 = seq_bytes1 + (len(audio1)).to_bytes(4, "big") + audio1 + resp1 = bytes(header1) + payload1 + + header2 = bytearray([0x11, 0xc0, 0x00, 0x00]) + resp2 = bytes(header2) + b"\x00" * 8 + + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.recv = AsyncMock(side_effect=[resp1, resp2]) + + mock_connect = AsyncMock() + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + mock_connect.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_mock_websockets, "connect", return_value=mock_connect): + result = await volc_model.generate_speech("Hello", stream=False) + assert isinstance(result, bytes) + assert result == audio1 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 09a6ed0ee648abdbaeb5c0e0a6d5a6e8343133fa Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Mon, 11 May 2026 12:12:44 +0800 Subject: [PATCH 2/5] add test files --- test/conftest.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 0f116282a..17cc4d606 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -19,15 +19,6 @@ "mem0.configs.embeddings": MagicMock(), "mem0.configs.embeddings.base": MagicMock(), } -_smolagents_stubs = { - "smolagents": MagicMock(), - "smolagents.memory": MagicMock(), - "smolagents.models": MagicMock(), -} -_all_stubs = {**_mem0_stubs, **_smolagents_stubs} -for _mod_name in _all_stubs: - if _mod_name not in sys.modules: - sys.modules[_mod_name] = _all_stubs[_mod_name] # Add backend and sdk directories to sys.path so that modules can be imported # as `from backend.xxx import ...` and `from sdk.xxx import ...` From c288430c526927b4fd5df230780838bad9023271 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Mon, 11 May 2026 12:21:33 +0800 Subject: [PATCH 3/5] add test files --- backend/consts/exceptions.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backend/consts/exceptions.py b/backend/consts/exceptions.py index a32f0282e..3f51a87ab 100644 --- a/backend/consts/exceptions.py +++ b/backend/consts/exceptions.py @@ -190,6 +190,12 @@ class STTConnectionException(Exception): pass +class TTSConnectionException(Exception): + """Raised when TTS service connection fails.""" + + pass + + class ToolExecutionException(Exception): """Raised when mcp tool execution failed.""" From dc860e4fe14e8f2c73a2c9b5bd17d31c4b48140a Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Mon, 11 May 2026 12:39:36 +0800 Subject: [PATCH 4/5] add test files --- backend/consts/exceptions.py | 6 ++++++ backend/services/voice_service.py | 9 ++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/backend/consts/exceptions.py b/backend/consts/exceptions.py index 3f51a87ab..859ed3316 100644 --- a/backend/consts/exceptions.py +++ b/backend/consts/exceptions.py @@ -184,6 +184,12 @@ class VoiceServiceException(Exception): pass +class VoiceConfigException(Exception): + """Raised when voice configuration is invalid or missing.""" + + pass + + class STTConnectionException(Exception): """Raised when STT service connection fails.""" diff --git a/backend/services/voice_service.py b/backend/services/voice_service.py index 7d274ff23..ff13054e3 100644 --- a/backend/services/voice_service.py +++ b/backend/services/voice_service.py @@ -553,12 +553,15 @@ async def check_tts_connectivity( connected = await tts_model.check_connectivity() if not connected: - logger.warning("TTS service connectivity check returned False") - return False + msg = "TTS service connectivity check returned False" + logger.warning(msg) + raise TTSConnectionException(msg) return connected + except TTSConnectionException: + raise except Exception as e: logger.error(f"TTS connectivity check failed: {str(e)}") - return False + raise TTSConnectionException(f"TTS connectivity check failed: {str(e)}") from e async def check_voice_connectivity( self, From 8bb2d53f79b15bd757d30246fcd4d0a945d4b987 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Mon, 11 May 2026 15:18:43 +0800 Subject: [PATCH 5/5] add test files --- backend/services/voice_service.py | 5 +++- .../services/test_voice_service_tts.py | 26 +++++++++---------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/backend/services/voice_service.py b/backend/services/voice_service.py index ff13054e3..5a08e1f8b 100644 --- a/backend/services/voice_service.py +++ b/backend/services/voice_service.py @@ -611,7 +611,7 @@ async def check_voice_connectivity( base_url = stt_config.get("base_url") if stt_config else None model = stt_config.get("model", "qwen3-tts-flash") if stt_config else "qwen3-tts-flash" - return await self.check_tts_connectivity( + connected = await self.check_tts_connectivity( model_factory=model_factory, api_key=api_key, model_appid=model_appid, @@ -620,6 +620,9 @@ async def check_voice_connectivity( base_url=base_url, model=model ) + if not connected: + raise TTSConnectionException("TTS service connectivity check returned False") + return connected else: logger.error(f"Unknown model type: {model_type}") raise VoiceServiceException(f"Unknown model type: {model_type}") diff --git a/test/backend/services/test_voice_service_tts.py b/test/backend/services/test_voice_service_tts.py index fcacd4255..4b8cd86a3 100644 --- a/test/backend/services/test_voice_service_tts.py +++ b/test/backend/services/test_voice_service_tts.py @@ -605,26 +605,26 @@ async def test_success_returns_true(self): p.stop() @pytest.mark.asyncio - async def test_failure_returns_false(self): - """Test check_tts_connectivity returns False when connectivity check fails.""" + async def test_failure_raises(self): + """Test check_tts_connectivity raises TTSConnectionException when connectivity check fails.""" _reset_singleton() patches, _, _ = _mock_all_models(tts_success=False) for p in patches: p.start() try: service = VoiceService() - result = await service.check_tts_connectivity( - api_key="test_key", - model="qwen3-tts-flash" - ) - assert result is False + with pytest.raises(TTSConnectionException, match="TTS service connectivity check returned False"): + await service.check_tts_connectivity( + api_key="test_key", + model="qwen3-tts-flash" + ) finally: for p in reversed(patches): p.stop() @pytest.mark.asyncio - async def test_exception_returns_false(self): - """Test check_tts_connectivity returns False when an exception occurs.""" + async def test_exception_raises(self): + """Test check_tts_connectivity raises TTSConnectionException when an exception occurs.""" _reset_singleton() exc = RuntimeError("connection timeout") patches, _, _ = _mock_all_models(tts_exc=exc) @@ -632,10 +632,10 @@ async def test_exception_returns_false(self): p.start() try: service = VoiceService() - result = await service.check_tts_connectivity( - api_key="test_key" - ) - assert result is False + with pytest.raises(TTSConnectionException, match="connection timeout"): + await service.check_tts_connectivity( + api_key="test_key" + ) finally: for p in reversed(patches): p.stop()