Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/next-release/changeset-5c01f156.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

read frames from pre-connect audio buffer and remove await from room_io start (#2156)
66 changes: 66 additions & 0 deletions examples/voice_agents/pre_connect_audio_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import logging

from dotenv import load_dotenv

from livekit.agents import Agent, AgentSession, JobContext, JobProcess, WorkerOptions, cli
from livekit.agents.voice.room_io import RoomInputOptions, RoomIO
from livekit.plugins import deepgram, openai, silero
from livekit.plugins.turn_detector.multilingual import MultilingualModel

logger = logging.getLogger("pre-connect-audio-agent")

load_dotenv()


# This example demonstrates the pre-connect audio buffer for instant connect feature.
# It captures what users say during connection time so they don't need to wait for the connection.
# The process works in three steps:
# 1. RoomIO is set up with pre_connect_audio=True
# 2. When connecting to the room, the client sends any audio spoken before connection
# 3. This pre-connection audio is combined with new audio after connection is established


class MyAgent(Agent):
def __init__(self) -> None:
super().__init__(
instructions="Your name is Kelly. You would interact with users via voice."
"with that in mind keep your responses concise and to the point."
"You are curious and friendly, and have a sense of humor.",
)


def prewarm(proc: JobProcess):
proc.userdata["vad"] = silero.VAD.load()


async def entrypoint(ctx: JobContext):
session = AgentSession(
vad=ctx.proc.userdata["vad"],
# any combination of STT, LLM, TTS, or realtime API can be used
llm=openai.LLM(model="gpt-4o-mini"),
stt=deepgram.STT(model="nova-3", language="multi"),
tts=openai.TTS(voice="ash"),
# use LiveKit's turn detection model
turn_detection=MultilingualModel(),
)
room_io = RoomIO(
agent_session=session,
room=ctx.room,
input_options=RoomInputOptions(pre_connect_audio=True, pre_connect_audio_timeout=5.0),
)

# start room_io to register the byte stream handler and listen to the audio track publication
# then connect to room to notify the client to send pre-connect audio buffer
await room_io.start()
await ctx.connect()

# put the time consuming model/knowledge loading here
# user audio buffering starts after the room_io is started

await ctx.wait_for_participant()
await session.start(agent=MyAgent())
logger.info("agent started")


if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm))
4 changes: 2 additions & 2 deletions livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ async def start(
await chat_cli.start()

elif is_given(room) and not self._room_io:
room_input_options = copy.deepcopy(room_input_options)
room_output_options = copy.deepcopy(room_output_options)
room_input_options = copy.copy(room_input_options)
room_output_options = copy.copy(room_output_options)

if (
self.input.audio is not None
Expand Down
68 changes: 64 additions & 4 deletions livekit-agents/livekit/agents/voice/room_io/_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Iterable
from typing import Generic, TypeVar, Union

from typing_extensions import override
Expand All @@ -12,6 +12,7 @@

from ...log import logger
from ..io import AudioInput, VideoInput
from ._pre_connect_audio import _WaitPreConnectAudio

T = TypeVar("T", bound=Union[rtc.AudioFrame, rtc.VideoFrame])

Expand Down Expand Up @@ -94,9 +95,9 @@ def set_participant(self, participant: rtc.Participant | str | None) -> None:
return

self._participant_identity = participant_identity
self._close_stream()

if participant_identity is None:
self._close_stream()
return

participant = (
Expand Down Expand Up @@ -127,12 +128,13 @@ async def _forward_task(
old_task: asyncio.Task | None,
stream: rtc.VideoStream | rtc.AudioStream,
track_source: rtc.TrackSource.ValueType,
participant: rtc.RemoteParticipant,
) -> None:
if old_task:
await utils.aio.cancel_and_wait(old_task)

extra = {
"participant": self._participant_identity,
"participant": participant.identity,
"source": rtc.TrackSource.Name(track_source),
}
logger.debug("start reading stream", extra=extra)
Expand Down Expand Up @@ -172,7 +174,7 @@ def _on_track_available(
self._stream = self._create_stream(track)
self._publication = publication
self._forward_atask = asyncio.create_task(
self._forward_task(self._forward_atask, self._stream, publication.source)
self._forward_task(self._forward_atask, self._stream, publication.source, participant)
)
return True

Expand Down Expand Up @@ -202,13 +204,15 @@ def __init__(
sample_rate: int,
num_channels: int,
noise_cancellation: rtc.NoiseCancellationOptions | None,
pre_connect_audio_cb: _WaitPreConnectAudio | None,
) -> None:
_ParticipantInputStream.__init__(
self, room=room, track_source=rtc.TrackSource.SOURCE_MICROPHONE
)
self._sample_rate = sample_rate
self._num_channels = num_channels
self._noise_cancellation = noise_cancellation
self._pre_connect_audio_cb = pre_connect_audio_cb

@override
def _create_stream(self, track: rtc.Track) -> rtc.AudioStream:
Expand All @@ -219,6 +223,62 @@ def _create_stream(self, track: rtc.Track) -> rtc.AudioStream:
noise_cancellation=self._noise_cancellation,
)

@override
async def _forward_task(
self,
old_task: asyncio.Task | None,
stream: rtc.AudioStream,
track_source: rtc.TrackSource.ValueType,
participant: rtc.RemoteParticipant,
) -> None:
if self._pre_connect_audio_cb:
try:
duration = 0
frames = await self._pre_connect_audio_cb(participant)
for frame in self._resample_frames(frames):
if self._attached:
await self._data_ch.send(frame)
duration += frame.duration
if frames:
logger.debug(
"pre-connect audio buffer pushed",
extra={"duration": duration, "participant": participant.identity},
)

except asyncio.TimeoutError:
logger.warning(
"timeout waiting for pre-connect audio buffer",
extra={"participant": participant.identity},
)

except Exception as e:
logger.error(
"error reading pre-connect audio buffer",
extra={"error": e, "participant": participant.identity},
)

await super()._forward_task(old_task, stream, track_source, participant)

def _resample_frames(self, frames: Iterable[rtc.AudioFrame]) -> Iterable[rtc.AudioFrame]:
resampler: rtc.AudioResampler | None = None
for frame in frames:
if (
not resampler
and self._sample_rate is not None
and frame.sample_rate != self._sample_rate
):
resampler = rtc.AudioResampler(
input_rate=frame.sample_rate, output_rate=self._sample_rate
)

if resampler:
yield from resampler.push(frame)
else:
yield frame

if resampler:
yield from resampler.flush()


class _ParticipantVideoInputStream(_ParticipantInputStream[rtc.VideoFrame], VideoInput):
def __init__(self, room: rtc.Room) -> None:
Expand Down
60 changes: 43 additions & 17 deletions livekit-agents/livekit/agents/voice/room_io/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,60 @@ def __init__(
self._lock = asyncio.Lock()
self._audio_source = rtc.AudioSource(sample_rate, num_channels, queue_size_ms)
self._publish_options = track_publish_options
self._publication: rtc.LocalTrackPublication | None = None
self._publish_task: asyncio.Task[rtc.LocalTrackPublication] | None = None
self._published: bool = False

self._republish_task: asyncio.Task | None = None # used to republish track on reconnection
self._flush_task: asyncio.Task | None = None
self._interrupted_event = asyncio.Event()

self._pushed_duration: float = 0.0
self._interrupted: bool = False

async def _publish_track(self) -> None:
async with self._lock:
track = rtc.LocalAudioTrack.create_audio_track("roomio_audio", self._audio_source)
self._publication = await self._room.local_participant.publish_track(
track, self._publish_options
)
await self._publication.wait_for_subscription()

async def start(self) -> None:
await self._publish_track()
if self._room.isconnected():
self._on_reconnected()
self._room.on("connection_state_changed", self._on_connection_state_changed)
self._room.on("reconnected", self._on_reconnected)

def _on_connection_state_changed(self, state: rtc.ConnectionState.ValueType) -> None:
if not self._publish_task and state == rtc.ConnectionState.CONN_CONNECTED:
self._on_reconnected()

def _on_reconnected(self) -> None:
async def _publish_track() -> rtc.LocalTrackPublication:
async with self._lock:
track = rtc.LocalAudioTrack.create_audio_track("roomio_audio", self._audio_source)
publication = await self._room.local_participant.publish_track(
track, self._publish_options
)
return publication

def _on_reconnected() -> None:
if self._republish_task:
self._republish_task.cancel()
self._republish_task = asyncio.create_task(self._publish_track())
if self._publish_task:
logger.warning("cancelling publish_track")
self._publish_task.cancel()
self._publish_task = asyncio.create_task(_publish_track())

self._room.on("reconnected", _on_reconnected)
async def aclose(self) -> None:
if self._publish_task:
self._publish_task.cancel()
self._room.off("connection_state_changed", self._on_connection_state_changed)
self._room.off("reconnected", self._on_reconnected)

async def capture_frame(self, frame: rtc.AudioFrame) -> None:
if not self._published:
if not self._publish_task:
raise RuntimeError("capture_frame called before room connected")

log_task = asyncio.get_event_loop().call_later(
5, logger.warning, "audio track publishing takes longer than expected..."
)
try:
publication = await self._publish_task
finally:
log_task.cancel()

await publication.wait_for_subscription()
self._published = True

await super().capture_frame(frame)

if self._flush_task and not self._flush_task.done():
Expand Down
Loading
Loading