-
Notifications
You must be signed in to change notification settings - Fork 3.1k
(speechmatics + inference): add VAD #5750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5b882e2
621b84e
2ce48d6
06fd834
f1ea135
03621a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,7 @@ | |
|
|
||
| from livekit import rtc | ||
|
|
||
| from .. import stt, utils | ||
| from .. import stt, utils, vad | ||
| from .._exceptions import ( | ||
| APIConnectionError, | ||
| APIStatusError, | ||
|
|
@@ -204,6 +204,31 @@ def _parse_model_string(model: str) -> tuple[str, NotGivenOr[LanguageCode]]: | |
| return model, language | ||
|
|
||
|
|
||
| def _resolve_vad_for_model( | ||
| model: NotGivenOr[STTModels | str], | ||
| vad_instance: vad.VAD | None, | ||
| ) -> vad.VAD | None: | ||
| is_speechmatics = ( | ||
| is_given(model) and isinstance(model, str) and model.startswith("speechmatics/") | ||
| ) | ||
| if vad_instance is not None and not is_speechmatics: | ||
| logger.warning( | ||
| "`vad` will be ignored: model %r handles endpointing server-side.", | ||
| model, | ||
| ) | ||
| return None | ||
| if is_speechmatics and vad_instance is None: | ||
| try: | ||
| from livekit.plugins.silero import VAD as SileroVAD | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "livekit-plugins-silero is required: model " | ||
| f"{model!r} does not handle endpointing server-side." | ||
| ) from e | ||
| vad_instance = SileroVAD.load() | ||
| return vad_instance | ||
|
|
||
|
|
||
| def _normalize_fallback( | ||
| fallback: list[FallbackModelType] | FallbackModelType, | ||
| ) -> list[FallbackModel]: | ||
|
|
@@ -368,6 +393,7 @@ def __init__( | |
| extra_kwargs: NotGivenOr[SpeechmaticsOptions] = NOT_GIVEN, | ||
| fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, | ||
| conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, | ||
| vad: NotGivenOr[vad.VAD | None] = NOT_GIVEN, | ||
| ) -> None: ... | ||
|
|
||
| @overload | ||
|
|
@@ -410,6 +436,7 @@ def __init__( | |
| ] = NOT_GIVEN, | ||
| fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, | ||
| conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, | ||
| vad: NotGivenOr[vad.VAD | None] = NOT_GIVEN, | ||
| ) -> None: | ||
| """Livekit Cloud Inference STT | ||
|
|
||
|
|
@@ -426,6 +453,9 @@ def __init__( | |
| fallback (FallbackModelType, optional): Fallback models - either a list of model names, | ||
| a list of FallbackModel instances. | ||
| conn_options (APIConnectOptions, optional): Connection options for request attempts. | ||
| vad (VAD, optional): External Voice Activity Detector. When provided, each audio | ||
| frame is forwarded to the VAD and `session.finalize` is sent to the inference | ||
| gateway on end of speech. Only applicable to Speechmatics models. | ||
| """ | ||
| # Infer diarization capability from provider-specific extra_kwargs | ||
| # keys (see _DIARIZATION_EXTRA_KEYS). xAI uses "diarize" (same as | ||
|
|
@@ -434,6 +464,15 @@ def __init__( | |
| dict(extra_kwargs) if is_given(extra_kwargs) else None | ||
| ) | ||
|
|
||
| # Parse language from model string if provided: "provider/model:language" | ||
| if is_given(model) and isinstance(model, str): | ||
| parsed_model, parsed_language = _parse_model_string(model) | ||
| model = parsed_model | ||
| if is_given(parsed_language) and not is_given(language): | ||
| language = parsed_language | ||
|
|
||
| vad = _resolve_vad_for_model(model, vad if is_given(vad) else None) | ||
|
|
||
| super().__init__( | ||
| capabilities=stt.STTCapabilities( | ||
| streaming=True, | ||
|
|
@@ -444,13 +483,6 @@ def __init__( | |
| ), | ||
| ) | ||
|
|
||
| # Parse language from model string if provided: "provider/model:language" | ||
| if is_given(model) and isinstance(model, str): | ||
| parsed_model, parsed_language = _parse_model_string(model) | ||
| model = parsed_model | ||
| if is_given(parsed_language) and not is_given(language): | ||
| language = parsed_language | ||
|
|
||
| lk_base_url = base_url if is_given(base_url) else get_default_inference_url() | ||
|
|
||
| lk_api_key = ( | ||
|
|
@@ -490,6 +522,7 @@ def __init__( | |
| ) | ||
|
|
||
| self._session = http_session | ||
| self._vad = vad | ||
| self._streams = weakref.WeakSet[SpeechStream]() | ||
|
|
||
| @classmethod | ||
|
|
@@ -537,7 +570,12 @@ def stream( | |
| ) -> SpeechStream: | ||
| """Create a streaming transcription session.""" | ||
| options = self._sanitize_options(language=language) | ||
| stream = SpeechStream(stt=self, opts=options, conn_options=conn_options) | ||
| stream = SpeechStream( | ||
| stt=self, | ||
| opts=options, | ||
| conn_options=conn_options, | ||
| vad_instance=self._vad, | ||
| ) | ||
| self._streams.add(stream) | ||
| return stream | ||
|
|
||
|
|
@@ -550,7 +588,15 @@ def update_options( | |
| ) -> None: | ||
| """Update STT configuration options.""" | ||
| if is_given(model): | ||
| # Mirror __init__: strip ":language" suffix and apply if not overridden. | ||
| if isinstance(model, str): | ||
| parsed_model, parsed_language = _parse_model_string(model) | ||
| model = parsed_model | ||
| if is_given(parsed_language) and not is_given(language): | ||
| language = parsed_language | ||
|
|
||
| self._opts.model = model | ||
| self._vad = _resolve_vad_for_model(model, self._vad) | ||
| if is_given(language): | ||
| self._opts.language = LanguageCode(language) | ||
| if is_given(extra): | ||
|
|
@@ -583,6 +629,7 @@ def __init__( | |
| stt: STT, | ||
| opts: STTOptions, | ||
| conn_options: APIConnectOptions, | ||
| vad_instance: vad.VAD | None = None, | ||
| ) -> None: | ||
| super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate) | ||
| self._stt: STT = stt | ||
|
|
@@ -592,6 +639,7 @@ def __init__( | |
| self._speaking = False | ||
| self._speech_duration: float = 0 | ||
| self._ws: aiohttp.ClientWebSocketResponse | None = None | ||
| self._vad: vad.VAD | None = vad_instance | ||
|
|
||
| def update_options( | ||
| self, | ||
|
|
@@ -639,6 +687,7 @@ async def _run(self) -> None: | |
| """Main loop for streaming transcription.""" | ||
| closing_ws = False | ||
| http_session = self._stt._ensure_session() | ||
| vad_stream: vad.VADStream | None = self._vad.stream() if self._vad is not None else None | ||
|
|
||
| @utils.log_exceptions(logger=logger) | ||
| async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: | ||
|
|
@@ -653,6 +702,8 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: | |
| async for ev in self._input_ch: | ||
| frames: list[rtc.AudioFrame] = [] | ||
| if isinstance(ev, rtc.AudioFrame): | ||
| if vad_stream is not None: | ||
| vad_stream.push_frame(ev) | ||
| frames.extend(audio_bstream.push(ev.data)) | ||
| elif isinstance(ev, self._FlushSentinel): | ||
| frames.extend(audio_bstream.flush()) | ||
|
|
@@ -667,12 +718,28 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: | |
| } | ||
| await ws.send_str(json.dumps(audio_msg)) | ||
|
|
||
| if vad_stream is not None: | ||
| vad_stream.end_input() | ||
|
|
||
| closing_ws = True | ||
| finalize_msg = { | ||
| "type": "session.finalize", | ||
| } | ||
| await ws.send_str(json.dumps(finalize_msg)) | ||
|
|
||
| @utils.log_exceptions(logger=logger) | ||
| async def vad_task(ws: aiohttp.ClientWebSocketResponse, stream: vad.VADStream) -> None: | ||
| async for ev in stream: | ||
| if ev.type != vad.VADEventType.END_OF_SPEECH: | ||
| continue | ||
| if ws.closed: | ||
| return | ||
| try: | ||
| await ws.send_str(json.dumps({"type": "session.finalize"})) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one questions, what will happen if VAD fires EOS on noise, will the STT return an empty or a random transcript?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i believe the STT will return an empty one "" |
||
| except Exception: | ||
| logger.debug("failed to send session.finalize from VAD, ws may be closing") | ||
| return | ||
|
|
||
| @utils.log_exceptions(logger=logger) | ||
| async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None: | ||
| nonlocal closing_ws | ||
|
|
@@ -722,6 +789,8 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None: | |
| asyncio.create_task(send_task(ws)), | ||
| asyncio.create_task(recv_task(ws)), | ||
| ] | ||
| if vad_stream is not None: | ||
| tasks.append(asyncio.create_task(vad_task(ws, vad_stream))) | ||
| try: | ||
| await asyncio.gather(*tasks) | ||
| finally: | ||
|
|
@@ -730,6 +799,8 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None: | |
| self._ws = None | ||
| if ws is not None: | ||
| await ws.close() | ||
| if vad_stream is not None: | ||
| await vad_stream.aclose() | ||
|
|
||
| async def _connect_ws( | ||
| self, http_session: aiohttp.ClientSession | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case where AgentSession has VAD wouldn't this mean we have 2 VAD instances?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, to use just 1 would require the user to store it and pass the same instance
maybe it would be helpful for the user to have separate settings for stt and session level vad, but as of right now the session vad can't be connected to stt