Skip to content
256 changes: 194 additions & 62 deletions agent_cli/server/whisper/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,59 @@
from __future__ import annotations

import contextlib
import io
import logging
import wave
from typing import TYPE_CHECKING, Annotated, Any, Literal

from fastapi import FastAPI, File, Form, HTTPException, Query, UploadFile, WebSocket
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel

from agent_cli.server.common import configure_app, create_lifespan, setup_wav_file
from agent_cli.server.common import configure_app, create_lifespan
from agent_cli.server.whisper.backends.base import InvalidAudioError
from agent_cli.services import pcm_to_wav

if TYPE_CHECKING:
from agent_cli.core.vad import VoiceActivityDetector
from agent_cli.server.whisper.model_registry import WhisperModelRegistry

logger = logging.getLogger(__name__)

_EOS_MARKER = b"EOS"


def _parse_eos(data: bytes) -> tuple[bytes, bool]:
"""Parse data for EOS marker, returning (audio_chunk, is_eos)."""
if data == _EOS_MARKER:
return b"", True
if data.endswith(_EOS_MARKER):
return data[: -len(_EOS_MARKER)], True
return data, False


def _create_vad(
threshold: float,
silence_threshold_ms: int,
min_speech_duration_ms: int,
) -> VoiceActivityDetector:
"""Create a VoiceActivityDetector instance.

Raises ImportError if onnxruntime is not available.
"""
from agent_cli.core.vad import VoiceActivityDetector # noqa: PLC0415

try:
return VoiceActivityDetector(
threshold=threshold,
silence_threshold_ms=silence_threshold_ms,
min_speech_duration_ms=min_speech_duration_ms,
)
except ImportError as e:
msg = (
"VAD requires onnxruntime. Install it with: "
"`pip install agent-cli[vad]` or `uv sync --extra vad`"
)
raise ImportError(msg) from e


def _split_seconds(seconds: float) -> tuple[int, int, int, int]:
"""Split seconds into (hours, minutes, seconds, milliseconds)."""
Expand Down Expand Up @@ -316,17 +352,41 @@ async def stream_transcription(
websocket: WebSocket,
model: Annotated[str | None, Query(description="Model to use")] = None,
language: Annotated[str | None, Query(description="Language code")] = None,
use_vad: Annotated[
bool,
Query(description="Enable VAD for streaming partial results"),
] = True,
vad_threshold: Annotated[
float,
Query(description="Speech detection threshold (0.0-1.0)", ge=0.0, le=1.0),
] = 0.3,
vad_silence_ms: Annotated[
int,
Query(description="Silence duration (ms) to end speech segment", ge=100, le=5000),
] = 1000,
vad_min_speech_ms: Annotated[
int,
Query(
description="Minimum speech duration (ms) to trigger transcription",
ge=50,
le=2000,
),
] = 250,
) -> None:
"""WebSocket endpoint for streaming transcription.
"""WebSocket endpoint for streaming transcription with optional VAD.

Protocol:
- Client sends binary audio chunks (16kHz, 16-bit, mono PCM)
- Client sends b"EOS" to signal end of audio
- Server sends JSON messages with transcription results

When use_vad=True (default):
- Partial transcriptions are sent as speech segments complete
- Final message contains combined text from all segments

Message format from server:
{"type": "partial", "text": "...", "is_final": false}
{"type": "final", "text": "...", "is_final": true, "segments": [...]}
{"type": "partial", "text": "...", "is_final": false, "language": "..."}
{"type": "final", "text": "...", "is_final": true, "language": "...", ...}
{"type": "error", "message": "..."}
"""
await websocket.accept()
Expand All @@ -340,74 +400,146 @@ async def stream_transcription(
await websocket.close()
return

# Collect audio data
audio_buffer = io.BytesIO()
wav_file: wave.Wave_write | None = None

try:
while True:
data = await websocket.receive_bytes()

# Initialize WAV file on first chunk (before EOS check)
if wav_file is None:
wav_file = wave.open(audio_buffer, "wb") # noqa: SIM115
setup_wav_file(wav_file)

# Check for end of stream (EOS marker)
eos_marker = b"EOS"
eos_len = len(eos_marker)
if data == eos_marker:
break
if data[-eos_len:] == eos_marker:
# Write remaining data before EOS marker
if len(data) > eos_len:
wav_file.writeframes(data[:-eos_len])
break

wav_file.writeframes(data)

# Close WAV file
if wav_file is not None:
wav_file.close()

# Get audio data
audio_buffer.seek(0)
audio_data = audio_buffer.read()

if not audio_data:
await websocket.send_json({"type": "error", "message": "No audio received"})
# Initialize VAD if requested
vad = None
if use_vad:
try:
vad = _create_vad(
threshold=vad_threshold,
silence_threshold_ms=vad_silence_ms,
min_speech_duration_ms=vad_min_speech_ms,
)
except ImportError as e:
await websocket.send_json({"type": "error", "message": str(e)})
await websocket.close()
return

# Transcribe
try:
result = await manager.transcribe(
audio_data,
language=language,
task="transcribe",
)
try:
if vad is not None:
# VAD-enabled streaming mode
await _stream_with_vad(websocket, manager, vad, language)
else:
# Legacy buffered mode (no VAD)
await _stream_buffered(websocket, manager, language)
except Exception as e:
logger.exception("WebSocket error")
with contextlib.suppress(Exception):
await websocket.send_json({"type": "error", "message": str(e)})
finally:
with contextlib.suppress(Exception):
await websocket.close()

async def _stream_with_vad(
websocket: WebSocket,
manager: Any,
vad: VoiceActivityDetector,
language: str | None,
) -> None:
"""Handle streaming transcription with VAD-based segmentation."""
all_segments_text: list[str] = []
total_duration: float = 0.0
final_language: str | None = None

async def process_segment(segment: bytes) -> None:
"""Transcribe segment and send partial result."""
nonlocal final_language, total_duration
result = await _transcribe_segment(manager, segment, language)
if result and result.text.strip():
all_segments_text.append(result.text.strip())
final_language = result.language
total_duration += result.duration
await websocket.send_json(
{
"type": "final",
"text": result.text,
"is_final": True,
"type": "partial",
"text": result.text.strip(),
"is_final": False,
"language": result.language,
"duration": result.duration,
"segments": result.segments,
},
)

except Exception as e:
await websocket.send_json({"type": "error", "message": str(e)})
while True:
data = await websocket.receive_bytes()
audio_chunk, is_eos = _parse_eos(data)

# Process audio chunk through VAD
if audio_chunk:
_is_speaking, segment = vad.process_chunk(audio_chunk)
if segment:
await process_segment(segment)

if is_eos:
# Flush any remaining audio in VAD buffer
if remaining := vad.flush():
await process_segment(remaining)
break

# Send final combined result
final_text = " ".join(all_segments_text)
await websocket.send_json(
{
"type": "final",
"text": final_text,
"is_final": True,
"language": final_language,
"duration": total_duration,
},
)

async def _stream_buffered(
websocket: WebSocket,
manager: Any,
language: str | None,
) -> None:
"""Handle streaming transcription with buffered mode (no VAD)."""
pcm_chunks: list[bytes] = []

while True:
data = await websocket.receive_bytes()
audio_chunk, is_eos = _parse_eos(data)
if audio_chunk:
pcm_chunks.append(audio_chunk)
if is_eos:
break

if not pcm_chunks:
await websocket.send_json({"type": "error", "message": "No audio received"})
return

# Transcribe
audio_data = pcm_to_wav(b"".join(pcm_chunks))
try:
result = await manager.transcribe(
audio_data,
language=language,
task="transcribe",
)
await websocket.send_json(
{
"type": "final",
"text": result.text,
"is_final": True,
"language": result.language,
"duration": result.duration,
"segments": result.segments,
},
)
except Exception as e:
logger.exception("WebSocket error")
with contextlib.suppress(Exception):
await websocket.send_json({"type": "error", "message": str(e)})
await websocket.send_json({"type": "error", "message": str(e)})

finally:
with contextlib.suppress(Exception):
await websocket.close()
async def _transcribe_segment(
manager: Any,
segment: bytes,
language: str | None,
) -> Any | None:
"""Transcribe a raw PCM audio segment."""
try:
return await manager.transcribe(
pcm_to_wav(segment),
language=language,
task="transcribe",
)
except Exception:
logger.exception("Failed to transcribe segment")
return None

return app
Loading