Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 80 additions & 9 deletions livekit-agents/livekit/agents/inference/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from livekit import rtc

from .. import stt, utils
from .. import stt, utils, vad
from .._exceptions import (
APIConnectionError,
APIStatusError,
Expand Down Expand Up @@ -204,6 +204,31 @@ def _parse_model_string(model: str) -> tuple[str, NotGivenOr[LanguageCode]]:
return model, language


def _resolve_vad_for_model(
model: NotGivenOr[STTModels | str],
vad_instance: vad.VAD | None,
) -> vad.VAD | None:
is_speechmatics = (
is_given(model) and isinstance(model, str) and model.startswith("speechmatics/")
)
if vad_instance is not None and not is_speechmatics:
logger.warning(
"`vad` will be ignored: model %r handles endpointing server-side.",
model,
)
return None
if is_speechmatics and vad_instance is None:
try:
from livekit.plugins.silero import VAD as SileroVAD
except ImportError as e:
raise ImportError(
"livekit-plugins-silero is required: model "
f"{model!r} does not handle endpointing server-side."
) from e
vad_instance = SileroVAD.load()
return vad_instance


Comment on lines +207 to +231
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case where AgentSession has VAD wouldn't this mean we have 2 VAD instances?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, to use just 1 would require the user to store it and pass the same instance

maybe it would be helpful for the user to have separate settings for stt and session level vad, but as of right now the session vad can't be connected to stt

def _normalize_fallback(
fallback: list[FallbackModelType] | FallbackModelType,
) -> list[FallbackModel]:
Expand Down Expand Up @@ -368,6 +393,7 @@ def __init__(
extra_kwargs: NotGivenOr[SpeechmaticsOptions] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
vad: NotGivenOr[vad.VAD | None] = NOT_GIVEN,
) -> None: ...

@overload
Expand Down Expand Up @@ -410,6 +436,7 @@ def __init__(
] = NOT_GIVEN,
fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN,
conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN,
vad: NotGivenOr[vad.VAD | None] = NOT_GIVEN,
) -> None:
"""Livekit Cloud Inference STT

Expand All @@ -426,6 +453,9 @@ def __init__(
fallback (FallbackModelType, optional): Fallback models - either a list of model names,
a list of FallbackModel instances.
conn_options (APIConnectOptions, optional): Connection options for request attempts.
vad (VAD, optional): External Voice Activity Detector. When provided, each audio
frame is forwarded to the VAD and `session.finalize` is sent to the inference
gateway on end of speech. Only applicable to Speechmatics models.
"""
# Infer diarization capability from provider-specific extra_kwargs
# keys (see _DIARIZATION_EXTRA_KEYS). xAI uses "diarize" (same as
Expand All @@ -434,6 +464,15 @@ def __init__(
dict(extra_kwargs) if is_given(extra_kwargs) else None
)

# Parse language from model string if provided: "provider/model:language"
if is_given(model) and isinstance(model, str):
parsed_model, parsed_language = _parse_model_string(model)
model = parsed_model
if is_given(parsed_language) and not is_given(language):
language = parsed_language

vad = _resolve_vad_for_model(model, vad if is_given(vad) else None)

super().__init__(
capabilities=stt.STTCapabilities(
streaming=True,
Expand All @@ -444,13 +483,6 @@ def __init__(
),
)

# Parse language from model string if provided: "provider/model:language"
if is_given(model) and isinstance(model, str):
parsed_model, parsed_language = _parse_model_string(model)
model = parsed_model
if is_given(parsed_language) and not is_given(language):
language = parsed_language

lk_base_url = base_url if is_given(base_url) else get_default_inference_url()

lk_api_key = (
Expand Down Expand Up @@ -490,6 +522,7 @@ def __init__(
)

self._session = http_session
self._vad = vad
self._streams = weakref.WeakSet[SpeechStream]()

@classmethod
Expand Down Expand Up @@ -537,7 +570,12 @@ def stream(
) -> SpeechStream:
"""Create a streaming transcription session."""
options = self._sanitize_options(language=language)
stream = SpeechStream(stt=self, opts=options, conn_options=conn_options)
stream = SpeechStream(
stt=self,
opts=options,
conn_options=conn_options,
vad_instance=self._vad,
)
self._streams.add(stream)
return stream

Expand All @@ -550,7 +588,15 @@ def update_options(
) -> None:
"""Update STT configuration options."""
if is_given(model):
# Mirror __init__: strip ":language" suffix and apply if not overridden.
if isinstance(model, str):
parsed_model, parsed_language = _parse_model_string(model)
model = parsed_model
if is_given(parsed_language) and not is_given(language):
language = parsed_language

self._opts.model = model
self._vad = _resolve_vad_for_model(model, self._vad)
if is_given(language):
self._opts.language = LanguageCode(language)
if is_given(extra):
Expand Down Expand Up @@ -583,6 +629,7 @@ def __init__(
stt: STT,
opts: STTOptions,
conn_options: APIConnectOptions,
vad_instance: vad.VAD | None = None,
) -> None:
super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
self._stt: STT = stt
Expand All @@ -592,6 +639,7 @@ def __init__(
self._speaking = False
self._speech_duration: float = 0
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._vad: vad.VAD | None = vad_instance

def update_options(
self,
Expand Down Expand Up @@ -639,6 +687,7 @@ async def _run(self) -> None:
"""Main loop for streaming transcription."""
closing_ws = False
http_session = self._stt._ensure_session()
vad_stream: vad.VADStream | None = self._vad.stream() if self._vad is not None else None

@utils.log_exceptions(logger=logger)
async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
Expand All @@ -653,6 +702,8 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
async for ev in self._input_ch:
frames: list[rtc.AudioFrame] = []
if isinstance(ev, rtc.AudioFrame):
if vad_stream is not None:
vad_stream.push_frame(ev)
frames.extend(audio_bstream.push(ev.data))
elif isinstance(ev, self._FlushSentinel):
frames.extend(audio_bstream.flush())
Expand All @@ -667,12 +718,28 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
}
await ws.send_str(json.dumps(audio_msg))

if vad_stream is not None:
vad_stream.end_input()

closing_ws = True
finalize_msg = {
"type": "session.finalize",
}
await ws.send_str(json.dumps(finalize_msg))

@utils.log_exceptions(logger=logger)
async def vad_task(ws: aiohttp.ClientWebSocketResponse, stream: vad.VADStream) -> None:
async for ev in stream:
if ev.type != vad.VADEventType.END_OF_SPEECH:
continue
if ws.closed:
return
try:
await ws.send_str(json.dumps({"type": "session.finalize"}))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one questions, what will happen if VAD fires EOS on noise, will the STT return an empty or a random transcript?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe the STT will return an empty one ""

except Exception:
logger.debug("failed to send session.finalize from VAD, ws may be closing")
return

@utils.log_exceptions(logger=logger)
async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
nonlocal closing_ws
Expand Down Expand Up @@ -722,6 +789,8 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
asyncio.create_task(send_task(ws)),
asyncio.create_task(recv_task(ws)),
]
if vad_stream is not None:
tasks.append(asyncio.create_task(vad_task(ws, vad_stream)))
try:
await asyncio.gather(*tasks)
finally:
Expand All @@ -730,6 +799,8 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
self._ws = None
if ws is not None:
await ws.close()
if vad_stream is not None:
await vad_stream.aclose()

async def _connect_ws(
self, http_session: aiohttp.ClientSession
Expand Down
1 change: 1 addition & 0 deletions livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def __init__(
self._vad = vad or None
self._llm = llm or None
self._tts = tts or None

self._turn_detection = raw_turn_detection
self._interruption_detection = interruption.get("mode", NOT_GIVEN)
self._mcp_servers = mcp_servers or None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
LanguageCode,
stt,
utils,
vad,
)
from livekit.agents.types import (
NOT_GIVEN,
Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__(
known_speakers: NotGivenOr[list[SpeakerIdentifier]] = NOT_GIVEN,
sample_rate: int = 16000,
audio_encoding: AudioEncoding = AudioEncoding.PCM_S16LE,
vad: NotGivenOr[vad.VAD | None] = NOT_GIVEN,
**kwargs: Any,
):
"""Create a new instance of Speechmatics STT using the Voice SDK.
Expand Down Expand Up @@ -248,10 +250,44 @@ def __init__(

audio_encoding: Audio encoding format. Defaults to `AudioEncoding.PCM_S16LE`.

vad: Optional external Voice Activity Detector. When provided, the STT
engine's endpointing is replaced by the VAD: each audio frame is
forwarded to the VAD, and `finalize()` is called whenever the VAD
reports end of speech. Providing a VAD implicitly sets
`turn_detection_mode` to `EXTERNAL`. When `turn_detection_mode` is
`EXTERNAL` and `vad` is not provided, Silero is auto-loaded to drive
finalize. Pass `vad=None` to opt out of the auto-load if you intend
to call `finalize()` from your own logic. Defaults to NOT_GIVEN.

**kwargs: Catches deprecated parameters. A warning is logged for any
recognised deprecated name.
"""

# Resolve final turn_detection_mode — a real `vad` forces EXTERNAL.
if is_given(vad) and vad is not None and turn_detection_mode != TurnDetectionMode.EXTERNAL:
logger.info(
"External `vad` provided; overriding turn_detection_mode "
f"{turn_detection_mode.value!r} -> 'external'"
)
turn_detection_mode = TurnDetectionMode.EXTERNAL

# In EXTERNAL mode the STT does not endpoint on its own. Auto-load Silero
# so finalize() is wired up, unless the caller explicitly passed `vad=None`
# to opt out (they'll drive finalize() themselves).
if turn_detection_mode == TurnDetectionMode.EXTERNAL and not is_given(vad):
try:
from livekit.plugins.silero import VAD as SileroVAD
except ImportError as e:
raise ImportError(
"livekit-plugins-silero is required for Speechmatics with "
"turn_detection_mode=EXTERNAL (no server-side endpointing). "
"Pass `vad=None` to opt out and drive finalize() manually."
) from e
vad = SileroVAD.load()

# Normalize NOT_GIVEN -> None for downstream storage.
self._vad = vad if is_given(vad) else None

# Set default values for optional parameters
super().__init__(
capabilities=stt.STTCapabilities(
Expand Down Expand Up @@ -359,6 +395,7 @@ def stream(
conn_options=conn_options,
config=self._prepare_config(language),
id=len(self._streams),
vad_instance=self._vad,
)

# Add to the list of streams
Expand Down Expand Up @@ -575,6 +612,7 @@ def __init__(
conn_options: APIConnectOptions,
config: VoiceAgentConfig,
id: int,
vad_instance: vad.VAD | None = None,
) -> None:
super().__init__(
stt=stt,
Expand All @@ -589,6 +627,9 @@ def __init__(
self._msg_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._speech_duration: float = 0

self._vad: vad.VAD | None = vad_instance
self._vad_stream: vad.VADStream | None = None

self._tasks: list[asyncio.Task] = []

# Speaker result event
Expand Down Expand Up @@ -644,13 +685,23 @@ def add_message(message: dict[str, Any]) -> None:
await self._client.connect()
logger.debug("Connected to Speechmatics STT service")

# Open external VAD stream (if provided) before tasks start pushing frames
if self._vad is not None:
self._vad_stream = self._vad.stream()

# Audio and messaging tasks
audio_task = asyncio.create_task(self._process_audio())
message_task = asyncio.create_task(self._process_messages())

# Tasks
self._tasks = [audio_task, message_task]

# Optional VAD task: calls `client.finalize()` on end of speech
vad_task: asyncio.Task | None = None
if self._vad_stream is not None:
vad_task = asyncio.create_task(self._process_vad(self._vad_stream))
self._tasks.append(vad_task)

# Wait for tasks to complete
try:
done, pending = await asyncio.wait(self._tasks, return_when=asyncio.FIRST_COMPLETED)
Expand All @@ -666,6 +717,14 @@ def add_message(message: dict[str, Any]) -> None:
except asyncio.CancelledError:
pass

# Close the VAD stream so its task drains and exits
if self._vad_stream is not None:
await self._vad_stream.aclose()
self._vad_stream = None

if vad_task is not None:
await utils.aio.cancel_and_wait(vad_task)

# Disconnect flushes final messages from the STT engine
await self._client.disconnect()

Expand Down Expand Up @@ -695,6 +754,9 @@ async def _process_audio(self) -> None:
if isinstance(data, self._FlushSentinel):
frames = audio_bstream.flush()
else:
# Forward the original frame to the VAD before resampling/repacking
if self._vad_stream is not None:
self._vad_stream.push_frame(data)
frames = audio_bstream.write(data.data.tobytes())

# Send audio frames
Expand All @@ -703,6 +765,20 @@ async def _process_audio(self) -> None:
self._speech_duration += frame.duration
await self._client.send_audio(frame.data.tobytes())

# No more input — let the VAD flush any pending event
if self._vad_stream is not None:
self._vad_stream.end_input()

except asyncio.CancelledError:
pass

async def _process_vad(self, vad_stream: vad.VADStream) -> None:
"""Call `client.finalize()` whenever the external VAD reports end of speech."""
try:
async for ev in vad_stream:
if ev.type == vad.VADEventType.END_OF_SPEECH:
if self._client and self._client._is_connected:
self._client.finalize()
except asyncio.CancelledError:
pass

Expand Down Expand Up @@ -853,6 +929,11 @@ async def aclose(self) -> None:
except asyncio.CancelledError:
pass

# Close the VAD stream if it's still open
if self._vad_stream is not None:
await self._vad_stream.aclose()
self._vad_stream = None

# Close the client
if self._client and self._client._is_connected:
await self._client.disconnect()
Expand Down
Loading