Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/ai_api/proxy/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,16 +529,24 @@ async def handle_realtime(


def _model_from_session_update(raw: str) -> str | None:
"""Extract the requested model from the client's first frame. Tolerant of both
realtime shapes: a conversation `session.update` (model in `session.model`) and a
transcription `(transcription_)session.update` (model in
`session.input_audio_transcription.model`)."""
try:
ev = json.loads(raw)
except (ValueError, TypeError):
return None
if not isinstance(ev, dict) or ev.get("type") != "session.update":
if not isinstance(ev, dict) or ev.get("type") not in ("session.update", "transcription_session.update"):
return None
session = ev.get("session")
if not isinstance(session, dict):
return None
model = session.get("model")
if not model:
iat = session.get("input_audio_transcription")
if isinstance(iat, dict):
model = iat.get("model")
return model if isinstance(model, str) and model else None


Expand Down
115 changes: 67 additions & 48 deletions src/ai_api/proxy/upstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from __future__ import annotations

import asyncio
import base64
import contextlib
import json
import os
from typing import Any

import litellm
Expand Down Expand Up @@ -206,22 +206,48 @@ async def asearch(
)


def _build_realtime_url(api_base: str | None, model: str, api_version: str | None) -> str:
"""Build the Azure Foundry realtime WS URL from the resolved credential.
# Azure realtime TRANSCRIPTION needs a realtime-preview api-version (distinct from
# the chat credential's version) — 2025-04-01-preview is the current preview that
# carries gpt-realtime-whisper. Overridable via env for a region/version bump
# without a redeploy-of-code (set AZURE_REALTIME_API_VERSION).
_AZURE_REALTIME_API_VERSION = os.environ.get("AZURE_REALTIME_API_VERSION", "2025-04-01-preview")

Azure OpenAI realtime: wss://<resource>.openai.azure.com/openai/realtime?
api-version=<v>&deployment=<deployment>. We derive the wss scheme from the
https api_base and carry the bare model (deployment) name. Validated against a
real Azure realtime endpoint in quickstart (T027) — CI uses a fake upstream.

def _build_realtime_url(api_base: str | None, model: str, *, provider: str = "azure") -> str:
"""Build the upstream realtime *transcription* WS URL.

Azure: wss://<resource>/openai/realtime?api-version=<v>&deployment=<dep>&intent=transcription.
OpenAI: wss://api.openai.com/v1/realtime?intent=transcription (model goes in session.update).

`intent=transcription` is REQUIRED: without it Azure treats the socket as a
conversation session and rejects a transcription-only deployment with HTTP 400.
The exact Azure URL is validated by the admin "test model" WS smoke / quickstart.
"""
base = (api_base or "").rstrip("/")
if base.startswith("https://"):
base = "wss://" + base[len("https://"):]
elif base.startswith("http://"):
base = "ws://" + base[len("http://"):]
deployment = model.split("/", 1)[-1]
version = api_version or "2024-10-01-preview"
return f"{base}/openai/realtime?api-version={version}&deployment={deployment}"
if provider == "openai":
return f"{base or 'wss://api.openai.com'}/v1/realtime?intent=transcription"
# Azure transcription: intent=transcription and NO deployment= — verified against
# the live resource. With deployment= Azure routes to a *conversation* realtime
# session, which a transcription model can't do → HTTP 400 "OperationNotSupported".
# The model is selected by the client's session.update (input_audio_transcription.model).
return f"{base}/openai/realtime?api-version={_AZURE_REALTIME_API_VERSION}&intent=transcription"


def _realtime_reject_detail(exc: Exception) -> str | None:
"""Surface an upstream WS-handshake rejection (status + body) so the admin test /
relay reports Azure's actual complaint (e.g. unsupported api-version, deployment
not found) instead of a bare 'HTTP 400'. Returns None if exc isn't a rejection."""
resp = getattr(exc, "response", None) # websockets>=14 InvalidStatus
status = getattr(resp, "status_code", None) or getattr(exc, "status_code", None)
if status is None:
return None
body = getattr(resp, "body", b"") if resp is not None else b""
text = body.decode("utf-8", "replace").strip()[:400] if body else ""
return f"upstream realtime handshake rejected: HTTP {status}{(' — ' + text) if text else ''}"


async def open_realtime_ws(
Expand All @@ -230,22 +256,29 @@ async def open_realtime_ws(
model: str,
api_key: str,
api_base: str | None = None,
api_version: str | None = None,
api_version: str | None = None, # chat-tuned; NOT used for the realtime URL
) -> Any:
"""Open a WebSocket to the upstream provider's realtime endpoint and return the
connection (has async `send`/`recv`/`close`). Injects the credential as the
`api-key` header (Azure) — the key/endpoint never reach the downstream client
(FR-006). Phase 32 (043): /v1/realtime live transcription relay.
(FR-006). On a handshake rejection, raises a RuntimeError carrying the upstream
status + body for diagnosis. Phase 32 (043): /v1/realtime live transcription relay.
"""
import websockets

url = _build_realtime_url(api_base, model, api_version)
url = _build_realtime_url(api_base, model, provider=provider)
# Azure uses the `api-key` header; OpenAI-style uses Authorization: Bearer.
if provider == "openai":
headers = {"Authorization": f"Bearer {api_key}"}
else:
headers = {"api-key": api_key}
return await websockets.connect(url, additional_headers=headers)
try:
return await websockets.connect(url, additional_headers=headers)
except Exception as e:
detail = _realtime_reject_detail(e)
if detail is not None:
raise RuntimeError(detail) from e
raise


async def realtime_smoke(
Expand All @@ -258,47 +291,33 @@ async def realtime_smoke(
) -> dict[str, Any]:
"""Phase 32 (043): minimal realtime WS smoke for the admin "test model" button.

Opens the upstream realtime WS, runs the session handshake + a tiny silent-audio
append, and waits for the first server event. A structured non-error event proves
egress (wss:443) + key + deployment + protocol are all good — i.e. the T027
protocol-reachability check, now runnable straight from the UI. Raises on any
`error` event, connect failure, or timeout, so the test honestly reports failure.
Billable: only a couple seconds of audio.
Opens the upstream realtime transcription WS and awaits the first server event.
Azure emits `transcription_session.created` immediately on connect (no send
needed), so a structured non-error first event proves egress (wss:443) + key +
api-version + intent + the realtime-transcription capability are all good — i.e.
the T027 reachability check, runnable straight from the UI. Raises on an `error`
event, handshake rejection (status+body surfaced), or timeout. Billable: a hair.
"""
provider = model.split("/", 1)[0] if "/" in model else "azure"
deployment = model.split("/", 1)[-1]
ws = await open_realtime_ws(
provider=provider, model=model, api_key=api_key,
api_base=api_base, api_version=api_version,
)
try:
await ws.send(json.dumps({
"type": "session.update",
"session": {
"type": "transcription", "model": deployment,
"audio": {"input": {"format": {"type": "audio/pcm", "rate": 16000}}},
},
}))
pcm = b"\x00\x00" * int(16000 * 0.2) # 0.2s silence, pcm16 mono 16 kHz
await ws.send(json.dumps({
"type": "input_audio_buffer.append",
"audio": base64.b64encode(pcm).decode(),
}))
try:
async with asyncio.timeout(timeout):
while True:
raw = await ws.recv()
ev = json.loads(raw) if isinstance(raw, str) else {}
etype = ev.get("type")
if etype == "error":
msg = (ev.get("error") or {}).get("message") or "(no message)"
raise RuntimeError(f"realtime upstream error: {msg}")
# Any structured server event ⇒ the handshake/protocol works.
return {"ok": True, "first_event": etype}
except TimeoutError as e:
raise RuntimeError(
f"realtime smoke timed out after {timeout}s with no server event"
) from e
async with asyncio.timeout(timeout):
while True:
raw = await ws.recv()
ev = json.loads(raw) if isinstance(raw, str) else {}
etype = ev.get("type")
if etype == "error":
msg = (ev.get("error") or {}).get("message") or "(no message)"
raise RuntimeError(f"realtime upstream error: {msg}")
# e.g. transcription_session.created ⇒ protocol/auth/capability OK.
return {"ok": True, "first_event": etype}
except TimeoutError as e:
raise RuntimeError(
f"realtime smoke timed out after {timeout}s with no server event"
) from e
finally:
with contextlib.suppress(Exception):
await ws.close()
Expand Down
29 changes: 25 additions & 4 deletions tests/unit/test_upstream_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,28 @@ async def test_aocr_leaves_non_azure_provider_untouched():
assert m.call_args.kwargs["model"] == "mistral/mistral-ocr-latest"


# --- Phase 32 (043): realtime WS smoke (admin "test model" recipe) -----------
# --- Phase 32 (043): realtime WS URL + smoke (admin "test model" recipe) -----
def test_build_realtime_url_azure_has_intent_and_apiversion():
from ai_api.proxy.upstream import _AZURE_REALTIME_API_VERSION, _build_realtime_url

url = _build_realtime_url("https://my-foundry.openai.azure.com", "azure/gpt-realtime-whisper")
assert url.startswith("wss://my-foundry.openai.azure.com/openai/realtime?")
assert "intent=transcription" in url # REQUIRED or Azure → HTTP 400
# NO deployment= : with it, Azure routes to a conversation session the
# transcription model can't do (verified live). Model comes via session.update.
assert "deployment=" not in url
assert f"api-version={_AZURE_REALTIME_API_VERSION}" in url


def test_build_realtime_url_openai_form():
from ai_api.proxy.upstream import _build_realtime_url

url = _build_realtime_url(None, "gpt-realtime-whisper", provider="openai")
# OpenAI: model goes in session.update, not the URL; just the intent.
assert url == "wss://api.openai.com/v1/realtime?intent=transcription"


# --- realtime WS smoke (admin "test model" recipe) ---------------------------
class _FakeSmokeWS:
"""A scripted upstream realtime WS for the smoke test (sent frames + recv queue)."""

Expand Down Expand Up @@ -92,10 +113,10 @@ async def test_realtime_smoke_ok_on_first_server_event():
api_base="https://x", api_version="2024-10-01-preview",
)
assert out["ok"] is True and out["first_event"] == "transcription_session.created"
# provider derived from the slug prefix; handshake + audio append were sent.
# provider derived from the slug prefix; smoke just awaits the auto-created
# session event (Azure emits it on connect — no client send needed).
assert opener.call_args.kwargs["provider"] == "azure"
assert any("session.update" in s for s in ws.sent)
assert any("input_audio_buffer.append" in s for s in ws.sent)
assert ws.sent == []
assert ws.closed is True # always closes the upstream WS


Expand Down