Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions vox_box/backends/stt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,16 @@ def transcribe(
**kwargs
):
pass

def is_stream_supported(self) -> bool:
return False

def transcribe_stream(
self,
audio: bytes,
language: Optional[str] = None,
prompt: Optional[str] = None,
temperature: float = 0.2,
**kwargs
):
raise NotImplementedError("Streaming is not supported for this backend")
28 changes: 28 additions & 0 deletions vox_box/backends/stt/faster_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,34 @@ def transcribe(

return response

def is_stream_supported(self) -> bool:
return True

def transcribe_stream(
self,
audio: bytes,
language: Optional[str] = None,
prompt: Optional[str] = None,
temperature: Optional[float] = 0.2,
**kwargs,
):
if language == "auto":
language = None

audio_data = io.BytesIO(audio)
segs, info = self._model.transcribe(
audio_data,
language=language,
initial_prompt=prompt,
temperature=temperature,
without_timestamps=True,
)
Comment thread
thesaadmirza marked this conversation as resolved.

for seg in segs:
text = seg.text.strip()
if text:
yield json.dumps({"text": text})

def _get_languages(self) -> List[Dict]:
return [
{"auto": "auto"},
Expand Down
8 changes: 8 additions & 0 deletions vox_box/backends/tts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ def speech(
**kwargs
):
pass

def is_stream_supported(self) -> bool:
return False

def speech_stream(
self, input: str, voice: Optional[str], speed: float = 1, **kwargs
):
raise NotImplementedError("Streaming is not supported for this backend")
23 changes: 23 additions & 0 deletions vox_box/backends/tts/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,29 @@ def speech(
output_file_path = convert(wav_file_path, reponse_format, speed)
return output_file_path

def is_stream_supported(self) -> bool:
return True

def speech_stream(
self,
input: str,
voice: Optional[str] = None,
speed: float = 1,
**kwargs,
):
if voice not in self._voices:
raise ValueError(f"Voice {voice} not supported")

original_voice = self._get_original_voice(voice)
model_output = self._model.inference_sft(
input, original_voice, stream=True, speed=speed
)
for chunk in model_output:
tts_audio = (
(chunk["tts_speech"].numpy() * (2**15)).astype(np.int16).tobytes()
)
yield tts_audio

def _get_voices(self) -> List[str]:
voices = self._model.list_available_spks()
return [self.language_map.get(voice, voice) for voice in voices]
Expand Down
145 changes: 107 additions & 38 deletions vox_box/server/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import mimetypes
from fastapi import APIRouter, HTTPException, Request, UploadFile
from pydantic import BaseModel
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, StreamingResponse

from vox_box.backends.stt.base import STTBackend
from vox_box.backends.tts.base import TTSBackend
Expand All @@ -30,6 +30,7 @@ class SpeechRequest(BaseModel):
voice: str
response_format: str = "mp3"
speed: float = 1.0
stream: bool = False


@router.post("/v1/audio/speech")
Expand All @@ -55,6 +56,27 @@ async def speech(request: SpeechRequest):
status_code=400, detail="Model instance does not support speech API"
)

if request.stream:
if not model_instance.is_stream_supported():
raise HTTPException(
status_code=400,
detail="Streaming is not supported for this model",
)
if request.response_format != "pcm":
raise HTTPException(
status_code=400,
detail="Streaming only supports pcm response format",
)
gen = model_instance.speech_stream(
request.input, request.voice, request.speed
)
headers = {
"X-Audio-Sample-Rate": "22050",
"X-Audio-Channels": "1",
"X-Audio-Sample-Width": "16",
}
return StreamingResponse(gen, media_type="audio/pcm", headers=headers)

func = functools.partial(
model_instance.speech,
request.input,
Expand All @@ -71,6 +93,8 @@ async def speech(request: SpeechRequest):

media_type = get_media_type(request.response_format)
return FileResponse(audio_file, media_type=media_type)
except HTTPException:
raise
except Exception as e:
return HTTPException(status_code=500, detail=f"Failed to generate speech, {e}")

Expand Down Expand Up @@ -110,39 +134,59 @@ async def speech(request: SpeechRequest):
ALLOWED_TRANSCRIPTIONS_OUTPUT_FORMATS = {"json", "text", "srt", "vtt", "verbose_json"}


@router.post("/v1/audio/transcriptions")
async def transcribe(request: Request):
async def _parse_transcription_form(request: Request):
form = await request.form()
keys = form.keys()
if "file" not in keys:
raise HTTPException(status_code=400, detail="Field file is required")

file: UploadFile = form[
"file"
] # flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm
file_content_type = file.content_type or mimetypes.guess_type(file.filename)[0]
if file_content_type not in ALLOWED_TRANSCRIPTIONS_INPUT_AUDIO_FORMATS:
raise HTTPException(
status_code=400,
detail=f"Unsupported file format: {file_content_type}",
)

audio_bytes = await file.read()
language = form.get("language")
prompt = form.get("prompt")
try:
form = await request.form()
keys = form.keys()
if "file" not in keys:
return HTTPException(status_code=400, detail="Field file is required")

file: UploadFile = form[
"file"
] # flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm
file_content_type = file.content_type or mimetypes.guess_type(file.filename)[0]
if file_content_type not in ALLOWED_TRANSCRIPTIONS_INPUT_AUDIO_FORMATS:
return HTTPException(
status_code=400,
detail=f"Unsupported file format: {file_content_type}",
)
temperature = float(form.get("temperature", 0.2))
except ValueError:
raise HTTPException(status_code=400, detail="Invalid temperature value. It must be a float.")
if not (0 <= temperature <= 1):
raise HTTPException(
status_code=400, detail="Temperature must be between 0 and 1"
)

audio_bytes = await file.read()
language = form.get("language")
prompt = form.get("prompt")
temperature = float(form.get("temperature", 0))
if not (0 <= temperature <= 1):
return HTTPException(
status_code=400, detail="Temperature must be between 0 and 1"
)
stream = str(form.get("stream", "false")).lower() == "true"
timestamp_granularities = form.getlist("timestamp_granularities")
response_format = form.get("response_format", "json")
if response_format not in ALLOWED_TRANSCRIPTIONS_OUTPUT_FORMATS:
raise HTTPException(
status_code=400,
detail=f"Unsupported response_format: {response_format}",
)

timestamp_granularities = form.getlist("timestamp_granularities")
response_format = form.get("response_format", "json")
if response_format not in ALLOWED_TRANSCRIPTIONS_OUTPUT_FORMATS:
return HTTPException(
status_code=400, detail="Unsupported response_format: {response_format}"
)
return {
"audio_bytes": audio_bytes,
"language": language,
"prompt": prompt,
"temperature": temperature,
"stream": stream,
"timestamp_granularities": timestamp_granularities,
"response_format": response_format,
"content_type": file_content_type,
}


@router.post("/v1/audio/transcriptions")
async def transcribe(request: Request):
try:
params = await _parse_transcription_form(request)

model_instance: STTBackend = get_model_instance()
if not isinstance(model_instance, STTBackend):
Expand All @@ -151,17 +195,33 @@ async def transcribe(request: Request):
detail="Model instance does not support transcriptions API",
)

if params["stream"]:
if not model_instance.is_stream_supported():
raise HTTPException(
status_code=400,
detail="Streaming is not supported for this model",
)
kwargs = {"content_type": params["content_type"]}
gen = model_instance.transcribe_stream(
params["audio_bytes"],
params["language"],
params["prompt"],
params["temperature"],
**kwargs,
)
return StreamingResponse(sse_stream(gen), media_type="text/event-stream")

kwargs = {
"content_type": file_content_type,
"content_type": params["content_type"],
}
func = functools.partial(
model_instance.transcribe,
audio_bytes,
language,
prompt,
temperature,
timestamp_granularities,
response_format,
params["audio_bytes"],
params["language"],
params["prompt"],
params["temperature"],
params["timestamp_granularities"],
params["response_format"],
**kwargs,
)

Expand All @@ -171,12 +231,15 @@ async def transcribe(request: Request):
func,
)

response_format = params["response_format"]
if response_format == "json":
return {"text": data}
elif response_format == "text":
return data
else:
return data
except HTTPException:
raise
except Exception as e:
return HTTPException(status_code=500, detail=f"Failed to transcribe audio, {e}")

Expand Down Expand Up @@ -244,3 +307,9 @@ def get_media_type(response_format) -> str:
)

return media_type


def sse_stream(sync_gen):
for item in sync_gen:
yield f"data: {item}\n\n"
yield "data: [DONE]\n\n"