diff --git a/agent_cli/server/whisper/api.py b/agent_cli/server/whisper/api.py index f827c71d..5151b84a 100644 --- a/agent_cli/server/whisper/api.py +++ b/agent_cli/server/whisper/api.py @@ -3,23 +3,59 @@ from __future__ import annotations import contextlib -import io import logging -import wave from typing import TYPE_CHECKING, Annotated, Any, Literal from fastapi import FastAPI, File, Form, HTTPException, Query, UploadFile, WebSocket from fastapi.responses import PlainTextResponse from pydantic import BaseModel -from agent_cli.server.common import configure_app, create_lifespan, setup_wav_file +from agent_cli.server.common import configure_app, create_lifespan from agent_cli.server.whisper.backends.base import InvalidAudioError +from agent_cli.services import pcm_to_wav if TYPE_CHECKING: + from agent_cli.core.vad import VoiceActivityDetector from agent_cli.server.whisper.model_registry import WhisperModelRegistry logger = logging.getLogger(__name__) +_EOS_MARKER = b"EOS" + + +def _parse_eos(data: bytes) -> tuple[bytes, bool]: + """Parse data for EOS marker, returning (audio_chunk, is_eos).""" + if data == _EOS_MARKER: + return b"", True + if data.endswith(_EOS_MARKER): + return data[: -len(_EOS_MARKER)], True + return data, False + + +def _create_vad( + threshold: float, + silence_threshold_ms: int, + min_speech_duration_ms: int, +) -> VoiceActivityDetector: + """Create a VoiceActivityDetector instance. + + Raises ImportError if onnxruntime is not available. + """ + from agent_cli.core.vad import VoiceActivityDetector # noqa: PLC0415 + + try: + return VoiceActivityDetector( + threshold=threshold, + silence_threshold_ms=silence_threshold_ms, + min_speech_duration_ms=min_speech_duration_ms, + ) + except ImportError as e: + msg = ( + "VAD requires onnxruntime. Install it with: " + "`pip install agent-cli[vad]` or `uv sync --extra vad`" + ) + raise ImportError(msg) from e + def _split_seconds(seconds: float) -> tuple[int, int, int, int]: """Split seconds into (hours, minutes, seconds, milliseconds).""" @@ -316,17 +352,41 @@ async def stream_transcription( websocket: WebSocket, model: Annotated[str | None, Query(description="Model to use")] = None, language: Annotated[str | None, Query(description="Language code")] = None, + use_vad: Annotated[ + bool, + Query(description="Enable VAD for streaming partial results"), + ] = True, + vad_threshold: Annotated[ + float, + Query(description="Speech detection threshold (0.0-1.0)", ge=0.0, le=1.0), + ] = 0.3, + vad_silence_ms: Annotated[ + int, + Query(description="Silence duration (ms) to end speech segment", ge=100, le=5000), + ] = 1000, + vad_min_speech_ms: Annotated[ + int, + Query( + description="Minimum speech duration (ms) to trigger transcription", + ge=50, + le=2000, + ), + ] = 250, ) -> None: - """WebSocket endpoint for streaming transcription. + """WebSocket endpoint for streaming transcription with optional VAD. Protocol: - Client sends binary audio chunks (16kHz, 16-bit, mono PCM) - Client sends b"EOS" to signal end of audio - Server sends JSON messages with transcription results + When use_vad=True (default): + - Partial transcriptions are sent as speech segments complete + - Final message contains combined text from all segments + Message format from server: - {"type": "partial", "text": "...", "is_final": false} - {"type": "final", "text": "...", "is_final": true, "segments": [...]} + {"type": "partial", "text": "...", "is_final": false, "language": "..."} + {"type": "final", "text": "...", "is_final": true, "language": "...", ...} {"type": "error", "message": "..."} """ await websocket.accept() @@ -340,74 +400,146 @@ async def stream_transcription( await websocket.close() return - # Collect audio data - audio_buffer = io.BytesIO() - wav_file: wave.Wave_write | None = None - - try: - while True: - data = await websocket.receive_bytes() - - # Initialize WAV file on first chunk (before EOS check) - if wav_file is None: - wav_file = wave.open(audio_buffer, "wb") # noqa: SIM115 - setup_wav_file(wav_file) - - # Check for end of stream (EOS marker) - eos_marker = b"EOS" - eos_len = len(eos_marker) - if data == eos_marker: - break - if data[-eos_len:] == eos_marker: - # Write remaining data before EOS marker - if len(data) > eos_len: - wav_file.writeframes(data[:-eos_len]) - break - - wav_file.writeframes(data) - - # Close WAV file - if wav_file is not None: - wav_file.close() - - # Get audio data - audio_buffer.seek(0) - audio_data = audio_buffer.read() - - if not audio_data: - await websocket.send_json({"type": "error", "message": "No audio received"}) + # Initialize VAD if requested + vad = None + if use_vad: + try: + vad = _create_vad( + threshold=vad_threshold, + silence_threshold_ms=vad_silence_ms, + min_speech_duration_ms=vad_min_speech_ms, + ) + except ImportError as e: + await websocket.send_json({"type": "error", "message": str(e)}) await websocket.close() return - # Transcribe - try: - result = await manager.transcribe( - audio_data, - language=language, - task="transcribe", - ) + try: + if vad is not None: + # VAD-enabled streaming mode + await _stream_with_vad(websocket, manager, vad, language) + else: + # Legacy buffered mode (no VAD) + await _stream_buffered(websocket, manager, language) + except Exception as e: + logger.exception("WebSocket error") + with contextlib.suppress(Exception): + await websocket.send_json({"type": "error", "message": str(e)}) + finally: + with contextlib.suppress(Exception): + await websocket.close() + async def _stream_with_vad( + websocket: WebSocket, + manager: Any, + vad: VoiceActivityDetector, + language: str | None, + ) -> None: + """Handle streaming transcription with VAD-based segmentation.""" + all_segments_text: list[str] = [] + total_duration: float = 0.0 + final_language: str | None = None + + async def process_segment(segment: bytes) -> None: + """Transcribe segment and send partial result.""" + nonlocal final_language, total_duration + result = await _transcribe_segment(manager, segment, language) + if result and result.text.strip(): + all_segments_text.append(result.text.strip()) + final_language = result.language + total_duration += result.duration await websocket.send_json( { - "type": "final", - "text": result.text, - "is_final": True, + "type": "partial", + "text": result.text.strip(), + "is_final": False, "language": result.language, - "duration": result.duration, - "segments": result.segments, }, ) - except Exception as e: - await websocket.send_json({"type": "error", "message": str(e)}) + while True: + data = await websocket.receive_bytes() + audio_chunk, is_eos = _parse_eos(data) + + # Process audio chunk through VAD + if audio_chunk: + _is_speaking, segment = vad.process_chunk(audio_chunk) + if segment: + await process_segment(segment) + + if is_eos: + # Flush any remaining audio in VAD buffer + if remaining := vad.flush(): + await process_segment(remaining) + break + + # Send final combined result + final_text = " ".join(all_segments_text) + await websocket.send_json( + { + "type": "final", + "text": final_text, + "is_final": True, + "language": final_language, + "duration": total_duration, + }, + ) + async def _stream_buffered( + websocket: WebSocket, + manager: Any, + language: str | None, + ) -> None: + """Handle streaming transcription with buffered mode (no VAD).""" + pcm_chunks: list[bytes] = [] + + while True: + data = await websocket.receive_bytes() + audio_chunk, is_eos = _parse_eos(data) + if audio_chunk: + pcm_chunks.append(audio_chunk) + if is_eos: + break + + if not pcm_chunks: + await websocket.send_json({"type": "error", "message": "No audio received"}) + return + + # Transcribe + audio_data = pcm_to_wav(b"".join(pcm_chunks)) + try: + result = await manager.transcribe( + audio_data, + language=language, + task="transcribe", + ) + await websocket.send_json( + { + "type": "final", + "text": result.text, + "is_final": True, + "language": result.language, + "duration": result.duration, + "segments": result.segments, + }, + ) except Exception as e: - logger.exception("WebSocket error") - with contextlib.suppress(Exception): - await websocket.send_json({"type": "error", "message": str(e)}) + await websocket.send_json({"type": "error", "message": str(e)}) - finally: - with contextlib.suppress(Exception): - await websocket.close() + async def _transcribe_segment( + manager: Any, + segment: bytes, + language: str | None, + ) -> Any | None: + """Transcribe a raw PCM audio segment.""" + try: + return await manager.transcribe( + pcm_to_wav(segment), + language=language, + task="transcribe", + ) + except Exception: + logger.exception("Failed to transcribe segment") + return None return app diff --git a/tests/test_server_whisper.py b/tests/test_server_whisper.py index b5af9ece..7edb52ec 100644 --- a/tests/test_server_whisper.py +++ b/tests/test_server_whisper.py @@ -620,7 +620,7 @@ def test_websocket_streaming_transcription( return_value=mock_result, ), client.websocket_connect( - "/v1/audio/transcriptions/stream?model=whisper-1", + "/v1/audio/transcriptions/stream?model=whisper-1&use_vad=false", ) as websocket, ): for chunk in chunks: @@ -657,7 +657,7 @@ def test_websocket_streaming_transcribe_error( side_effect=RuntimeError("boom"), ), client.websocket_connect( - "/v1/audio/transcriptions/stream?model=whisper-1", + "/v1/audio/transcriptions/stream?model=whisper-1&use_vad=false", ) as websocket, ): websocket.send_bytes(b"\x00\x00" * 160) @@ -666,3 +666,149 @@ def test_websocket_streaming_transcribe_error( assert data["type"] == "error" assert data["message"] == "boom" + + def test_websocket_vad_mode_single_segment( + self, + client: TestClient, + mock_registry: WhisperModelRegistry, + ) -> None: + """Test VAD mode sends partial then final for a single segment.""" + mock_result = TranscriptionResult( + text="Hello world", + language="en", + language_probability=0.95, + duration=1.5, + segments=[], + ) + + # Create a mock VAD that returns one segment on flush + mock_vad = type( + "MockVAD", + (), + { + "process_chunk": lambda _self, _chunk: (False, None), + "flush": lambda _self: b"audio_data", + }, + )() + + manager = mock_registry.get_manager() + with ( + patch.object( + manager, + "transcribe", + new_callable=AsyncMock, + return_value=mock_result, + ), + patch( + "agent_cli.server.whisper.api._create_vad", + return_value=mock_vad, + ), + client.websocket_connect( + "/v1/audio/transcriptions/stream?model=whisper-1&use_vad=true", + ) as websocket, + ): + websocket.send_bytes(b"\x00\x00" * 160) + websocket.send_bytes(b"EOS") + + # Should receive partial then final + partial = websocket.receive_json() + final = websocket.receive_json() + + assert partial["type"] == "partial" + assert partial["text"] == "Hello world" + assert partial["is_final"] is False + + assert final["type"] == "final" + assert final["text"] == "Hello world" + assert final["is_final"] is True + assert final["duration"] == 1.5 + + def test_websocket_vad_mode_multiple_segments( + self, + client: TestClient, + mock_registry: WhisperModelRegistry, + ) -> None: + """Test VAD mode sends multiple partials for multiple segments.""" + results = [ + TranscriptionResult( + text="First segment", + language="en", + language_probability=0.95, + duration=1.0, + segments=[], + ), + TranscriptionResult( + text="Second segment", + language="en", + language_probability=0.95, + duration=0.8, + segments=[], + ), + ] + + # Mock VAD that returns segment on first chunk, another on flush + call_count = [0] + + def mock_process_chunk(_self: object, _chunk: bytes) -> tuple[bool, bytes | None]: + call_count[0] += 1 + if call_count[0] == 1: + return (False, b"segment1") + return (False, None) + + mock_vad = type( + "MockVAD", + (), + { + "process_chunk": mock_process_chunk, + "flush": lambda _self: b"segment2", + }, + )() + + manager = mock_registry.get_manager() + transcribe_mock = AsyncMock(side_effect=results) + + with ( + patch.object(manager, "transcribe", transcribe_mock), + patch( + "agent_cli.server.whisper.api._create_vad", + return_value=mock_vad, + ), + client.websocket_connect( + "/v1/audio/transcriptions/stream?model=whisper-1&use_vad=true", + ) as websocket, + ): + websocket.send_bytes(b"\x00\x00" * 160) + websocket.send_bytes(b"EOS") + + partial1 = websocket.receive_json() + partial2 = websocket.receive_json() + final = websocket.receive_json() + + assert partial1["type"] == "partial" + assert partial1["text"] == "First segment" + + assert partial2["type"] == "partial" + assert partial2["text"] == "Second segment" + + assert final["type"] == "final" + assert final["text"] == "First segment Second segment" + assert final["duration"] == pytest.approx(1.8) + + def test_websocket_vad_not_available( + self, + client: TestClient, + ) -> None: + """Test error message when VAD is requested but not available.""" + with ( + patch( + "agent_cli.server.whisper.api._create_vad", + side_effect=ImportError("VAD not available"), + ), + client.websocket_connect( + "/v1/audio/transcriptions/stream?model=whisper-1&use_vad=true", + ) as websocket, + ): + data = websocket.receive_json() + + assert data["type"] == "error" + assert "VAD not available" in data["message"]