diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000..92919c8 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,63 @@ +# Copilot Instructions — Satellite + +## Build & Test + +```bash +# Install +pip install -r requirements.txt +pip install -r requirements-dev.txt # test deps (pytest, httpx, etc.) + +# Run all tests (with coverage) +pytest + +# Run a single test file / single test +pytest tests/test_api.py +pytest tests/test_api.py::test_get_transcription_success -k "test_get_transcription_success" + +# Run the app +python main.py +``` + +Python 3.12+. No linter configured in CI — only `pytest` runs in the build workflow. +Container image uses `Containerfile` (multi-stage, python:slim base). + +## Architecture + +Satellite bridges Asterisk PBX ↔ transcription providers (Deepgram or VoxTral), publishing results over MQTT. + +### Runtime components (all in one process) + +| Module | Role | +|---|---| +| `main.py` | Entrypoint — starts the asyncio event loop for the real-time pipeline and a background thread running the FastAPI/Uvicorn HTTP server | +| `asterisk_bridge.py` | ARI WebSocket client — listens for Stasis events, creates snoop channels + external media, manages per-call lifecycle | +| `rtp_server.py` | UDP server — receives RTP audio, strips headers, routes packets to per-channel async queues by source port | +| `deepgram_connector.py` | Streams audio to Deepgram via WebSocket — interleaves two RTP channels into stereo for multichannel transcription; aggregates final transcript on hangup (real-time path only, Deepgram-only for now) | +| `mqtt_client.py` | Publishes interim/final transcription JSON to MQTT topics (`{prefix}/transcription`, `{prefix}/final`) | +| `transcription/` | **Provider abstraction** — `base.py` defines interface; `deepgram.py` and `voxtral.py` implement REST API clients; `__init__.py` factory selects provider via env var or per-request override | +| `api.py` | FastAPI app — `POST /api/get_transcription` accepts WAV uploads, calls transcription provider REST API, optionally persists to Postgres | +| `call_processor.py` | **Runs as a subprocess** (invoked from api.py via `subprocess.run`) — reads JSON from stdin, calls AI enrichment, writes results to DB | +| `ai.py` | LangChain + OpenAI — cleans transcript, generates summary + sentiment score (0-10) | +| `db.py` | PostgreSQL + pgvector — schema auto-init with threading lock; stores transcripts, state machine (`progress` → `summarizing` → `done` / `failed`), and text-embedding-3-small chunks | + +### Key data flows + +1. **Real-time path:** Asterisk → ARI WebSocket → snoop channel → RTP → `rtp_server` → `deepgram_connector` (stereo WebSocket stream) → Deepgram → `mqtt_client` (Deepgram-only for now) +2. **REST/batch path:** WAV upload → `api.py` → `transcription/` REST API (Deepgram or VoxTral) → (optionally) `db.py` persist → (optionally) `call_processor.py` subprocess → `ai.py` → `db.py` update + +### Non-obvious details + +- Two RTP streams per call (one per direction) are interleaved into a single stereo buffer for Deepgram's multichannel mode (real-time path only). +- `asterisk_bridge` detects if Asterisk swapped the RTP source ports and adjusts speaker labels accordingly. +- `call_processor` is deliberately a **subprocess** (not async task) — isolates OpenAI calls with independent timeout/logging, avoids blocking the event loop. +- DB schema initialization is guarded by a **threading lock** (not asyncio lock) because `psycopg` sync connections are used alongside the async FastAPI server. +- **Multi-provider support:** REST/batch path supports Deepgram and VoxTral. Select provider via `TRANSCRIPTION_PROVIDER` env var (default: `deepgram`) or per-request `provider=` parameter. Real-time path remains Deepgram-only. + +## Conventions + +- **Config:** Exclusively via environment variables (loaded from `.env` by `python-dotenv`). No config files or CLI args. +- **Logging:** One logger per module (`logging.getLogger(__name__)`), level controlled by `LOG_LEVEL` env var. +- **Async:** `asyncio` throughout the real-time pipeline; `asyncio.Lock` for connector close logic, `asyncio.Queue` for RTP buffer routing. Reconnection uses exponential backoff. +- **Testing:** `pytest-asyncio` with `asyncio_mode = auto`. Tests monkeypatch env vars and mock external services (Deepgram, MQTT, psycopg). A conftest auto-fixture resets `db._schema_initialized` between tests. +- **Auth:** Optional static bearer token (`API_TOKEN` env var) for `/api/*` endpoints. Accepts `Authorization: Bearer ` or `X-API-Token: `. +- **Validation:** `uniqueid` must match `\d+\.\d+` (Asterisk format). diff --git a/Containerfile b/Containerfile index c4d8c3d..266b008 100644 --- a/Containerfile +++ b/Containerfile @@ -15,6 +15,7 @@ COPY requirements.txt /tmp/requirements.txt # Copy application files COPY *.py /tmp/ COPY README.md /tmp/ +COPY transcription /tmp/transcription # Install dependencies RUN pip install --no-cache-dir --no-warn-script-location --user -r /tmp/requirements.txt @@ -36,6 +37,7 @@ COPY --from=builder /root/.local /root/.local # Copy application files COPY --from=builder /tmp/*.py /app/ COPY --from=builder /tmp/README.md /app/ +COPY --from=builder /tmp/transcription /app/transcription # Make sure scripts in .local are usable ENV PATH=/root/.local/bin:$PATH @@ -55,7 +57,9 @@ ENV ASTERISK_URL="http://127.0.0.1:8088" \ MQTT_USERNAME="satellite" \ SATELLITE_MQTT_PASSWORD="dummypassword" \ HTTP_PORT="8000" \ + TRANSCRIPTION_PROVIDER="deepgram" \ DEEPGRAM_API_KEY="" \ + MISTRAL_API_KEY="" \ LOG_LEVEL="INFO" \ PYTHONUNBUFFERED="1" diff --git a/README.md b/README.md index 12dc9d9..ce5422f 100644 --- a/README.md +++ b/README.md @@ -49,9 +49,16 @@ RTP_HEADER_SIZE=12 MQTT_URL=mqtt://127.0.0.1:1883 MQTT_TOPIC_PREFIX=satellite -# Deepgram API Key +# Transcription Provider (optional, default: deepgram) +# Options: deepgram, voxtral +TRANSCRIPTION_PROVIDER=deepgram + +# Deepgram API Key (required for Deepgram provider) DEEPGRAM_API_KEY=your_deepgram_api_key +# Mistral API Key (required for VoxTral provider) +MISTRAL_API_KEY=your_mistral_api_key + # REST API (optional) HTTP_PORT=8000 @@ -92,8 +99,10 @@ PGVECTOR_DATABASE=satellite - `MQTT_URL`: URL of the MQTT broker - `MQTT_TOPIC_PREFIX`: Prefix for MQTT topics -#### Deepgram Configuration -- `DEEPGRAM_API_KEY`: Your Deepgram API key +#### Transcription Configuration +- `TRANSCRIPTION_PROVIDER`: Choose the transcription provider (`deepgram` or `voxtral`, default: `deepgram`) +- `DEEPGRAM_API_KEY`: Your Deepgram API key (required for Deepgram provider) +- `MISTRAL_API_KEY`: Your Mistral API key (required for VoxTral provider) #### Rest API Configuration - `HTTP_PORT`: Port for the HTTP server (default: 8000) @@ -125,28 +134,38 @@ This requires the `vector` extension (pgvector) in your Postgres instance. #### `POST /api/get_transcription` -Accepts a WAV upload and returns a Deepgram transcription. +Accepts a WAV upload and returns a transcription from the configured provider (Deepgram or VoxTral). Request requirements: - Content type: multipart form upload with a `file` field (`audio/wav` or `audio/x-wav`) Optional fields (query string or multipart form fields): +- `provider`: Override the transcription provider (`deepgram` or `voxtral`). If not set, uses `TRANSCRIPTION_PROVIDER` env var (default: `deepgram`) - `uniqueid`: Asterisk-style uniqueid like `1234567890.1234` (required only when `persist=true`) - `persist`: `true|false` (default `false`) — persist raw transcript to Postgres (requires `PGVECTOR_*` env vars) - `summary`: `true|false` (default `false`) — run AI enrichment (requires `OPENAI_API_KEY` and also `persist=true` so there is a DB record to update) -- `channel0_name`, `channel1_name`: rename diarization labels in the returned transcript (replaces `Channel 0:` / `Channel 1:`) +- `channel0_name`, `channel1_name`: rename diarization labels in the returned transcript (replaces `Channel 0:` / `Channel 1:` or `Speaker 0:` / `Speaker 1:`) -Deepgram parameters: -- Most Deepgram `/v1/listen` parameters may be provided as query/form fields and are passed through to Deepgram. +Provider-specific parameters: +- **Deepgram**: Most Deepgram `/v1/listen` parameters may be provided as query/form fields (e.g., `model`, `language`, `diarize`, `punctuate`) +- **VoxTral**: Supports `model` (default: `voxtral-mini-latest`), `language`, `diarize`, `temperature`, `context_bias`, `timestamp_granularities` Example: ``` +# Using default provider (from TRANSCRIPTION_PROVIDER env var) curl -X POST http://127.0.0.1:8000/api/get_transcription \ -H 'Authorization: Bearer YOUR_TOKEN' \ -F uniqueid=1234567890.1234 \ -F persist=true \ -F summary=true \ -F file=@call.wav;type=audio/wav + +# Override provider to use VoxTral +curl -X POST http://127.0.0.1:8000/api/get_transcription \ + -H 'Authorization: Bearer YOUR_TOKEN' \ + -F provider=voxtral \ + -F diarize=true \ + -F file=@call.wav;type=audio/wav ``` Authentication: diff --git a/api.py b/api.py index 138f9c0..b7da20d 100644 --- a/api.py +++ b/api.py @@ -8,12 +8,11 @@ import sys import db +from transcription import get_provider app = FastAPI() logger = logging.getLogger("api") -DEEPGRAM_API_KEY = os.getenv("DEEPGRAM_API_KEY") # Ensure this environment variable is set - def _require_api_token_if_configured(request: Request) -> None: configured_token = (os.getenv("API_TOKEN") or "").strip() @@ -75,14 +74,6 @@ def _run_call_processor( stdout_preview = (proc.stdout or b"")[:2000].decode("utf-8", errors="replace") raise RuntimeError(f"call_processor failed rc={proc.returncode} stdout={stdout_preview!r} stderr={stderr_preview!r}") -def _get_deepgram_timeout_seconds() -> float: - raw = os.getenv("DEEPGRAM_TIMEOUT_SECONDS", "300").strip() - try: - return float(raw) - except ValueError: - logger.warning("Invalid DEEPGRAM_TIMEOUT_SECONDS=%r; defaulting to 300", raw) - return 300.0 - @api_router.post('/get_transcription') async def get_transcription( request: Request, @@ -113,6 +104,7 @@ async def get_transcription( uniqueid = (input_params.get("uniqueid") or "").strip() channel0_name = (input_params.get("channel0_name") or "").strip() channel1_name = (input_params.get("channel1_name") or "").strip() + provider_name = (input_params.get("provider") or "").strip().lower() or None # Persist only when explicitly requested. persist = (input_params.get("persist") or "false").lower() in ("1", "true", "yes") summary = (input_params.get("summary") or "false").lower() in ("1", "true", "yes") @@ -126,7 +118,7 @@ async def get_transcription( transcript_id = None if db.is_configured() and persist: - # Create/mark a DB row immediately so we can track state even if Deepgram fails. + # Create/mark a DB row immediately so we can track state even if transcription fails. try: transcript_id = await run_in_threadpool( db.upsert_transcript_progress, @@ -136,84 +128,27 @@ async def get_transcription( logger.exception("Failed to initialize transcript row for state tracking") raise HTTPException(status_code=500, detail="Failed to initialize transcript persistence") - # Valid Deepgram REST API parameters for /v1/listen endpoint - deepgram_params = { - "callback": "", - "callback_method": "", - "custom_topic": "", - "custom_topic_mode": "", - "custom_intent": "", - "custom_intent_mode": "", - "detect_entities": "", - "detect_language": "true", - "diarize": "", - "dictation": "", - "encoding": "", - "extra": "", - "filler_words": "", - "intents": "", - "keyterm": "", - "keywords": "", - "language": "", - "measurements": "", - "mip_opt_out": "", # Opts out requests from the Deepgram Model Improvement Program - "model": "nova-3", - "multichannel": "", - "numerals": "true", - "paragraphs": "true", - "profanity_filter": "", - "punctuate": "true", - "redact": "", - "replace": "", - "search": "", - "sentiment": "false", - "smart_format": "true", - "summarize": "", - "tag": "", - "topics": "", - "utterances": "", - "utt_split": "", - "version": "", - } - - headers = { - "Authorization": f"Token {DEEPGRAM_API_KEY}", - "Content-Type": file.content_type - } - - params = {} - for k, v in deepgram_params.items(): - if k in input_params and input_params[k].strip(): - params[k] = input_params[k] - elif v.strip(): - params[k] = v - + # Get transcription provider try: - deepgram_timeout_seconds = _get_deepgram_timeout_seconds() - timeout = httpx.Timeout( - connect=10.0, - read=deepgram_timeout_seconds, - write=deepgram_timeout_seconds, - pool=10.0, - ) - async with httpx.AsyncClient(timeout=timeout) as client: - response = await client.post( - "https://api.deepgram.com/v1/listen", - headers=headers, - params=params, - content=audio_bytes, - ) - # Debug: log response meta and preview + provider = get_provider(provider_name) + except ValueError as e: + logger.error("Failed to get transcription provider: %s", str(e)) + if transcript_id is not None: try: - logger.debug( - "Deepgram response: status=%s content_type=%s body_preview=%s", - response.status_code, - response.headers.get("Content-Type"), - (response.text[:500] if response is not None and hasattr(response, "text") and response.text else ""), - ) + await run_in_threadpool(db.set_transcript_state, transcript_id=transcript_id, state="failed") except Exception: - logger.debug("Failed to log Deepgram response preview") - response.raise_for_status() + logger.exception("Failed to update transcript state=failed") + raise HTTPException(status_code=400, detail=str(e)) + + # Call transcription provider + try: + result = await provider.transcribe( + audio_bytes=audio_bytes, + content_type=file.content_type, + params=input_params, + ) + raw_transcription = result.raw_transcription + detected_language = result.detected_language except httpx.HTTPStatusError as e: if transcript_id is not None: try: @@ -223,72 +158,50 @@ async def get_transcription( try: status = e.response.status_code if e.response is not None else "unknown" body_preview = e.response.text[:500] if e.response is not None and hasattr(e.response, "text") and e.response.text else "" - logger.error("Deepgram API error: status=%s body_preview=%s", status, body_preview) + logger.error("Transcription API error: status=%s body_preview=%s", status, body_preview) except Exception: - logger.error("Deepgram API error (logging failed)") - raise HTTPException(status_code=e.response.status_code, detail=f"Deepgram API error: {e.response.text}") + logger.error("Transcription API error (logging failed)") + raise HTTPException(status_code=e.response.status_code, detail=f"Transcription API error: {e.response.text}") except httpx.TimeoutException: - logger.warning("Deepgram request timed out (uniqueid=%s)", uniqueid) + logger.warning("Transcription request timed out (uniqueid=%s)", uniqueid) if transcript_id is not None: try: await run_in_threadpool(db.set_transcript_state, transcript_id=transcript_id, state="failed") except Exception: logger.exception("Failed to update transcript state=failed") - raise HTTPException(status_code=504, detail="Deepgram request timed out") + raise HTTPException(status_code=504, detail="Transcription request timed out") except httpx.RequestError as e: - logger.error("Deepgram request failed (uniqueid=%s): %s", uniqueid, str(e)) + logger.error("Transcription request failed (uniqueid=%s): %s", uniqueid, str(e)) if transcript_id is not None: try: await run_in_threadpool(db.set_transcript_state, transcript_id=transcript_id, state="failed") except Exception: logger.exception("Failed to update transcript state=failed") - raise HTTPException(status_code=502, detail="Failed to reach Deepgram") - except Exception as e: - logger.exception("Unexpected error while calling Deepgram") + raise HTTPException(status_code=502, detail="Failed to reach transcription service") + except ValueError as e: + logger.error("Failed to parse transcription response: %s", str(e)) if transcript_id is not None: try: await run_in_threadpool(db.set_transcript_state, transcript_id=transcript_id, state="failed") except Exception: logger.exception("Failed to update transcript state=failed") - raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") - - result = response.json() - detected_language = None # always define; mocks may omit this field - try: - if "paragraphs" in result["results"] and "transcript" in result["results"]["paragraphs"]: - raw_transcription = result["results"]["paragraphs"]["transcript"].strip() - elif ( - "channels" in result["results"] - and result["results"]["channels"] - and "alternatives" in result["results"]["channels"][0] - and result["results"]["channels"][0]["alternatives"] - and "paragraphs" in result["results"]["channels"][0]["alternatives"][0] - and "transcript" in result["results"]["channels"][0]["alternatives"][0]["paragraphs"] - ): - raw_transcription = ( - result["results"]["channels"][0]["alternatives"][0]["paragraphs"]["transcript"].strip() - ) - else: - logger.debug("failed to get paragraphs transcript") - logger.debug(result) - raise KeyError("paragraphs transcript not found") - if "channels" in result["results"] and "detected_language" in result["results"]["channels"][0]: - detected_language = result["results"]["channels"][0]["detected_language"] - else: - logger.debug("failed to get detected_language") - logger.debug(result) - if channel0_name: - raw_transcription = raw_transcription.replace("Channel 0:", f"{channel0_name}:") - if channel1_name: - raw_transcription = raw_transcription.replace("Channel 1:", f"{channel1_name}:") - except (KeyError, IndexError): - logger.error("Failed to parse Deepgram transcription response: %s", response.text) + raise HTTPException(status_code=500, detail=str(e)) + except Exception as e: + logger.exception("Unexpected error while calling transcription service") if transcript_id is not None: try: await run_in_threadpool(db.set_transcript_state, transcript_id=transcript_id, state="failed") except Exception: logger.exception("Failed to update transcript state=failed") - raise HTTPException(status_code=500, detail="Failed to parse transcription response.") + raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") + + # Apply channel name replacements (provider-agnostic post-processing) + if channel0_name: + raw_transcription = raw_transcription.replace("Channel 0:", f"{channel0_name}:") + raw_transcription = raw_transcription.replace("Speaker 0:", f"{channel0_name}:") + if channel1_name: + raw_transcription = raw_transcription.replace("Channel 1:", f"{channel1_name}:") + raw_transcription = raw_transcription.replace("Speaker 1:", f"{channel1_name}:") # Persist raw transcript when Postgres config is present (default) unless disabled per request. if transcript_id is not None: diff --git a/tests/test_api.py b/tests/test_api.py index db7a282..68d39a9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -66,9 +66,12 @@ def test_auth_enabled_wrong_token_returns_401(self, client, valid_wav_content): assert response.status_code == 401 - @patch('httpx.AsyncClient') - def test_auth_enabled_valid_token_allows_request(self, mock_client_class, client, valid_wav_content): + @patch('transcription.deepgram.httpx.AsyncClient') + @patch('api.get_provider') + def test_auth_enabled_valid_token_allows_request(self, mock_get_provider, mock_client_class, client, valid_wav_content, monkeypatch): """When API_TOKEN is set, /api endpoints require a matching token.""" + monkeypatch.setenv("DEEPGRAM_API_KEY", "test_key") + # Mock the Deepgram API response mock_response = Mock() mock_response.json.return_value = { @@ -78,12 +81,16 @@ def test_auth_enabled_valid_token_allows_request(self, mock_client_class, client { "alternatives": [ {"transcript": "Hello world"} - ] + ], + "detected_language": "en" } ] } } mock_response.raise_for_status = Mock() + mock_response.status_code = 200 + mock_response.text = "mock" + mock_response.headers.get.return_value = "application/json" mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) @@ -91,7 +98,11 @@ def test_auth_enabled_valid_token_allows_request(self, mock_client_class, client mock_client.__aexit__ = AsyncMock() mock_client_class.return_value = mock_client - with patch.dict(os.environ, {"API_TOKEN": "secret"}): + # Use actual provider + from transcription import get_provider as real_get_provider + mock_get_provider.side_effect = real_get_provider + + with patch.dict(os.environ, {"API_TOKEN": "secret", "DEEPGRAM_API_KEY": "test_key"}): response = client.post( "/api/get_transcription", headers={"Authorization": "Bearer secret"}, @@ -119,9 +130,12 @@ def test_missing_uniqueid(self, mock_db_configured, client, valid_wav_content): assert response.status_code == 400 assert "uniqueid" in response.json()["detail"] - @patch('httpx.AsyncClient') - def test_valid_wav_file(self, mock_client_class, client, valid_wav_content): + @patch('transcription.deepgram.httpx.AsyncClient') + @patch('api.get_provider') + def test_valid_wav_file(self, mock_get_provider, mock_client_class, client, valid_wav_content, monkeypatch): """Test transcription with a valid WAV file.""" + monkeypatch.setenv("DEEPGRAM_API_KEY", "test_key") + # Mock the Deepgram API response mock_response = Mock() mock_response.json.return_value = { @@ -131,12 +145,16 @@ def test_valid_wav_file(self, mock_client_class, client, valid_wav_content): { "alternatives": [ {"transcript": "Hello world"} - ] + ], + "detected_language": "en" } ] } } mock_response.raise_for_status = Mock() + mock_response.status_code = 200 + mock_response.text = "mock" + mock_response.headers.get.return_value = "application/json" mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) @@ -144,6 +162,10 @@ def test_valid_wav_file(self, mock_client_class, client, valid_wav_content): mock_client.__aexit__ = AsyncMock() mock_client_class.return_value = mock_client + # Use actual provider + from transcription import get_provider as real_get_provider + mock_get_provider.side_effect = real_get_provider + # Make the request response = client.post( "/api/get_transcription", @@ -157,9 +179,11 @@ def test_valid_wav_file(self, mock_client_class, client, valid_wav_content): assert "transcript" in data assert data["transcript"] == "SPEAKER 1: Hello world" - @patch('httpx.AsyncClient') - def test_persists_raw_transcript_via_threadpool(self, mock_client_class, client, valid_wav_content): + @patch('transcription.deepgram.httpx.AsyncClient') + @patch('api.get_provider') + def test_persists_raw_transcript_via_threadpool(self, mock_get_provider, mock_client_class, client, valid_wav_content, monkeypatch): """Ensure persistence path uses threadpool helper and forwards kwargs to db layer.""" + monkeypatch.setenv("DEEPGRAM_API_KEY", "test_key") # Mock the Deepgram API response mock_response = Mock() @@ -170,12 +194,16 @@ def test_persists_raw_transcript_via_threadpool(self, mock_client_class, client, { "alternatives": [ {"transcript": "Hello world"} - ] + ], + "detected_language": "en" } ] } } mock_response.raise_for_status = Mock() + mock_response.status_code = 200 + mock_response.text = "mock" + mock_response.headers.get.return_value = "application/json" mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) @@ -183,10 +211,14 @@ def test_persists_raw_transcript_via_threadpool(self, mock_client_class, client, mock_client.__aexit__ = AsyncMock() mock_client_class.return_value = mock_client + # Use actual provider + from transcription import get_provider as real_get_provider + mock_get_provider.side_effect = real_get_provider + async def fake_run_in_threadpool(func, *args, **kwargs): return func(*args, **kwargs) - with patch.dict(os.environ, {"OPENAI_API_KEY": ""}), \ + with patch.dict(os.environ, {"OPENAI_API_KEY": "", "DEEPGRAM_API_KEY": "test_key"}), \ patch("api.db.is_configured", return_value=True), \ patch("api.db.upsert_transcript_progress", return_value=123) as progress_mock, \ patch("api.db.upsert_transcript_raw", return_value=123) as upsert_mock, \ @@ -217,9 +249,12 @@ def test_invalid_file_type(self, client): assert response.status_code == 400 assert "Invalid file type" in response.json()["detail"] - @patch('httpx.AsyncClient') - def test_deepgram_api_error(self, mock_client_class, client, valid_wav_content): + @patch('transcription.deepgram.httpx.AsyncClient') + @patch('api.get_provider') + def test_deepgram_api_error(self, mock_get_provider, mock_client_class, client, valid_wav_content, monkeypatch): """Test handling of Deepgram API errors.""" + monkeypatch.setenv("DEEPGRAM_API_KEY", "test_key") + # Mock an HTTP error from Deepgram mock_response = Mock() mock_response.status_code = 401 @@ -237,6 +272,10 @@ def test_deepgram_api_error(self, mock_client_class, client, valid_wav_content): mock_client.__aexit__ = AsyncMock(return_value=False) mock_client_class.return_value = mock_client + # Use actual provider + from transcription import get_provider as real_get_provider + mock_get_provider.side_effect = real_get_provider + response = client.post( "/api/get_transcription", files={"file": ("test.wav", valid_wav_content, "audio/wav")}, @@ -244,11 +283,14 @@ def test_deepgram_api_error(self, mock_client_class, client, valid_wav_content): ) assert response.status_code == 401 - assert "Deepgram API error" in response.json()["detail"] + assert "API error" in response.json()["detail"] - @patch('httpx.AsyncClient') - def test_deepgram_timeout_returns_504(self, mock_client_class, client, valid_wav_content): + @patch('transcription.deepgram.httpx.AsyncClient') + @patch('api.get_provider') + def test_deepgram_timeout_returns_504(self, mock_get_provider, mock_client_class, client, valid_wav_content, monkeypatch): """Test that Deepgram timeouts are mapped to 504 Gateway Timeout.""" + monkeypatch.setenv("DEEPGRAM_API_KEY", "test_key") + mock_client = AsyncMock() mock_client.post = AsyncMock( side_effect=httpx.ReadTimeout("Timed out", request=Mock()) @@ -257,6 +299,10 @@ def test_deepgram_timeout_returns_504(self, mock_client_class, client, valid_wav mock_client.__aexit__ = AsyncMock(return_value=False) mock_client_class.return_value = mock_client + # Use actual provider + from transcription import get_provider as real_get_provider + mock_get_provider.side_effect = real_get_provider + response = client.post( "/api/get_transcription", files={"file": ("test.wav", valid_wav_content, "audio/wav")}, @@ -266,13 +312,19 @@ def test_deepgram_timeout_returns_504(self, mock_client_class, client, valid_wav assert response.status_code == 504 assert "timed out" in response.json()["detail"].lower() - @patch('httpx.AsyncClient') - def test_malformed_deepgram_response(self, mock_client_class, client, valid_wav_content): + @patch('transcription.deepgram.httpx.AsyncClient') + @patch('api.get_provider') + def test_malformed_deepgram_response(self, mock_get_provider, mock_client_class, client, valid_wav_content, monkeypatch): """Test handling of malformed responses from Deepgram.""" + monkeypatch.setenv("DEEPGRAM_API_KEY", "test_key") + # Mock a response with missing fields mock_response = Mock() mock_response.json.return_value = {"results": {}} mock_response.raise_for_status = Mock() + mock_response.status_code = 200 + mock_response.text = "bad response" + mock_response.headers.get.return_value = "application/json" mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) @@ -280,6 +332,10 @@ def test_malformed_deepgram_response(self, mock_client_class, client, valid_wav_ mock_client.__aexit__ = AsyncMock() mock_client_class.return_value = mock_client + # Use actual provider + from transcription import get_provider as real_get_provider + mock_get_provider.side_effect = real_get_provider + response = client.post( "/api/get_transcription", files={"file": ("test.wav", valid_wav_content, "audio/wav")}, @@ -289,9 +345,12 @@ def test_malformed_deepgram_response(self, mock_client_class, client, valid_wav_ assert response.status_code == 500 assert "Failed to parse transcription response" in response.json()["detail"] - @patch('httpx.AsyncClient') - def test_missing_paragraphs_transcript_is_error(self, mock_client_class, client, valid_wav_content): + @patch('transcription.deepgram.httpx.AsyncClient') + @patch('api.get_provider') + def test_missing_paragraphs_transcript_is_error(self, mock_get_provider, mock_client_class, client, valid_wav_content, monkeypatch): """Diarized-only: missing paragraphs transcript returns 500.""" + monkeypatch.setenv("DEEPGRAM_API_KEY", "test_key") + # Mock response without paragraphs transcript mock_response = Mock() mock_response.json.return_value = { @@ -306,6 +365,9 @@ def test_missing_paragraphs_transcript_is_error(self, mock_client_class, client, } } mock_response.raise_for_status = Mock() + mock_response.status_code = 200 + mock_response.text = "bad" + mock_response.headers.get.return_value = "application/json" mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) @@ -313,6 +375,10 @@ def test_missing_paragraphs_transcript_is_error(self, mock_client_class, client, mock_client.__aexit__ = AsyncMock() mock_client_class.return_value = mock_client + # Use actual provider + from transcription import get_provider as real_get_provider + mock_get_provider.side_effect = real_get_provider + response = client.post( "/api/get_transcription", files={"file": ("test.wav", valid_wav_content, "audio/wav")}, diff --git a/tests/test_transcription.py b/tests/test_transcription.py new file mode 100644 index 0000000..1a3166c --- /dev/null +++ b/tests/test_transcription.py @@ -0,0 +1,293 @@ +"""Tests for transcription providers.""" + +import pytest +import httpx +from unittest.mock import AsyncMock, MagicMock, patch +from transcription import get_provider, TranscriptionResult +from transcription.deepgram import DeepgramProvider +from transcription.voxtral import VoxtralProvider + + +class TestGetProvider: + """Tests for provider factory.""" + + def test_get_provider_deepgram_default(self, monkeypatch): + """Test default provider is Deepgram.""" + monkeypatch.setenv("DEEPGRAM_API_KEY", "test_key") + monkeypatch.delenv("TRANSCRIPTION_PROVIDER", raising=False) + provider = get_provider() + assert isinstance(provider, DeepgramProvider) + + def test_get_provider_deepgram_explicit(self, monkeypatch): + """Test explicit Deepgram provider.""" + monkeypatch.setenv("DEEPGRAM_API_KEY", "test_key") + provider = get_provider("deepgram") + assert isinstance(provider, DeepgramProvider) + + def test_get_provider_voxtral(self, monkeypatch): + """Test VoxTral provider.""" + monkeypatch.setenv("MISTRAL_API_KEY", "test_key") + provider = get_provider("voxtral") + assert isinstance(provider, VoxtralProvider) + + def test_get_provider_from_env(self, monkeypatch): + """Test provider selection from env var.""" + monkeypatch.setenv("TRANSCRIPTION_PROVIDER", "voxtral") + monkeypatch.setenv("MISTRAL_API_KEY", "test_key") + provider = get_provider() + assert isinstance(provider, VoxtralProvider) + + def test_get_provider_missing_api_key(self, monkeypatch): + """Test error when API key is missing.""" + monkeypatch.delenv("DEEPGRAM_API_KEY", raising=False) + with pytest.raises(ValueError, match="DEEPGRAM_API_KEY is required"): + get_provider("deepgram") + + def test_get_provider_unknown(self, monkeypatch): + """Test error with unknown provider.""" + with pytest.raises(ValueError, match="Unknown transcription provider"): + get_provider("unknown") + + +class TestDeepgramProvider: + """Tests for Deepgram provider.""" + + @pytest.mark.asyncio + async def test_transcribe_success(self, monkeypatch): + """Test successful Deepgram transcription.""" + monkeypatch.setenv("DEEPGRAM_API_KEY", "test_key") + + # Mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "mock response" + mock_response.json.return_value = { + "results": { + "channels": [{ + "alternatives": [{ + "paragraphs": { + "transcript": "Test transcription" + } + }], + "detected_language": "en" + }] + } + } + mock_response.headers.get.return_value = "application/json" + + # Mock httpx client + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.post.return_value = mock_response + + with patch("transcription.deepgram.httpx.AsyncClient", return_value=mock_client): + provider = DeepgramProvider() + result = await provider.transcribe( + audio_bytes=b"fake audio", + content_type="audio/wav", + params={} + ) + + assert isinstance(result, TranscriptionResult) + assert result.raw_transcription == "Test transcription" + assert result.detected_language == "en" + + @pytest.mark.asyncio + async def test_transcribe_paragraphs_format(self, monkeypatch): + """Test Deepgram transcription with paragraphs at top level.""" + monkeypatch.setenv("DEEPGRAM_API_KEY", "test_key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "mock" + mock_response.json.return_value = { + "results": { + "paragraphs": { + "transcript": "Top level transcript" + } + } + } + mock_response.headers.get.return_value = "application/json" + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.post.return_value = mock_response + + with patch("transcription.deepgram.httpx.AsyncClient", return_value=mock_client): + provider = DeepgramProvider() + result = await provider.transcribe( + audio_bytes=b"fake audio", + content_type="audio/wav", + params={} + ) + + assert result.raw_transcription == "Top level transcript" + + @pytest.mark.asyncio + async def test_transcribe_missing_transcript(self, monkeypatch): + """Test error when transcript is missing from response.""" + monkeypatch.setenv("DEEPGRAM_API_KEY", "test_key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "invalid response" + mock_response.json.return_value = {"results": {}} + mock_response.headers.get.return_value = "application/json" + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.post.return_value = mock_response + + with patch("transcription.deepgram.httpx.AsyncClient", return_value=mock_client): + provider = DeepgramProvider() + with pytest.raises(ValueError, match="Failed to parse transcription response"): + await provider.transcribe( + audio_bytes=b"fake audio", + content_type="audio/wav", + params={} + ) + + +class TestVoxtralProvider: + """Tests for VoxTral provider.""" + + @pytest.mark.asyncio + async def test_transcribe_success(self, monkeypatch): + """Test successful VoxTral transcription.""" + monkeypatch.setenv("MISTRAL_API_KEY", "test_key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "mock response" + mock_response.json.return_value = { + "text": "VoxTral transcription", + "language": "en", + "model": "voxtral-mini-latest" + } + mock_response.headers.get.return_value = "application/json" + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.post.return_value = mock_response + + with patch("transcription.voxtral.httpx.AsyncClient", return_value=mock_client): + provider = VoxtralProvider() + result = await provider.transcribe( + audio_bytes=b"fake audio", + content_type="audio/wav", + params={} + ) + + assert isinstance(result, TranscriptionResult) + assert result.raw_transcription == "VoxTral transcription" + assert result.detected_language == "en" + + @pytest.mark.asyncio + async def test_transcribe_with_diarization(self, monkeypatch): + """Test VoxTral transcription with speaker diarization.""" + monkeypatch.setenv("MISTRAL_API_KEY", "test_key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "mock" + mock_response.json.return_value = { + "text": "ignored", + "language": "en", + "segments": [ + {"speaker_id": "speaker_1", "text": "Hello", "start": 0.0, "end": 1.0}, + {"speaker_id": "speaker_1", "text": "world", "start": 1.0, "end": 2.0}, + {"speaker_id": "speaker_2", "text": "Hi there", "start": 2.0, "end": 3.0}, + ] + } + mock_response.headers.get.return_value = "application/json" + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.post.return_value = mock_response + + with patch("transcription.voxtral.httpx.AsyncClient", return_value=mock_client): + provider = VoxtralProvider() + result = await provider.transcribe( + audio_bytes=b"fake audio", + content_type="audio/wav", + params={"diarize": "true"} + ) + + # Should format with speaker labels + assert "speaker_1:" in result.raw_transcription + assert "speaker_2:" in result.raw_transcription + assert "Hello" in result.raw_transcription + assert "Hi there" in result.raw_transcription + + @pytest.mark.asyncio + async def test_transcribe_diarization_enabled_by_default(self, monkeypatch): + """Test that VoxTral enables diarization by default (no params passed).""" + monkeypatch.setenv("MISTRAL_API_KEY", "test_key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "mock" + mock_response.json.return_value = { + "text": "ignored", + "language": "en", + "segments": [ + {"speaker_id": "speaker_1", "text": "First speaker says hello", "start": 0.0, "end": 2.0}, + {"speaker_id": "speaker_2", "text": "Second speaker responds", "start": 2.5, "end": 4.0}, + ] + } + mock_response.headers.get.return_value = "application/json" + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.post.return_value = mock_response + + with patch("transcription.voxtral.httpx.AsyncClient", return_value=mock_client) as mock_http: + provider = VoxtralProvider() + result = await provider.transcribe( + audio_bytes=b"fake audio", + content_type="audio/wav", + params={} # No params - diarization should be enabled by default + ) + + # Verify diarization was requested + call_args = mock_http.return_value.__aenter__.return_value.post.call_args + assert call_args[1]["data"]["diarize"] is True + + # Should format with speaker labels by default + assert "speaker_1:" in result.raw_transcription + assert "speaker_2:" in result.raw_transcription + assert "First speaker says hello" in result.raw_transcription + assert "Second speaker responds" in result.raw_transcription + + @pytest.mark.asyncio + async def test_transcribe_empty_response(self, monkeypatch): + """Test that VoxTral handles empty transcription (silence) gracefully.""" + monkeypatch.setenv("MISTRAL_API_KEY", "test_key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "empty" + mock_response.json.return_value = {"text": "", "language": "en"} + mock_response.headers.get.return_value = "application/json" + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.post.return_value = mock_response + + with patch("transcription.voxtral.httpx.AsyncClient", return_value=mock_client): + provider = VoxtralProvider() + # Empty transcription is valid (no speech detected) + result = await provider.transcribe( + audio_bytes=b"fake audio", + content_type="audio/wav", + params={} + ) + assert result.raw_transcription == "" + assert result.detected_language == "en" diff --git a/transcription/__init__.py b/transcription/__init__.py new file mode 100644 index 0000000..2ae0bda --- /dev/null +++ b/transcription/__init__.py @@ -0,0 +1,41 @@ +"""Transcription provider abstraction.""" + +import logging +import os +from .base import TranscriptionProvider, TranscriptionResult +from .deepgram import DeepgramProvider +from .voxtral import VoxtralProvider + +logger = logging.getLogger(__name__) + +__all__ = ["TranscriptionProvider", "TranscriptionResult", "get_provider"] + + +def get_provider(name: str | None = None) -> TranscriptionProvider: + """ + Get a transcription provider instance. + + Args: + name: Provider name ("deepgram" or "voxtral"). If None, uses TRANSCRIPTION_PROVIDER env var. + + Returns: + TranscriptionProvider instance + + Raises: + ValueError: If provider name is unknown or required API key is missing + """ + if name is None: + name = os.getenv("TRANSCRIPTION_PROVIDER", "deepgram").strip().lower() + + if name == "deepgram": + api_key = os.getenv("DEEPGRAM_API_KEY", "").strip() + if not api_key: + raise ValueError("DEEPGRAM_API_KEY is required for Deepgram provider") + return DeepgramProvider(api_key=api_key) + elif name == "voxtral": + api_key = os.getenv("MISTRAL_API_KEY", "").strip() + if not api_key: + raise ValueError("MISTRAL_API_KEY is required for VoxTral provider") + return VoxtralProvider(api_key=api_key) + else: + raise ValueError(f"Unknown transcription provider: {name}. Valid options: deepgram, voxtral") diff --git a/transcription/base.py b/transcription/base.py new file mode 100644 index 0000000..7541a7e --- /dev/null +++ b/transcription/base.py @@ -0,0 +1,35 @@ +"""Base class for transcription providers.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass + + +@dataclass +class TranscriptionResult: + """Result from a transcription provider.""" + raw_transcription: str + detected_language: str | None = None + + +class TranscriptionProvider(ABC): + """Base class for transcription providers.""" + + @abstractmethod + async def transcribe( + self, + audio_bytes: bytes, + content_type: str, + params: dict[str, str], + ) -> TranscriptionResult: + """ + Transcribe audio bytes. + + Args: + audio_bytes: Raw audio data + content_type: MIME type (e.g., "audio/wav") + params: Provider-specific parameters + + Returns: + TranscriptionResult with raw_transcription and detected_language + """ + pass diff --git a/transcription/deepgram.py b/transcription/deepgram.py new file mode 100644 index 0000000..b08c443 --- /dev/null +++ b/transcription/deepgram.py @@ -0,0 +1,155 @@ +"""Deepgram transcription provider.""" + +import httpx +import logging +import os +from .base import TranscriptionProvider, TranscriptionResult + +logger = logging.getLogger(__name__) + + +class DeepgramProvider(TranscriptionProvider): + """Deepgram transcription provider using REST API.""" + + def __init__(self, api_key: str | None = None): + self.api_key = api_key or os.getenv("DEEPGRAM_API_KEY", "") + if not self.api_key: + raise ValueError("DEEPGRAM_API_KEY is required for Deepgram provider") + + async def transcribe( + self, + audio_bytes: bytes, + content_type: str, + params: dict[str, str], + ) -> TranscriptionResult: + """Transcribe audio using Deepgram REST API.""" + + # Valid Deepgram REST API parameters for /v1/listen endpoint + deepgram_params = { + "callback": "", + "callback_method": "", + "custom_topic": "", + "custom_topic_mode": "", + "custom_intent": "", + "custom_intent_mode": "", + "detect_entities": "", + "detect_language": "true", + "diarize": "", + "dictation": "", + "encoding": "", + "extra": "", + "filler_words": "", + "intents": "", + "keyterm": "", + "keywords": "", + "language": "", + "measurements": "", + "mip_opt_out": "", + "model": "nova-3", + "multichannel": "", + "numerals": "true", + "paragraphs": "true", + "profanity_filter": "", + "punctuate": "true", + "redact": "", + "replace": "", + "search": "", + "sentiment": "false", + "smart_format": "true", + "summarize": "", + "tag": "", + "topics": "", + "utterances": "", + "utt_split": "", + "version": "", + } + + headers = { + "Authorization": f"Token {self.api_key}", + "Content-Type": content_type + } + + # Build request params from defaults + user overrides + request_params = {} + for k, v in deepgram_params.items(): + if k in params and params[k].strip(): + request_params[k] = params[k] + elif v.strip(): + request_params[k] = v + + # Get timeout from env var + timeout_seconds = self._get_timeout_seconds() + timeout = httpx.Timeout( + connect=10.0, + read=timeout_seconds, + write=timeout_seconds, + pool=10.0, + ) + + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post( + "https://api.deepgram.com/v1/listen", + headers=headers, + params=request_params, + content=audio_bytes, + ) + + # Debug logging + try: + logger.debug( + "Deepgram response: status=%s content_type=%s body_preview=%s", + response.status_code, + response.headers.get("Content-Type"), + (response.text[:500] if response.text else ""), + ) + except Exception: + logger.debug("Failed to log Deepgram response preview") + + response.raise_for_status() + + result = response.json() + detected_language = None + + # Parse transcription from response + try: + if "paragraphs" in result["results"] and "transcript" in result["results"]["paragraphs"]: + raw_transcription = result["results"]["paragraphs"]["transcript"].strip() + elif ( + "channels" in result["results"] + and result["results"]["channels"] + and "alternatives" in result["results"]["channels"][0] + and result["results"]["channels"][0]["alternatives"] + and "paragraphs" in result["results"]["channels"][0]["alternatives"][0] + and "transcript" in result["results"]["channels"][0]["alternatives"][0]["paragraphs"] + ): + raw_transcription = ( + result["results"]["channels"][0]["alternatives"][0]["paragraphs"]["transcript"].strip() + ) + else: + logger.debug("failed to get paragraphs transcript") + logger.debug(result) + raise KeyError("paragraphs transcript not found") + + if "channels" in result["results"] and "detected_language" in result["results"]["channels"][0]: + detected_language = result["results"]["channels"][0]["detected_language"] + else: + logger.debug("failed to get detected_language") + logger.debug(result) + + except (KeyError, IndexError) as e: + logger.error("Failed to parse Deepgram transcription response: %s", response.text) + raise ValueError(f"Failed to parse transcription response: {e}") + + return TranscriptionResult( + raw_transcription=raw_transcription, + detected_language=detected_language + ) + + def _get_timeout_seconds(self) -> float: + """Get timeout from environment variable.""" + raw = os.getenv("DEEPGRAM_TIMEOUT_SECONDS", "300").strip() + try: + return float(raw) + except ValueError: + logger.warning("Invalid DEEPGRAM_TIMEOUT_SECONDS=%r; defaulting to 300", raw) + return 300.0 diff --git a/transcription/voxtral.py b/transcription/voxtral.py new file mode 100644 index 0000000..2477ce3 --- /dev/null +++ b/transcription/voxtral.py @@ -0,0 +1,164 @@ +"""VoxTral (Mistral) transcription provider.""" + +import httpx +import logging +import os +from .base import TranscriptionProvider, TranscriptionResult + +logger = logging.getLogger(__name__) + + +class VoxtralProvider(TranscriptionProvider): + """VoxTral (Mistral) transcription provider using REST API.""" + + def __init__(self, api_key: str | None = None): + self.api_key = api_key or os.getenv("MISTRAL_API_KEY", "") + if not self.api_key: + raise ValueError("MISTRAL_API_KEY is required for VoxTral provider") + + async def transcribe( + self, + audio_bytes: bytes, + content_type: str, + params: dict[str, str], + ) -> TranscriptionResult: + """Transcribe audio using Mistral VoxTral REST API.""" + + # Build multipart form data + files = { + "file": ("audio.wav", audio_bytes, content_type), + } + + # VoxTral parameters + data = { + "model": params.get("model", "voxtral-mini-latest"), + } + + # Optional parameters + if "language" in params and params["language"].strip(): + data["language"] = params["language"] + + # Enable diarization by default (for speaker labels), unless explicitly disabled + diarize_disabled = "diarize" in params and params["diarize"].strip().lower() in ("false", "0", "no") + if not diarize_disabled: + data["diarize"] = True # Boolean, not string + # VoxTral requires timestamp_granularities when diarize is enabled + if "timestamp_granularities" not in params or not params.get("timestamp_granularities", "").strip(): + data["timestamp_granularities"] = ["segment"] + + if "temperature" in params and params["temperature"].strip(): + try: + data["temperature"] = float(params["temperature"]) + except ValueError: + pass # Skip invalid temperature values + + # Context biasing (up to 100 words/phrases) + if "context_bias" in params and params["context_bias"].strip(): + # Split comma-separated list if provided + context_items = [item.strip() for item in params["context_bias"].split(",") if item.strip()] + if context_items: + # VoxTral expects multiple "context_bias" fields in the form data + for item in context_items[:100]: # limit to 100 + data.setdefault("context_bias", []) + if isinstance(data["context_bias"], list): + data["context_bias"].append(item) + + # Timestamp granularities (user-provided or set by diarize logic above) + if "timestamp_granularities" in params and params["timestamp_granularities"].strip(): + granularities = [g.strip() for g in params["timestamp_granularities"].split(",") if g.strip()] + valid_granularities = [g for g in granularities if g in ("segment", "word")] + if valid_granularities: + data["timestamp_granularities"] = valid_granularities + + headers = { + "Authorization": f"Bearer {self.api_key}", + } + + # Get timeout from env var + timeout_seconds = self._get_timeout_seconds() + timeout = httpx.Timeout( + connect=10.0, + read=timeout_seconds, + write=timeout_seconds, + pool=10.0, + ) + + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post( + "https://api.mistral.ai/v1/audio/transcriptions", + headers=headers, + files=files, + data=data, + ) + + # Debug logging + try: + logger.debug( + "VoxTral response: status=%s content_type=%s body_preview=%s", + response.status_code, + response.headers.get("Content-Type"), + (response.text[:500] if response.text else ""), + ) + except Exception: + logger.debug("Failed to log VoxTral response preview") + + response.raise_for_status() + + result = response.json() + + # Parse VoxTral response + # Response format: { "text": "...", "language": "...", "segments": [...], "model": "..." } + raw_transcription = result.get("text", "").strip() + detected_language = result.get("language") + + # If diarization is enabled and we have segments with speaker info, + # reconstruct a speaker-labeled transcript + segments = result.get("segments", []) + if segments and any("speaker_id" in seg or "speaker" in seg for seg in segments): + raw_transcription = self._format_diarized_transcript(segments) + + if not raw_transcription: + # Empty transcription is valid for silence/no speech + logger.debug("VoxTral returned empty transcription (no speech detected)") + + return TranscriptionResult( + raw_transcription=raw_transcription or "", # Return empty string instead of raising + detected_language=detected_language + ) + + def _format_diarized_transcript(self, segments: list[dict]) -> str: + """Format segments with speaker diarization into a readable transcript.""" + lines = [] + last_speaker = None + + for seg in segments: + # VoXtral uses "speaker_id" field (e.g., "speaker_1", "speaker_2") + # Fall back to "speaker" for backward compatibility with test mocks + speaker = seg.get("speaker_id") or seg.get("speaker") + text = seg.get("text", "").strip() + + if not text: + continue + + # Add speaker label when speaker changes + if speaker is not None and speaker != last_speaker: + # Format as "Speaker N:" to match common convention + lines.append(f"\n{speaker}: {text}") + last_speaker = speaker + else: + # Continue current speaker's text + if lines: + lines.append(text) + else: + lines.append(text) + + return "\n".join(lines).strip() + + def _get_timeout_seconds(self) -> float: + """Get timeout from environment variable.""" + raw = os.getenv("VOXTRAL_TIMEOUT_SECONDS", os.getenv("DEEPGRAM_TIMEOUT_SECONDS", "300")).strip() + try: + return float(raw) + except ValueError: + logger.warning("Invalid timeout value=%r; defaulting to 300", raw) + return 300.0