Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 108 additions & 5 deletions backend/apps/voice_app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from http import HTTPStatus

Expand All @@ -7,6 +8,8 @@
from consts.exceptions import (
VoiceServiceException,
STTConnectionException,
TTSConnectionException,
VoiceConfigException,
)
from consts.model import VoiceConnectivityRequest, VoiceConnectivityResponse
from services.voice_service import get_voice_service
Expand Down Expand Up @@ -56,12 +59,97 @@ async def stt_websocket(websocket: WebSocket):
logger.info("STT WebSocket connection closed")


@voice_runtime_router.websocket("/tts/ws")
async def tts_websocket(websocket: WebSocket):
"""WebSocket endpoint for streaming TTS"""
logger.info("TTS WebSocket connection attempt...")
await websocket.accept()
logger.info("TTS WebSocket connection accepted")

try:
# Receive config and text from client
msg = await websocket.receive()
client_config = {}
text = None

if msg["type"] == "websocket.receive":
if "text" in msg:
import json
client_config = json.loads(msg["text"])
text = client_config.get("text")
elif "bytes" in msg:
try:
import json
client_config = json.loads(msg["bytes"].decode('utf-8'))
text = client_config.get("text")
except Exception as e:
logger.warning(f"Failed to parse bytes as JSON: {e}")

if not text:
if websocket.client_state.name == "CONNECTED":
await websocket.send_json({"error": "No text provided"})
return

# Extract config from client
tenant_id = client_config.get("tenant_id")
model_factory = client_config.get("model_factory")
model_name = client_config.get("model_name")
api_key = client_config.get("api_key")
model_appid = client_config.get("model_appid")
access_token = client_config.get("access_token")
base_url = client_config.get("base_url")

logger.info(f"TTS request - model_name: {model_name}, model_factory: {model_factory}, "
f"has_api_key: {bool(api_key)}")

# Build tts_config dict for voice service
tts_config = {
"model_factory": model_factory,
"api_key": api_key,
"model_appid": model_appid,
"access_token": access_token,
"base_url": base_url,
"model_name": model_name,
}

# Stream TTS audio to WebSocket
voice_service = get_voice_service()
await voice_service.stream_tts_to_websocket(
websocket,
text,
tenant_id=tenant_id,
model_name=model_name,
tts_config=tts_config
)

except TTSConnectionException as e:
logger.error(f"TTS WebSocket error: {str(e)}")
await websocket.send_json({"error": str(e)})
except Exception as e:
logger.error(f"TTS WebSocket error: {str(e)}")
await websocket.send_json({"error": str(e)})
finally:
logger.info("TTS WebSocket connection closed")
# Ensure connection is properly closed
if websocket.client_state.name == "CONNECTED":
await websocket.close()


@voice_config_router.post("/connectivity")
async def check_voice_connectivity(request: VoiceConnectivityRequest):
"""Check voice service connectivity."""
"""
Check voice service connectivity

Args:
request: VoiceConnectivityRequest containing model_type

Returns:
VoiceConnectivityResponse with connectivity status
"""
try:
voice_service = get_voice_service()
connected = await voice_service.check_voice_connectivity(request.model_type)

return JSONResponse(
status_code=HTTPStatus.OK,
content=VoiceConnectivityResponse(
Expand All @@ -72,10 +160,25 @@ async def check_voice_connectivity(request: VoiceConnectivityRequest):
)
except VoiceServiceException as e:
logger.error(f"Voice service error: {str(e)}")
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e))
except STTConnectionException as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=str(e)
)
except (STTConnectionException, TTSConnectionException) as e:
logger.error(f"Voice connectivity error: {str(e)}")
raise HTTPException(status_code=HTTPStatus.SERVICE_UNAVAILABLE, detail=str(e))
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail=str(e)
)
except VoiceConfigException as e:
logger.error(f"Voice configuration error: {str(e)}")
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=str(e)
)
except Exception as e:
logger.error(f"Unexpected voice service error: {str(e)}")
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Voice service error")
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Voice service error"
)
12 changes: 12 additions & 0 deletions backend/consts/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,24 @@ class VoiceServiceException(Exception):
pass


class VoiceConfigException(Exception):
"""Raised when voice configuration is invalid or missing."""

pass


class STTConnectionException(Exception):
"""Raised when STT service connection fails."""

pass


class TTSConnectionException(Exception):
"""Raised when TTS service connection fails."""

pass


class ToolExecutionException(Exception):
"""Raised when mcp tool execution failed."""

Expand Down
13 changes: 12 additions & 1 deletion backend/consts/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,24 @@ class STTModelConfig(BaseModel):
accessToken: Optional[str] = None


class TTSModelConfig(BaseModel):
"""TTS model specific configuration with factory, appid, and access token fields"""
modelName: str
displayName: str
apiConfig: Optional[ModelApiConfig] = None
modelFactory: Optional[str] = None
modelAppid: Optional[str] = None
accessToken: Optional[str] = None


class ModelConfig(BaseModel):
llm: SingleModelConfig
embedding: SingleModelConfig
multiEmbedding: SingleModelConfig
rerank: SingleModelConfig
vlm: SingleModelConfig
stt: STTModelConfig
tts: TTSModelConfig


class AppConfig(BaseModel):
Expand Down Expand Up @@ -504,7 +515,7 @@ def default(cls) -> "MemoryAgentShareMode":
class VoiceConnectivityRequest(BaseModel):
"""Request model for voice service connectivity check"""
model_type: str = Field(...,
description="Type of model to check ('stt')")
description="Type of model to check ('stt' or 'tts')")


class VoiceConnectivityResponse(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions backend/services/config_sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def build_model_config(model_config: dict) -> dict:
if "embedding" in model_config.get("model_type", ""):
config["dimension"] = model_config.get("max_tokens", 0)

# Add STT model specific fields
# Add voice model specific fields (STT and TTS)
model_type = model_config.get("model_type", "")
if model_type == "stt":
if model_type == "stt" or model_type == "tts":
config["modelFactory"] = model_config.get("model_factory", "")
config["modelAppid"] = model_config.get("model_appid", "")
config["accessToken"] = model_config.get("access_token", "")
Expand Down
28 changes: 27 additions & 1 deletion backend/services/model_health_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
raise ValueError(f"Unsupported model type: {model_type}")


async def _perform_connectivity_check(

Check failure on line 64 in backend/services/model_health_service.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this function to reduce its Cognitive Complexity from 18 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=ModelEngine-Group_nexent&issues=AZ4VL-gdIgguPulXBK90&open=AZ4VL-gdIgguPulXBK90&pullRequest=2959
model_name: str,
model_type: str,
model_base_url: str,
Expand Down Expand Up @@ -139,7 +139,6 @@
elif model_type == 'stt':
voice_service = get_voice_service()


# Determine STT provider based on model_factory
use_volc = model_factory and model_factory.lower() in ["volcengine", "volcano", "volcengine", "火山引擎"]

Expand All @@ -164,6 +163,33 @@
"model": model_name
}
)
elif model_type == 'tts':
voice_service = get_voice_service()

# Determine TTS provider based on model_factory
use_volc = model_factory and model_factory.lower() in ["volcengine", "volcano", "volcengine", "火山引擎"]

if use_volc:
# Use Volcano TTS with appid and access_token
connectivity = await voice_service.check_voice_connectivity(
model_type="tts",
stt_config={
"model_factory": model_factory,
"model_appid": model_appid,
"access_token": access_token,
"base_url": model_base_url
}
)
else:
# Use Ali TTS (default) with api_key and model name
connectivity = await voice_service.check_voice_connectivity(
model_type="tts",
stt_config={
"api_key": model_api_key,
"base_url": model_base_url,
"model": model_name
}
)
else:
raise ValueError(f"Unsupported model type: {model_type}")

Expand Down
Loading
Loading