From 1d9c1972e4b3cd327b28c0db0db12bb1d79f4296 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Thu, 29 Jan 2026 05:41:15 -0800 Subject: [PATCH 1/7] refactor: extract duplicate code, add VAD tests, add duration to response - Extract duplicate segment processing into inner process_segment() function - Add duration tracking to VAD mode for consistent response format - Add 3 new tests: single segment, multiple segments, VAD not available - Remove unused .claude/REPORT.md from commit --- agent_cli/server/whisper/api.py | 288 +++++++++++++++++++++++++------- tests/test_server_whisper.py | 150 ++++++++++++++++- 2 files changed, 377 insertions(+), 61 deletions(-) diff --git a/agent_cli/server/whisper/api.py b/agent_cli/server/whisper/api.py index f827c71de..ce5acd537 100644 --- a/agent_cli/server/whisper/api.py +++ b/agent_cli/server/whisper/api.py @@ -16,10 +16,42 @@ from agent_cli.server.whisper.backends.base import InvalidAudioError if TYPE_CHECKING: + from agent_cli.core.vad import VoiceActivityDetector from agent_cli.server.whisper.model_registry import WhisperModelRegistry logger = logging.getLogger(__name__) +# VAD availability check - the vad extra may not be installed +_VAD_AVAILABLE = False +try: + from agent_cli.core.vad import VoiceActivityDetector as _VoiceActivityDetector + + _VAD_AVAILABLE = True +except ImportError: + _VoiceActivityDetector = None # type: ignore[misc, assignment] + + +def _create_vad( + threshold: float, + silence_threshold_ms: int, + min_speech_duration_ms: int, +) -> VoiceActivityDetector: + """Create a VoiceActivityDetector instance. + + Raises ImportError if VAD is not available. + """ + if not _VAD_AVAILABLE: + msg = ( + "VAD is not available. Install it with: " + "`pip install agent-cli[vad]` or `uv sync --extra vad`" + ) + raise ImportError(msg) + return _VoiceActivityDetector( + threshold=threshold, + silence_threshold_ms=silence_threshold_ms, + min_speech_duration_ms=min_speech_duration_ms, + ) + def _split_seconds(seconds: float) -> tuple[int, int, int, int]: """Split seconds into (hours, minutes, seconds, milliseconds).""" @@ -316,17 +348,41 @@ async def stream_transcription( websocket: WebSocket, model: Annotated[str | None, Query(description="Model to use")] = None, language: Annotated[str | None, Query(description="Language code")] = None, + use_vad: Annotated[ + bool, + Query(description="Enable VAD for streaming partial results"), + ] = True, + vad_threshold: Annotated[ + float, + Query(description="Speech detection threshold (0.0-1.0)", ge=0.0, le=1.0), + ] = 0.3, + vad_silence_ms: Annotated[ + int, + Query(description="Silence duration (ms) to end speech segment", ge=100, le=5000), + ] = 1000, + vad_min_speech_ms: Annotated[ + int, + Query( + description="Minimum speech duration (ms) to trigger transcription", + ge=50, + le=2000, + ), + ] = 250, ) -> None: - """WebSocket endpoint for streaming transcription. + """WebSocket endpoint for streaming transcription with optional VAD. Protocol: - Client sends binary audio chunks (16kHz, 16-bit, mono PCM) - Client sends b"EOS" to signal end of audio - Server sends JSON messages with transcription results + When use_vad=True (default): + - Partial transcriptions are sent as speech segments complete + - Final message contains combined text from all segments + Message format from server: - {"type": "partial", "text": "...", "is_final": false} - {"type": "final", "text": "...", "is_final": true, "segments": [...]} + {"type": "partial", "text": "...", "is_final": false, "language": "..."} + {"type": "final", "text": "...", "is_final": true, "language": "...", ...} {"type": "error", "message": "..."} """ await websocket.accept() @@ -340,74 +396,188 @@ async def stream_transcription( await websocket.close() return - # Collect audio data - audio_buffer = io.BytesIO() - wav_file: wave.Wave_write | None = None - - try: - while True: - data = await websocket.receive_bytes() - - # Initialize WAV file on first chunk (before EOS check) - if wav_file is None: - wav_file = wave.open(audio_buffer, "wb") # noqa: SIM115 - setup_wav_file(wav_file) - - # Check for end of stream (EOS marker) - eos_marker = b"EOS" - eos_len = len(eos_marker) - if data == eos_marker: - break - if data[-eos_len:] == eos_marker: - # Write remaining data before EOS marker - if len(data) > eos_len: - wav_file.writeframes(data[:-eos_len]) - break - - wav_file.writeframes(data) - - # Close WAV file - if wav_file is not None: - wav_file.close() - - # Get audio data - audio_buffer.seek(0) - audio_data = audio_buffer.read() - - if not audio_data: - await websocket.send_json({"type": "error", "message": "No audio received"}) + # Initialize VAD if requested + vad = None + if use_vad: + try: + vad = _create_vad( + threshold=vad_threshold, + silence_threshold_ms=vad_silence_ms, + min_speech_duration_ms=vad_min_speech_ms, + ) + except ImportError as e: + await websocket.send_json({"type": "error", "message": str(e)}) await websocket.close() return - # Transcribe - try: - result = await manager.transcribe( - audio_data, - language=language, - task="transcribe", - ) + try: + if vad is not None: + # VAD-enabled streaming mode + await _stream_with_vad(websocket, manager, vad, language) + else: + # Legacy buffered mode (no VAD) + await _stream_buffered(websocket, manager, language) + except Exception as e: + logger.exception("WebSocket error") + with contextlib.suppress(Exception): + await websocket.send_json({"type": "error", "message": str(e)}) + finally: + with contextlib.suppress(Exception): + await websocket.close() + async def _stream_with_vad( + websocket: WebSocket, + manager: Any, + vad: VoiceActivityDetector, + language: str | None, + ) -> None: + """Handle streaming transcription with VAD-based segmentation.""" + all_segments_text: list[str] = [] + total_duration: float = 0.0 + final_language: str | None = None + eos_marker = b"EOS" + eos_len = len(eos_marker) + + async def process_segment(segment: bytes) -> None: + """Transcribe segment and send partial result.""" + nonlocal final_language, total_duration + result = await _transcribe_segment(manager, segment, language) + if result and result.text.strip(): + all_segments_text.append(result.text.strip()) + final_language = result.language + total_duration += result.duration await websocket.send_json( { - "type": "final", - "text": result.text, - "is_final": True, + "type": "partial", + "text": result.text.strip(), + "is_final": False, "language": result.language, - "duration": result.duration, - "segments": result.segments, }, ) - except Exception as e: - await websocket.send_json({"type": "error", "message": str(e)}) + while True: + data = await websocket.receive_bytes() + + # Check for end of stream + is_eos = data == eos_marker + audio_chunk = b"" + + if is_eos: + pass # No audio to process + elif data[-eos_len:] == eos_marker: + # Audio followed by EOS marker + audio_chunk = data[:-eos_len] + is_eos = True + else: + audio_chunk = data + + # Process audio chunk through VAD + if audio_chunk: + _is_speaking, segment = vad.process_chunk(audio_chunk) + if segment: + await process_segment(segment) + + if is_eos: + # Flush any remaining audio in VAD buffer + if remaining := vad.flush(): + await process_segment(remaining) + break + + # Send final combined result + final_text = " ".join(all_segments_text) + await websocket.send_json( + { + "type": "final", + "text": final_text, + "is_final": True, + "language": final_language, + "duration": total_duration, + }, + ) + + async def _stream_buffered( + websocket: WebSocket, + manager: Any, + language: str | None, + ) -> None: + """Handle streaming transcription with buffered mode (no VAD).""" + audio_buffer = io.BytesIO() + wav_file: wave.Wave_write | None = None + eos_marker = b"EOS" + eos_len = len(eos_marker) + + while True: + data = await websocket.receive_bytes() + + # Initialize WAV file on first chunk (before EOS check) + if wav_file is None: + wav_file = wave.open(audio_buffer, "wb") # noqa: SIM115 + setup_wav_file(wav_file) + + # Check for end of stream + if data == eos_marker: + break + if data[-eos_len:] == eos_marker: + # Write remaining data before EOS marker + if len(data) > eos_len: + wav_file.writeframes(data[:-eos_len]) + break + + wav_file.writeframes(data) + + # Close WAV file + if wav_file is not None: + wav_file.close() + + # Get audio data + audio_buffer.seek(0) + audio_data = audio_buffer.read() + + if not audio_data: + await websocket.send_json({"type": "error", "message": "No audio received"}) + return + + # Transcribe + try: + result = await manager.transcribe( + audio_data, + language=language, + task="transcribe", + ) + await websocket.send_json( + { + "type": "final", + "text": result.text, + "is_final": True, + "language": result.language, + "duration": result.duration, + "segments": result.segments, + }, + ) except Exception as e: - logger.exception("WebSocket error") - with contextlib.suppress(Exception): - await websocket.send_json({"type": "error", "message": str(e)}) + await websocket.send_json({"type": "error", "message": str(e)}) - finally: - with contextlib.suppress(Exception): - await websocket.close() + async def _transcribe_segment( + manager: Any, + segment: bytes, + language: str | None, + ) -> Any | None: + """Transcribe a raw PCM audio segment by wrapping it in WAV format.""" + try: + # Wrap raw PCM in WAV format for transcription + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, "wb") as wav_file: + setup_wav_file(wav_file) + wav_file.writeframes(segment) + wav_buffer.seek(0) + return await manager.transcribe( + wav_buffer.read(), + language=language, + task="transcribe", + ) + except Exception: + logger.exception("Failed to transcribe segment") + return None return app diff --git a/tests/test_server_whisper.py b/tests/test_server_whisper.py index b5af9ecea..7edb52ec9 100644 --- a/tests/test_server_whisper.py +++ b/tests/test_server_whisper.py @@ -620,7 +620,7 @@ def test_websocket_streaming_transcription( return_value=mock_result, ), client.websocket_connect( - "/v1/audio/transcriptions/stream?model=whisper-1", + "/v1/audio/transcriptions/stream?model=whisper-1&use_vad=false", ) as websocket, ): for chunk in chunks: @@ -657,7 +657,7 @@ def test_websocket_streaming_transcribe_error( side_effect=RuntimeError("boom"), ), client.websocket_connect( - "/v1/audio/transcriptions/stream?model=whisper-1", + "/v1/audio/transcriptions/stream?model=whisper-1&use_vad=false", ) as websocket, ): websocket.send_bytes(b"\x00\x00" * 160) @@ -666,3 +666,149 @@ def test_websocket_streaming_transcribe_error( assert data["type"] == "error" assert data["message"] == "boom" + + def test_websocket_vad_mode_single_segment( + self, + client: TestClient, + mock_registry: WhisperModelRegistry, + ) -> None: + """Test VAD mode sends partial then final for a single segment.""" + mock_result = TranscriptionResult( + text="Hello world", + language="en", + language_probability=0.95, + duration=1.5, + segments=[], + ) + + # Create a mock VAD that returns one segment on flush + mock_vad = type( + "MockVAD", + (), + { + "process_chunk": lambda _self, _chunk: (False, None), + "flush": lambda _self: b"audio_data", + }, + )() + + manager = mock_registry.get_manager() + with ( + patch.object( + manager, + "transcribe", + new_callable=AsyncMock, + return_value=mock_result, + ), + patch( + "agent_cli.server.whisper.api._create_vad", + return_value=mock_vad, + ), + client.websocket_connect( + "/v1/audio/transcriptions/stream?model=whisper-1&use_vad=true", + ) as websocket, + ): + websocket.send_bytes(b"\x00\x00" * 160) + websocket.send_bytes(b"EOS") + + # Should receive partial then final + partial = websocket.receive_json() + final = websocket.receive_json() + + assert partial["type"] == "partial" + assert partial["text"] == "Hello world" + assert partial["is_final"] is False + + assert final["type"] == "final" + assert final["text"] == "Hello world" + assert final["is_final"] is True + assert final["duration"] == 1.5 + + def test_websocket_vad_mode_multiple_segments( + self, + client: TestClient, + mock_registry: WhisperModelRegistry, + ) -> None: + """Test VAD mode sends multiple partials for multiple segments.""" + results = [ + TranscriptionResult( + text="First segment", + language="en", + language_probability=0.95, + duration=1.0, + segments=[], + ), + TranscriptionResult( + text="Second segment", + language="en", + language_probability=0.95, + duration=0.8, + segments=[], + ), + ] + + # Mock VAD that returns segment on first chunk, another on flush + call_count = [0] + + def mock_process_chunk(_self: object, _chunk: bytes) -> tuple[bool, bytes | None]: + call_count[0] += 1 + if call_count[0] == 1: + return (False, b"segment1") + return (False, None) + + mock_vad = type( + "MockVAD", + (), + { + "process_chunk": mock_process_chunk, + "flush": lambda _self: b"segment2", + }, + )() + + manager = mock_registry.get_manager() + transcribe_mock = AsyncMock(side_effect=results) + + with ( + patch.object(manager, "transcribe", transcribe_mock), + patch( + "agent_cli.server.whisper.api._create_vad", + return_value=mock_vad, + ), + client.websocket_connect( + "/v1/audio/transcriptions/stream?model=whisper-1&use_vad=true", + ) as websocket, + ): + websocket.send_bytes(b"\x00\x00" * 160) + websocket.send_bytes(b"EOS") + + partial1 = websocket.receive_json() + partial2 = websocket.receive_json() + final = websocket.receive_json() + + assert partial1["type"] == "partial" + assert partial1["text"] == "First segment" + + assert partial2["type"] == "partial" + assert partial2["text"] == "Second segment" + + assert final["type"] == "final" + assert final["text"] == "First segment Second segment" + assert final["duration"] == pytest.approx(1.8) + + def test_websocket_vad_not_available( + self, + client: TestClient, + ) -> None: + """Test error message when VAD is requested but not available.""" + with ( + patch( + "agent_cli.server.whisper.api._create_vad", + side_effect=ImportError("VAD not available"), + ), + client.websocket_connect( + "/v1/audio/transcriptions/stream?model=whisper-1&use_vad=true", + ) as websocket, + ): + data = websocket.receive_json() + + assert data["type"] == "error" + assert "VAD not available" in data["message"] From e912450514d5265def32bb925b793bdf14fe0d86 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Wed, 4 Feb 2026 08:18:50 -0800 Subject: [PATCH 2/7] fix: move VAD import to instantiation time where onnxruntime actually loads The module-level try/except was useless since VoiceActivityDetector imports onnxruntime lazily in __init__, not at module load time. --- agent_cli/server/whisper/api.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/agent_cli/server/whisper/api.py b/agent_cli/server/whisper/api.py index ce5acd537..dfc61600a 100644 --- a/agent_cli/server/whisper/api.py +++ b/agent_cli/server/whisper/api.py @@ -21,15 +21,6 @@ logger = logging.getLogger(__name__) -# VAD availability check - the vad extra may not be installed -_VAD_AVAILABLE = False -try: - from agent_cli.core.vad import VoiceActivityDetector as _VoiceActivityDetector - - _VAD_AVAILABLE = True -except ImportError: - _VoiceActivityDetector = None # type: ignore[misc, assignment] - def _create_vad( threshold: float, @@ -38,19 +29,22 @@ def _create_vad( ) -> VoiceActivityDetector: """Create a VoiceActivityDetector instance. - Raises ImportError if VAD is not available. + Raises ImportError if onnxruntime is not available. """ - if not _VAD_AVAILABLE: + from agent_cli.core.vad import VoiceActivityDetector as _VoiceActivityDetector # noqa: PLC0415 + + try: + return _VoiceActivityDetector( + threshold=threshold, + silence_threshold_ms=silence_threshold_ms, + min_speech_duration_ms=min_speech_duration_ms, + ) + except ImportError as e: msg = ( - "VAD is not available. Install it with: " + "VAD requires onnxruntime. Install it with: " "`pip install agent-cli[vad]` or `uv sync --extra vad`" ) - raise ImportError(msg) - return _VoiceActivityDetector( - threshold=threshold, - silence_threshold_ms=silence_threshold_ms, - min_speech_duration_ms=min_speech_duration_ms, - ) + raise ImportError(msg) from e def _split_seconds(seconds: float) -> tuple[int, int, int, int]: From 1c37785fcad95f5ce61cc629ad0af36862d553c1 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Wed, 4 Feb 2026 08:26:23 -0800 Subject: [PATCH 3/7] fix: remove unnecessary try-except blocks that hide errors - Remove ImportError handling since onnxruntime is a transitive dep of faster-whisper - Let _transcribe_segment exceptions propagate instead of silently returning None - Remove test for VAD unavailability (no longer a valid scenario) --- agent_cli/server/whisper/api.py | 66 ++++++++++++--------------------- tests/test_server_whisper.py | 19 ---------- 2 files changed, 23 insertions(+), 62 deletions(-) diff --git a/agent_cli/server/whisper/api.py b/agent_cli/server/whisper/api.py index dfc61600a..91f6c12d0 100644 --- a/agent_cli/server/whisper/api.py +++ b/agent_cli/server/whisper/api.py @@ -27,24 +27,14 @@ def _create_vad( silence_threshold_ms: int, min_speech_duration_ms: int, ) -> VoiceActivityDetector: - """Create a VoiceActivityDetector instance. - - Raises ImportError if onnxruntime is not available. - """ + """Create a VoiceActivityDetector instance.""" from agent_cli.core.vad import VoiceActivityDetector as _VoiceActivityDetector # noqa: PLC0415 - try: - return _VoiceActivityDetector( - threshold=threshold, - silence_threshold_ms=silence_threshold_ms, - min_speech_duration_ms=min_speech_duration_ms, - ) - except ImportError as e: - msg = ( - "VAD requires onnxruntime. Install it with: " - "`pip install agent-cli[vad]` or `uv sync --extra vad`" - ) - raise ImportError(msg) from e + return _VoiceActivityDetector( + threshold=threshold, + silence_threshold_ms=silence_threshold_ms, + min_speech_duration_ms=min_speech_duration_ms, + ) def _split_seconds(seconds: float) -> tuple[int, int, int, int]: @@ -393,16 +383,11 @@ async def stream_transcription( # Initialize VAD if requested vad = None if use_vad: - try: - vad = _create_vad( - threshold=vad_threshold, - silence_threshold_ms=vad_silence_ms, - min_speech_duration_ms=vad_min_speech_ms, - ) - except ImportError as e: - await websocket.send_json({"type": "error", "message": str(e)}) - await websocket.close() - return + vad = _create_vad( + threshold=vad_threshold, + silence_threshold_ms=vad_silence_ms, + min_speech_duration_ms=vad_min_speech_ms, + ) try: if vad is not None: @@ -436,7 +421,7 @@ async def process_segment(segment: bytes) -> None: """Transcribe segment and send partial result.""" nonlocal final_language, total_duration result = await _transcribe_segment(manager, segment, language) - if result and result.text.strip(): + if result.text.strip(): all_segments_text.append(result.text.strip()) final_language = result.language total_duration += result.duration @@ -556,22 +541,17 @@ async def _transcribe_segment( manager: Any, segment: bytes, language: str | None, - ) -> Any | None: + ) -> Any: """Transcribe a raw PCM audio segment by wrapping it in WAV format.""" - try: - # Wrap raw PCM in WAV format for transcription - wav_buffer = io.BytesIO() - with wave.open(wav_buffer, "wb") as wav_file: - setup_wav_file(wav_file) - wav_file.writeframes(segment) - wav_buffer.seek(0) - return await manager.transcribe( - wav_buffer.read(), - language=language, - task="transcribe", - ) - except Exception: - logger.exception("Failed to transcribe segment") - return None + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, "wb") as wav_file: + setup_wav_file(wav_file) + wav_file.writeframes(segment) + wav_buffer.seek(0) + return await manager.transcribe( + wav_buffer.read(), + language=language, + task="transcribe", + ) return app diff --git a/tests/test_server_whisper.py b/tests/test_server_whisper.py index 7edb52ec9..23b2cb92d 100644 --- a/tests/test_server_whisper.py +++ b/tests/test_server_whisper.py @@ -793,22 +793,3 @@ def mock_process_chunk(_self: object, _chunk: bytes) -> tuple[bool, bytes | None assert final["type"] == "final" assert final["text"] == "First segment Second segment" assert final["duration"] == pytest.approx(1.8) - - def test_websocket_vad_not_available( - self, - client: TestClient, - ) -> None: - """Test error message when VAD is requested but not available.""" - with ( - patch( - "agent_cli.server.whisper.api._create_vad", - side_effect=ImportError("VAD not available"), - ), - client.websocket_connect( - "/v1/audio/transcriptions/stream?model=whisper-1&use_vad=true", - ) as websocket, - ): - data = websocket.receive_json() - - assert data["type"] == "error" - assert "VAD not available" in data["message"] From 6ad9ed1f19d390d47f8b33ea4268667f91371a83 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Wed, 4 Feb 2026 08:33:25 -0800 Subject: [PATCH 4/7] feat: add VAD as dependency for all whisper backends Ensures onnxruntime (via vad extra) is available for all whisper backends (faster-whisper, mlx-whisper, whisper-transformers), making VAD streaming work out of the box. --- agent_cli/_requirements/faster-whisper.txt | 4 ++- agent_cli/_requirements/mlx-whisper.txt | 29 +++++++++++++++---- .../_requirements/whisper-transformers.txt | 21 ++++++++++++-- pyproject.toml | 3 ++ uv.lock | 6 ++++ 5 files changed, 54 insertions(+), 9 deletions(-) diff --git a/agent_cli/_requirements/faster-whisper.txt b/agent_cli/_requirements/faster-whisper.txt index 694ae52f6..6404f3375 100644 --- a/agent_cli/_requirements/faster-whisper.txt +++ b/agent_cli/_requirements/faster-whisper.txt @@ -100,7 +100,9 @@ numpy==2.3.5 # ctranslate2 # onnxruntime onnxruntime==1.20.1 - # via faster-whisper + # via + # agent-cli + # faster-whisper packaging==25.0 # via # huggingface-hub diff --git a/agent_cli/_requirements/mlx-whisper.txt b/agent_cli/_requirements/mlx-whisper.txt index 3b8de140c..f8b91b568 100644 --- a/agent_cli/_requirements/mlx-whisper.txt +++ b/agent_cli/_requirements/mlx-whisper.txt @@ -28,6 +28,8 @@ colorama==0.4.6 ; sys_platform == 'win32' # click # tqdm # uvicorn +coloredlogs==15.0.1 + # via onnxruntime dnspython==2.8.0 # via email-validator dotenv==0.9.9 @@ -48,6 +50,8 @@ filelock==3.20.3 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via # huggingface-hub # torch +flatbuffers==25.12.19 + # via onnxruntime fsspec==2026.1.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via # huggingface-hub @@ -69,6 +73,8 @@ httpx==0.28.1 # fastapi-cloud-cli huggingface-hub==0.36.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via mlx-whisper +humanfriendly==10.0 + # via coloredlogs idna==3.11 # via # anyio @@ -95,19 +101,26 @@ mlx-whisper==0.4.3 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via agent-cli more-itertools==10.8.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via mlx-whisper -mpmath==1.3.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' +mpmath==1.3.0 # via sympy networkx==3.6.1 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via torch numba==0.63.1 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via mlx-whisper -numpy==2.3.5 ; platform_machine == 'arm64' and sys_platform == 'darwin' +numpy==2.3.5 # via # mlx-whisper # numba + # onnxruntime # scipy -packaging==25.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' - # via huggingface-hub +onnxruntime==1.20.1 + # via agent-cli +packaging==25.0 + # via + # huggingface-hub + # onnxruntime +protobuf==6.33.4 + # via onnxruntime psutil==7.2.1 ; sys_platform == 'win32' # via agent-cli pydantic==2.12.5 @@ -127,6 +140,8 @@ pygments==2.19.2 # via rich pyperclip==1.11.0 # via agent-cli +pyreadline3==3.5.4 ; sys_platform == 'win32' + # via humanfriendly python-dotenv==1.2.1 # via # dotenv @@ -170,8 +185,10 @@ shellingham==1.5.4 # typer-slim starlette==0.50.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via fastapi -sympy==1.14.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' - # via torch +sympy==1.14.0 + # via + # onnxruntime + # torch tiktoken==0.12.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via mlx-whisper torch==2.9.1 ; platform_machine == 'arm64' and sys_platform == 'darwin' diff --git a/agent_cli/_requirements/whisper-transformers.txt b/agent_cli/_requirements/whisper-transformers.txt index a9d0037c7..54d7a6008 100644 --- a/agent_cli/_requirements/whisper-transformers.txt +++ b/agent_cli/_requirements/whisper-transformers.txt @@ -28,6 +28,8 @@ colorama==0.4.6 ; sys_platform == 'win32' # click # tqdm # uvicorn +coloredlogs==15.0.1 + # via onnxruntime dnspython==2.8.0 # via email-validator dotenv==0.9.9 @@ -49,6 +51,8 @@ filelock==3.20.3 # huggingface-hub # torch # transformers +flatbuffers==25.12.19 + # via onnxruntime fsspec==2026.1.0 # via # huggingface-hub @@ -72,6 +76,8 @@ huggingface-hub==0.36.0 # via # tokenizers # transformers +humanfriendly==10.0 + # via coloredlogs idna==3.11 # via # anyio @@ -93,7 +99,9 @@ mpmath==1.3.0 networkx==3.6.1 # via torch numpy==2.3.5 - # via transformers + # via + # onnxruntime + # transformers nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via # nvidia-cudnn-cu12 @@ -133,10 +141,15 @@ nvidia-nvshmem-cu12==3.3.20 ; platform_machine == 'x86_64' and sys_platform == ' # via torch nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch +onnxruntime==1.20.1 + # via agent-cli packaging==25.0 # via # huggingface-hub + # onnxruntime # transformers +protobuf==6.33.4 + # via onnxruntime psutil==7.2.1 ; sys_platform == 'win32' # via agent-cli pydantic==2.12.5 @@ -156,6 +169,8 @@ pygments==2.19.2 # via rich pyperclip==1.11.0 # via agent-cli +pyreadline3==3.5.4 ; sys_platform == 'win32' + # via humanfriendly python-dotenv==1.2.1 # via # dotenv @@ -201,7 +216,9 @@ shellingham==1.5.4 starlette==0.50.0 # via fastapi sympy==1.14.0 - # via torch + # via + # onnxruntime + # torch tokenizers==0.22.2 # via transformers torch==2.9.1 diff --git a/pyproject.toml b/pyproject.toml index 2349c34bf..d0b3d176a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,15 +69,18 @@ vad = [ faster-whisper = [ "fastapi[standard]", "faster-whisper>=1.0.0", + "agent-cli[vad]", # VAD for streaming transcription ] mlx-whisper = [ "fastapi[standard]; sys_platform == 'darwin' and platform_machine == 'arm64'", "mlx-whisper>=0.4.0; sys_platform == 'darwin' and platform_machine == 'arm64'", + "agent-cli[vad]", # VAD for streaming transcription ] whisper-transformers = [ "fastapi[standard]", "transformers>=4.30.0", "torch>=2.0.0", + "agent-cli[vad]", # VAD for streaming transcription ] piper = [ "fastapi[standard]", diff --git a/uv.lock b/uv.lock index 7cb9dc132..0dad25f66 100644 --- a/uv.lock +++ b/uv.lock @@ -59,6 +59,7 @@ dev = [ faster-whisper = [ { name = "fastapi", extra = ["standard"] }, { name = "faster-whisper" }, + { name = "onnxruntime" }, ] kokoro = [ { name = "fastapi", extra = ["standard"] }, @@ -85,6 +86,7 @@ memory = [ mlx-whisper = [ { name = "fastapi", extra = ["standard"], marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, { name = "mlx-whisper", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "onnxruntime" }, ] piper = [ { name = "fastapi", extra = ["standard"] }, @@ -119,6 +121,7 @@ vad = [ ] whisper-transformers = [ { name = "fastapi", extra = ["standard"] }, + { name = "onnxruntime" }, { name = "torch" }, { name = "transformers" }, ] @@ -170,9 +173,12 @@ requires-dist = [ { name = "mlx-whisper", marker = "platform_machine == 'arm64' and sys_platform == 'darwin' and extra == 'mlx-whisper'", specifier = ">=0.4.0" }, { name = "notebook", marker = "extra == 'dev'" }, { name = "numpy", marker = "extra == 'audio'" }, + { name = "onnxruntime", marker = "extra == 'faster-whisper'", specifier = ">=1.16.0" }, { name = "onnxruntime", marker = "extra == 'memory'", specifier = ">=1.17.0" }, + { name = "onnxruntime", marker = "extra == 'mlx-whisper'", specifier = ">=1.16.0" }, { name = "onnxruntime", marker = "extra == 'rag'", specifier = ">=1.17.0" }, { name = "onnxruntime", marker = "extra == 'vad'", specifier = ">=1.16.0" }, + { name = "onnxruntime", marker = "extra == 'whisper-transformers'", specifier = ">=1.16.0" }, { name = "openai", marker = "extra == 'memory'", specifier = ">=1.0.0" }, { name = "openai", marker = "extra == 'rag'", specifier = ">=1.0.0" }, { name = "pip", marker = "extra == 'kokoro'" }, From 9f529931043b57868a64dfecefcc0b9d2b2b0269 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Wed, 4 Feb 2026 09:05:52 -0800 Subject: [PATCH 5/7] revert: keep VAD as optional dependency Reverts adding VAD as a dependency for all whisper backends. onnxruntime adds ~138MB which is significant for mlx-whisper users. Users who want VAD streaming should install agent-cli[vad] separately. --- agent_cli/_requirements/faster-whisper.txt | 4 +- agent_cli/_requirements/mlx-whisper.txt | 29 ++------ .../_requirements/whisper-transformers.txt | 21 +----- agent_cli/server/whisper/api.py | 68 ++++++++++++------- pyproject.toml | 3 - tests/test_server_whisper.py | 19 ++++++ uv.lock | 6 -- 7 files changed, 72 insertions(+), 78 deletions(-) diff --git a/agent_cli/_requirements/faster-whisper.txt b/agent_cli/_requirements/faster-whisper.txt index 6404f3375..694ae52f6 100644 --- a/agent_cli/_requirements/faster-whisper.txt +++ b/agent_cli/_requirements/faster-whisper.txt @@ -100,9 +100,7 @@ numpy==2.3.5 # ctranslate2 # onnxruntime onnxruntime==1.20.1 - # via - # agent-cli - # faster-whisper + # via faster-whisper packaging==25.0 # via # huggingface-hub diff --git a/agent_cli/_requirements/mlx-whisper.txt b/agent_cli/_requirements/mlx-whisper.txt index f8b91b568..3b8de140c 100644 --- a/agent_cli/_requirements/mlx-whisper.txt +++ b/agent_cli/_requirements/mlx-whisper.txt @@ -28,8 +28,6 @@ colorama==0.4.6 ; sys_platform == 'win32' # click # tqdm # uvicorn -coloredlogs==15.0.1 - # via onnxruntime dnspython==2.8.0 # via email-validator dotenv==0.9.9 @@ -50,8 +48,6 @@ filelock==3.20.3 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via # huggingface-hub # torch -flatbuffers==25.12.19 - # via onnxruntime fsspec==2026.1.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via # huggingface-hub @@ -73,8 +69,6 @@ httpx==0.28.1 # fastapi-cloud-cli huggingface-hub==0.36.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via mlx-whisper -humanfriendly==10.0 - # via coloredlogs idna==3.11 # via # anyio @@ -101,26 +95,19 @@ mlx-whisper==0.4.3 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via agent-cli more-itertools==10.8.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via mlx-whisper -mpmath==1.3.0 +mpmath==1.3.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via sympy networkx==3.6.1 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via torch numba==0.63.1 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via mlx-whisper -numpy==2.3.5 +numpy==2.3.5 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via # mlx-whisper # numba - # onnxruntime # scipy -onnxruntime==1.20.1 - # via agent-cli -packaging==25.0 - # via - # huggingface-hub - # onnxruntime -protobuf==6.33.4 - # via onnxruntime +packaging==25.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' + # via huggingface-hub psutil==7.2.1 ; sys_platform == 'win32' # via agent-cli pydantic==2.12.5 @@ -140,8 +127,6 @@ pygments==2.19.2 # via rich pyperclip==1.11.0 # via agent-cli -pyreadline3==3.5.4 ; sys_platform == 'win32' - # via humanfriendly python-dotenv==1.2.1 # via # dotenv @@ -185,10 +170,8 @@ shellingham==1.5.4 # typer-slim starlette==0.50.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via fastapi -sympy==1.14.0 - # via - # onnxruntime - # torch +sympy==1.14.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' + # via torch tiktoken==0.12.0 ; platform_machine == 'arm64' and sys_platform == 'darwin' # via mlx-whisper torch==2.9.1 ; platform_machine == 'arm64' and sys_platform == 'darwin' diff --git a/agent_cli/_requirements/whisper-transformers.txt b/agent_cli/_requirements/whisper-transformers.txt index 54d7a6008..a9d0037c7 100644 --- a/agent_cli/_requirements/whisper-transformers.txt +++ b/agent_cli/_requirements/whisper-transformers.txt @@ -28,8 +28,6 @@ colorama==0.4.6 ; sys_platform == 'win32' # click # tqdm # uvicorn -coloredlogs==15.0.1 - # via onnxruntime dnspython==2.8.0 # via email-validator dotenv==0.9.9 @@ -51,8 +49,6 @@ filelock==3.20.3 # huggingface-hub # torch # transformers -flatbuffers==25.12.19 - # via onnxruntime fsspec==2026.1.0 # via # huggingface-hub @@ -76,8 +72,6 @@ huggingface-hub==0.36.0 # via # tokenizers # transformers -humanfriendly==10.0 - # via coloredlogs idna==3.11 # via # anyio @@ -99,9 +93,7 @@ mpmath==1.3.0 networkx==3.6.1 # via torch numpy==2.3.5 - # via - # onnxruntime - # transformers + # via transformers nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via # nvidia-cudnn-cu12 @@ -141,15 +133,10 @@ nvidia-nvshmem-cu12==3.3.20 ; platform_machine == 'x86_64' and sys_platform == ' # via torch nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux' # via torch -onnxruntime==1.20.1 - # via agent-cli packaging==25.0 # via # huggingface-hub - # onnxruntime # transformers -protobuf==6.33.4 - # via onnxruntime psutil==7.2.1 ; sys_platform == 'win32' # via agent-cli pydantic==2.12.5 @@ -169,8 +156,6 @@ pygments==2.19.2 # via rich pyperclip==1.11.0 # via agent-cli -pyreadline3==3.5.4 ; sys_platform == 'win32' - # via humanfriendly python-dotenv==1.2.1 # via # dotenv @@ -216,9 +201,7 @@ shellingham==1.5.4 starlette==0.50.0 # via fastapi sympy==1.14.0 - # via - # onnxruntime - # torch + # via torch tokenizers==0.22.2 # via transformers torch==2.9.1 diff --git a/agent_cli/server/whisper/api.py b/agent_cli/server/whisper/api.py index 91f6c12d0..205b76b25 100644 --- a/agent_cli/server/whisper/api.py +++ b/agent_cli/server/whisper/api.py @@ -27,14 +27,24 @@ def _create_vad( silence_threshold_ms: int, min_speech_duration_ms: int, ) -> VoiceActivityDetector: - """Create a VoiceActivityDetector instance.""" - from agent_cli.core.vad import VoiceActivityDetector as _VoiceActivityDetector # noqa: PLC0415 + """Create a VoiceActivityDetector instance. - return _VoiceActivityDetector( - threshold=threshold, - silence_threshold_ms=silence_threshold_ms, - min_speech_duration_ms=min_speech_duration_ms, - ) + Raises ImportError if onnxruntime is not available. + """ + from agent_cli.core.vad import VoiceActivityDetector # noqa: PLC0415 + + try: + return VoiceActivityDetector( + threshold=threshold, + silence_threshold_ms=silence_threshold_ms, + min_speech_duration_ms=min_speech_duration_ms, + ) + except ImportError as e: + msg = ( + "VAD requires onnxruntime. Install it with: " + "`pip install agent-cli[vad]` or `uv sync --extra vad`" + ) + raise ImportError(msg) from e def _split_seconds(seconds: float) -> tuple[int, int, int, int]: @@ -383,11 +393,16 @@ async def stream_transcription( # Initialize VAD if requested vad = None if use_vad: - vad = _create_vad( - threshold=vad_threshold, - silence_threshold_ms=vad_silence_ms, - min_speech_duration_ms=vad_min_speech_ms, - ) + try: + vad = _create_vad( + threshold=vad_threshold, + silence_threshold_ms=vad_silence_ms, + min_speech_duration_ms=vad_min_speech_ms, + ) + except ImportError as e: + await websocket.send_json({"type": "error", "message": str(e)}) + await websocket.close() + return try: if vad is not None: @@ -421,7 +436,7 @@ async def process_segment(segment: bytes) -> None: """Transcribe segment and send partial result.""" nonlocal final_language, total_duration result = await _transcribe_segment(manager, segment, language) - if result.text.strip(): + if result and result.text.strip(): all_segments_text.append(result.text.strip()) final_language = result.language total_duration += result.duration @@ -541,17 +556,22 @@ async def _transcribe_segment( manager: Any, segment: bytes, language: str | None, - ) -> Any: + ) -> Any | None: """Transcribe a raw PCM audio segment by wrapping it in WAV format.""" - wav_buffer = io.BytesIO() - with wave.open(wav_buffer, "wb") as wav_file: - setup_wav_file(wav_file) - wav_file.writeframes(segment) - wav_buffer.seek(0) - return await manager.transcribe( - wav_buffer.read(), - language=language, - task="transcribe", - ) + try: + # Wrap raw PCM in WAV format for transcription + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, "wb") as wav_file: + setup_wav_file(wav_file) + wav_file.writeframes(segment) + wav_buffer.seek(0) + return await manager.transcribe( + wav_buffer.read(), + language=language, + task="transcribe", + ) + except Exception: + logger.exception("Failed to transcribe segment") + return None return app diff --git a/pyproject.toml b/pyproject.toml index d0b3d176a..2349c34bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,18 +69,15 @@ vad = [ faster-whisper = [ "fastapi[standard]", "faster-whisper>=1.0.0", - "agent-cli[vad]", # VAD for streaming transcription ] mlx-whisper = [ "fastapi[standard]; sys_platform == 'darwin' and platform_machine == 'arm64'", "mlx-whisper>=0.4.0; sys_platform == 'darwin' and platform_machine == 'arm64'", - "agent-cli[vad]", # VAD for streaming transcription ] whisper-transformers = [ "fastapi[standard]", "transformers>=4.30.0", "torch>=2.0.0", - "agent-cli[vad]", # VAD for streaming transcription ] piper = [ "fastapi[standard]", diff --git a/tests/test_server_whisper.py b/tests/test_server_whisper.py index 23b2cb92d..7edb52ec9 100644 --- a/tests/test_server_whisper.py +++ b/tests/test_server_whisper.py @@ -793,3 +793,22 @@ def mock_process_chunk(_self: object, _chunk: bytes) -> tuple[bool, bytes | None assert final["type"] == "final" assert final["text"] == "First segment Second segment" assert final["duration"] == pytest.approx(1.8) + + def test_websocket_vad_not_available( + self, + client: TestClient, + ) -> None: + """Test error message when VAD is requested but not available.""" + with ( + patch( + "agent_cli.server.whisper.api._create_vad", + side_effect=ImportError("VAD not available"), + ), + client.websocket_connect( + "/v1/audio/transcriptions/stream?model=whisper-1&use_vad=true", + ) as websocket, + ): + data = websocket.receive_json() + + assert data["type"] == "error" + assert "VAD not available" in data["message"] diff --git a/uv.lock b/uv.lock index 0dad25f66..7cb9dc132 100644 --- a/uv.lock +++ b/uv.lock @@ -59,7 +59,6 @@ dev = [ faster-whisper = [ { name = "fastapi", extra = ["standard"] }, { name = "faster-whisper" }, - { name = "onnxruntime" }, ] kokoro = [ { name = "fastapi", extra = ["standard"] }, @@ -86,7 +85,6 @@ memory = [ mlx-whisper = [ { name = "fastapi", extra = ["standard"], marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, { name = "mlx-whisper", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, - { name = "onnxruntime" }, ] piper = [ { name = "fastapi", extra = ["standard"] }, @@ -121,7 +119,6 @@ vad = [ ] whisper-transformers = [ { name = "fastapi", extra = ["standard"] }, - { name = "onnxruntime" }, { name = "torch" }, { name = "transformers" }, ] @@ -173,12 +170,9 @@ requires-dist = [ { name = "mlx-whisper", marker = "platform_machine == 'arm64' and sys_platform == 'darwin' and extra == 'mlx-whisper'", specifier = ">=0.4.0" }, { name = "notebook", marker = "extra == 'dev'" }, { name = "numpy", marker = "extra == 'audio'" }, - { name = "onnxruntime", marker = "extra == 'faster-whisper'", specifier = ">=1.16.0" }, { name = "onnxruntime", marker = "extra == 'memory'", specifier = ">=1.17.0" }, - { name = "onnxruntime", marker = "extra == 'mlx-whisper'", specifier = ">=1.16.0" }, { name = "onnxruntime", marker = "extra == 'rag'", specifier = ">=1.17.0" }, { name = "onnxruntime", marker = "extra == 'vad'", specifier = ">=1.16.0" }, - { name = "onnxruntime", marker = "extra == 'whisper-transformers'", specifier = ">=1.16.0" }, { name = "openai", marker = "extra == 'memory'", specifier = ">=1.0.0" }, { name = "openai", marker = "extra == 'rag'", specifier = ">=1.0.0" }, { name = "pip", marker = "extra == 'kokoro'" }, From fc1e85048e4b525911f9056127591a2629ecb1da Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Wed, 4 Feb 2026 09:10:53 -0800 Subject: [PATCH 6/7] refactor: extract common helpers for DRY code - _parse_eos(): parse EOS marker from data, returns (audio_chunk, is_eos) - _wrap_pcm_as_wav(): wrap raw PCM in WAV format Simplifies both _stream_with_vad and _stream_buffered. --- agent_cli/server/whisper/api.py | 83 ++++++++++++--------------------- 1 file changed, 31 insertions(+), 52 deletions(-) diff --git a/agent_cli/server/whisper/api.py b/agent_cli/server/whisper/api.py index 205b76b25..877757060 100644 --- a/agent_cli/server/whisper/api.py +++ b/agent_cli/server/whisper/api.py @@ -21,6 +21,27 @@ logger = logging.getLogger(__name__) +_EOS_MARKER = b"EOS" + + +def _parse_eos(data: bytes) -> tuple[bytes, bool]: + """Parse data for EOS marker, returning (audio_chunk, is_eos).""" + if data == _EOS_MARKER: + return b"", True + if data.endswith(_EOS_MARKER): + return data[: -len(_EOS_MARKER)], True + return data, False + + +def _wrap_pcm_as_wav(pcm_data: bytes) -> bytes: + """Wrap raw PCM audio data in WAV format.""" + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, "wb") as wav_file: + setup_wav_file(wav_file) + wav_file.writeframes(pcm_data) + wav_buffer.seek(0) + return wav_buffer.read() + def _create_vad( threshold: float, @@ -429,8 +450,6 @@ async def _stream_with_vad( all_segments_text: list[str] = [] total_duration: float = 0.0 final_language: str | None = None - eos_marker = b"EOS" - eos_len = len(eos_marker) async def process_segment(segment: bytes) -> None: """Transcribe segment and send partial result.""" @@ -451,19 +470,7 @@ async def process_segment(segment: bytes) -> None: while True: data = await websocket.receive_bytes() - - # Check for end of stream - is_eos = data == eos_marker - audio_chunk = b"" - - if is_eos: - pass # No audio to process - elif data[-eos_len:] == eos_marker: - # Audio followed by EOS marker - audio_chunk = data[:-eos_len] - is_eos = True - else: - audio_chunk = data + audio_chunk, is_eos = _parse_eos(data) # Process audio chunk through VAD if audio_chunk: @@ -495,50 +502,28 @@ async def _stream_buffered( language: str | None, ) -> None: """Handle streaming transcription with buffered mode (no VAD).""" - audio_buffer = io.BytesIO() - wav_file: wave.Wave_write | None = None - eos_marker = b"EOS" - eos_len = len(eos_marker) + pcm_chunks: list[bytes] = [] while True: data = await websocket.receive_bytes() - - # Initialize WAV file on first chunk (before EOS check) - if wav_file is None: - wav_file = wave.open(audio_buffer, "wb") # noqa: SIM115 - setup_wav_file(wav_file) - - # Check for end of stream - if data == eos_marker: - break - if data[-eos_len:] == eos_marker: - # Write remaining data before EOS marker - if len(data) > eos_len: - wav_file.writeframes(data[:-eos_len]) + audio_chunk, is_eos = _parse_eos(data) + if audio_chunk: + pcm_chunks.append(audio_chunk) + if is_eos: break - wav_file.writeframes(data) - - # Close WAV file - if wav_file is not None: - wav_file.close() - - # Get audio data - audio_buffer.seek(0) - audio_data = audio_buffer.read() - - if not audio_data: + if not pcm_chunks: await websocket.send_json({"type": "error", "message": "No audio received"}) return # Transcribe + audio_data = _wrap_pcm_as_wav(b"".join(pcm_chunks)) try: result = await manager.transcribe( audio_data, language=language, task="transcribe", ) - await websocket.send_json( { "type": "final", @@ -557,16 +542,10 @@ async def _transcribe_segment( segment: bytes, language: str | None, ) -> Any | None: - """Transcribe a raw PCM audio segment by wrapping it in WAV format.""" + """Transcribe a raw PCM audio segment.""" try: - # Wrap raw PCM in WAV format for transcription - wav_buffer = io.BytesIO() - with wave.open(wav_buffer, "wb") as wav_file: - setup_wav_file(wav_file) - wav_file.writeframes(segment) - wav_buffer.seek(0) return await manager.transcribe( - wav_buffer.read(), + _wrap_pcm_as_wav(segment), language=language, task="transcribe", ) From cb6de4b3218cc924f35725021534d34f734f6eb4 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Wed, 4 Feb 2026 09:12:33 -0800 Subject: [PATCH 7/7] refactor: reuse existing pcm_to_wav from services Replace local _wrap_pcm_as_wav with pcm_to_wav from agent_cli.services. Removes duplicate code and unused wave/io imports. --- agent_cli/server/whisper/api.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/agent_cli/server/whisper/api.py b/agent_cli/server/whisper/api.py index 877757060..5151b84a9 100644 --- a/agent_cli/server/whisper/api.py +++ b/agent_cli/server/whisper/api.py @@ -3,17 +3,16 @@ from __future__ import annotations import contextlib -import io import logging -import wave from typing import TYPE_CHECKING, Annotated, Any, Literal from fastapi import FastAPI, File, Form, HTTPException, Query, UploadFile, WebSocket from fastapi.responses import PlainTextResponse from pydantic import BaseModel -from agent_cli.server.common import configure_app, create_lifespan, setup_wav_file +from agent_cli.server.common import configure_app, create_lifespan from agent_cli.server.whisper.backends.base import InvalidAudioError +from agent_cli.services import pcm_to_wav if TYPE_CHECKING: from agent_cli.core.vad import VoiceActivityDetector @@ -33,16 +32,6 @@ def _parse_eos(data: bytes) -> tuple[bytes, bool]: return data, False -def _wrap_pcm_as_wav(pcm_data: bytes) -> bytes: - """Wrap raw PCM audio data in WAV format.""" - wav_buffer = io.BytesIO() - with wave.open(wav_buffer, "wb") as wav_file: - setup_wav_file(wav_file) - wav_file.writeframes(pcm_data) - wav_buffer.seek(0) - return wav_buffer.read() - - def _create_vad( threshold: float, silence_threshold_ms: int, @@ -517,7 +506,7 @@ async def _stream_buffered( return # Transcribe - audio_data = _wrap_pcm_as_wav(b"".join(pcm_chunks)) + audio_data = pcm_to_wav(b"".join(pcm_chunks)) try: result = await manager.transcribe( audio_data, @@ -545,7 +534,7 @@ async def _transcribe_segment( """Transcribe a raw PCM audio segment.""" try: return await manager.transcribe( - _wrap_pcm_as_wav(segment), + pcm_to_wav(segment), language=language, task="transcribe", )