Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
355b47f
add pre-connect audio buffer
longcw Apr 29, 2025
b732d68
read buffer as a list
longcw Apr 29, 2025
ca15aa0
clean logs
longcw Apr 29, 2025
c161ebd
add wait_for_data for PreConnectAudioData
longcw Apr 29, 2025
b77db38
Merge remote-tracking branch 'origin/main' into longc/pre-connect-audio
longcw Apr 29, 2025
4b0073f
support multi participant
longcw Apr 29, 2025
ad3e6b7
update PreConnectAudioData
longcw Apr 29, 2025
1ea9d6b
move PreConnectAudioHandler to room io
longcw Apr 30, 2025
13171f9
update comments
longcw Apr 30, 2025
f066d26
update comments
longcw Apr 30, 2025
92fa347
clean up timeout
longcw Apr 30, 2025
ca5537e
check PRE_CONNECT_AUDIO_ATTRIBUTE == true
longcw Apr 30, 2025
7da1d04
add warning for PreConnectAudioHandler
longcw May 1, 2025
f1a7df1
Merge remote-tracking branch 'origin/main' into longc/pre-connect-aud…
longcw May 1, 2025
420f153
Merge remote-tracking branch 'origin/main' into longc/pre-connect-aud…
longcw May 1, 2025
5dc0b23
update logs and timeout
longcw May 1, 2025
439f377
Merge remote-tracking branch 'origin/main' into longc/pre-connect-aud…
longcw May 5, 2025
c964dd3
update audio buffer timeout
longcw May 5, 2025
181fed8
pass publication to forward task
longcw May 5, 2025
b3125e0
Merge remote-tracking branch 'origin/main' into longc/pre-connect-aud…
longcw May 7, 2025
fb6aefd
use track id as the buffer key
longcw May 7, 2025
ed6c8c1
check AudioTrackFeature TF_PRECONNECT_BUFFER
longcw May 8, 2025
8e733f9
Merge remote-tracking branch 'origin/main' into longc/pre-connect-aud…
longcw May 8, 2025
cc04cb1
ruff
longcw May 8, 2025
2b73403
upgrade livekit sdk
longcw May 8, 2025
ba90cf5
Merge remote-tracking branch 'origin/main' into longc/pre-connect-aud…
longcw May 13, 2025
73fe674
set pre_connect_audio default True
longcw May 13, 2025
9086ba1
remove pre connect audio buffer example
longcw May 13, 2025
56e278e
log warning for connection order only the buffer used
longcw May 13, 2025
cc093f5
set default timeout for pre-connect buffer to 3s
longcw May 14, 2025
c4308f1
Merge remote-tracking branch 'origin/main' into longc/pre-connect-aud…
longcw May 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/voice_agents/basic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ async def entrypoint(ctx: JobContext):
ctx.log_context_fields = {
"room": ctx.room.name,
}
await ctx.connect()

session = AgentSession(
vad=ctx.proc.userdata["vad"],
Expand All @@ -98,9 +97,6 @@ async def log_usage():
# shutdown callbacks are triggered when the session is over
ctx.add_shutdown_callback(log_usage)

# wait for a participant to join the room
await ctx.wait_for_participant()

await session.start(
agent=MyAgent(),
room=ctx.room,
Expand All @@ -111,6 +107,9 @@ async def log_usage():
room_output_options=RoomOutputOptions(transcription_enabled=True),
)

# join the room when agent is ready
await ctx.connect()


if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm))
87 changes: 82 additions & 5 deletions livekit-agents/livekit/agents/voice/room_io/_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

import asyncio
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Iterable
from typing import Generic, TypeVar, Union

from typing_extensions import override

import livekit.rtc as rtc
from livekit.rtc._proto.track_pb2 import AudioTrackFeature

from ...log import logger
from ...utils import aio, log_exceptions
from ..io import AudioInput, VideoInput
from ._pre_connect_audio import PreConnectAudioHandler

T = TypeVar("T", bound=Union[rtc.AudioFrame, rtc.VideoFrame])

Expand Down Expand Up @@ -126,14 +128,15 @@ async def _forward_task(
self,
old_task: asyncio.Task | None,
stream: rtc.VideoStream | rtc.AudioStream,
track_source: rtc.TrackSource.ValueType,
publication: rtc.RemoteTrackPublication,
participant: rtc.RemoteParticipant,
) -> None:
if old_task:
await aio.cancel_and_wait(old_task)

extra = {
"participant": self._participant_identity,
"source": rtc.TrackSource.Name(track_source),
"participant": participant.identity,
"source": rtc.TrackSource.Name(publication.source),
}
logger.debug("start reading stream", extra=extra)
async for event in stream:
Expand Down Expand Up @@ -172,7 +175,7 @@ def _on_track_available(
self._stream = self._create_stream(track)
self._publication = publication
self._forward_atask = asyncio.create_task(
self._forward_task(self._forward_atask, self._stream, publication.source)
self._forward_task(self._forward_atask, self._stream, publication, participant)
)
return True

Expand Down Expand Up @@ -202,13 +205,15 @@ def __init__(
sample_rate: int,
num_channels: int,
noise_cancellation: rtc.NoiseCancellationOptions | None,
pre_connect_audio_handler: PreConnectAudioHandler | None,
) -> None:
_ParticipantInputStream.__init__(
self, room=room, track_source=rtc.TrackSource.SOURCE_MICROPHONE
)
self._sample_rate = sample_rate
self._num_channels = num_channels
self._noise_cancellation = noise_cancellation
self._pre_connect_audio_handler = pre_connect_audio_handler

@override
def _create_stream(self, track: rtc.Track) -> rtc.AudioStream:
Expand All @@ -219,6 +224,78 @@ def _create_stream(self, track: rtc.Track) -> rtc.AudioStream:
noise_cancellation=self._noise_cancellation,
)

@override
async def _forward_task(
self,
old_task: asyncio.Task | None,
stream: rtc.AudioStream,
publication: rtc.RemoteTrackPublication,
participant: rtc.RemoteParticipant,
) -> None:
if (
self._pre_connect_audio_handler
and publication.track
and AudioTrackFeature.TF_PRECONNECT_BUFFER in publication.audio_features
):
try:
duration = 0
frames = await self._pre_connect_audio_handler.wait_for_data(publication.track.sid)
for frame in self._resample_frames(frames):
if self._attached:
await self._data_ch.send(frame)
duration += frame.duration
if frames:
logger.debug(
"pre-connect audio buffer pushed",
extra={
"duration": duration,
"track_id": publication.track.sid,
"participant": participant.identity,
},
)

except asyncio.TimeoutError:
logger.warning(
"timeout waiting for pre-connect audio buffer",
extra={
"duration": duration,
"track_id": publication.track.sid,
"participant": participant.identity,
},
)

except Exception as e:
logger.error(
"error reading pre-connect audio buffer",
extra={
"error": e,
"track_id": publication.track.sid,
"participant": participant.identity,
},
)

await super()._forward_task(old_task, stream, publication, participant)

def _resample_frames(self, frames: Iterable[rtc.AudioFrame]) -> Iterable[rtc.AudioFrame]:
resampler: rtc.AudioResampler | None = None
for frame in frames:
if (
not resampler
and self._sample_rate is not None
and frame.sample_rate != self._sample_rate
):
resampler = rtc.AudioResampler(
input_rate=frame.sample_rate, output_rate=self._sample_rate
)

if resampler:
yield from resampler.push(frame)
else:
yield frame

if resampler:
yield from resampler.flush()


class _ParticipantVideoInputStream(_ParticipantInputStream[rtc.VideoFrame], VideoInput):
def __init__(self, room: rtc.Room) -> None:
Expand Down
129 changes: 129 additions & 0 deletions livekit-agents/livekit/agents/voice/room_io/_pre_connect_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import asyncio
import contextlib
import time
from dataclasses import dataclass, field

from livekit import rtc

from ..agent import logger, utils

PRE_CONNECT_AUDIO_BUFFER_STREAM = "lk.agent.pre-connect-audio-buffer"


@dataclass
class _PreConnectAudioBuffer:
timestamp: float
frames: list[rtc.AudioFrame] = field(default_factory=list)


class PreConnectAudioHandler:
def __init__(self, room: rtc.Room, *, timeout: float, max_delta_s: float = 1.0):
self._room = room
self._timeout = timeout
self._max_delta_s = max_delta_s

# track id -> buffer
self._buffers: dict[str, asyncio.Future[_PreConnectAudioBuffer]] = {}
self._tasks: set[asyncio.Task] = set()

self._registered_after_connect = False

def register(self):
def _handler(reader: rtc.ByteStreamReader, participant_id: str):
task = asyncio.create_task(self._read_audio_task(reader, participant_id))
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)

def _on_timeout():
logger.warning(
"pre-connect audio received but not completed in time",
extra={"participant": participant_id},
)
if not task.done():
task.cancel()

timeout_handle = asyncio.get_event_loop().call_later(self._timeout, _on_timeout)
task.add_done_callback(lambda _: timeout_handle.cancel())

try:
if self._room.isconnected():
self._registered_after_connect = True
self._room.register_byte_stream_handler(PRE_CONNECT_AUDIO_BUFFER_STREAM, _handler)
except ValueError:
logger.warning(
f"pre-connect audio handler for {PRE_CONNECT_AUDIO_BUFFER_STREAM} "
"already registered, ignoring"
)

async def aclose(self):
self._room.unregister_byte_stream_handler(PRE_CONNECT_AUDIO_BUFFER_STREAM)
await utils.aio.cancel_and_wait(*self._tasks)

async def wait_for_data(self, track_id: str) -> list[rtc.AudioFrame]:
# the handler is enabled by default, log a warning only if the buffer is actually used
if self._registered_after_connect:
logger.warning(
"pre-connect audio handler registered after room connection, "
"start RoomIO before ctx.connect() to ensure seamless audio buffer.",
extra={"track_id": track_id},
)

self._buffers.setdefault(track_id, asyncio.Future())
fut = self._buffers[track_id]

try:
if fut.done():
buf = fut.result()
if (delta := time.time() - buf.timestamp) > self._max_delta_s:
logger.warning(
"pre-connect audio buffer is too old",
extra={"track_id": track_id, "delta_time": delta},
)
return []
return buf.frames

buf = await asyncio.wait_for(fut, self._timeout)
return buf.frames
finally:
self._buffers.pop(track_id)

@utils.log_exceptions(logger=logger)
async def _read_audio_task(self, reader: rtc.ByteStreamReader, participant_id: str):
if not (track_id := reader.info.attributes.get("trackId")):
logger.warning(
"pre-connect audio received but no trackId", extra={"participant": participant_id}
)
return

if (fut := self._buffers.get(track_id)) and fut.done():
# reset the buffer if it's already set
self._buffers.pop(track_id)
self._buffers.setdefault(track_id, asyncio.Future())
fut = self._buffers[track_id]

buf = _PreConnectAudioBuffer(timestamp=time.time())
try:
sample_rate = int(reader.info.attributes["sampleRate"])
num_channels = int(reader.info.attributes["channels"])

duration = 0
audio_stream = utils.audio.AudioByteStream(sample_rate, num_channels)
async for chunk in reader:
for frame in audio_stream.push(chunk):
buf.frames.append(frame)
duration += frame.duration

for frame in audio_stream.flush():
buf.frames.append(frame)
duration += frame.duration

logger.debug(
"pre-connect audio received",
extra={"duration": duration, "track_id": track_id, "participant": participant_id},
)

with contextlib.suppress(asyncio.InvalidStateError):
fut.set_result(buf)
except Exception as e:
with contextlib.suppress(asyncio.InvalidStateError):
fut.set_exception(e)
18 changes: 18 additions & 0 deletions livekit-agents/livekit/agents/voice/room_io/room_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..events import AgentStateChangedEvent, UserInputTranscribedEvent
from ..io import AudioInput, AudioOutput, TextOutput, VideoInput
from ..transcription import TranscriptSynchronizer
from ._pre_connect_audio import PreConnectAudioHandler

if TYPE_CHECKING:
from ..agent_session import AgentSession
Expand Down Expand Up @@ -70,6 +71,10 @@ class RoomInputOptions:
participant_identity: NotGivenOr[str] = NOT_GIVEN
"""The participant to link to. If not provided, link to the first participant.
Can be overridden by the `participant` argument of RoomIO constructor or `set_participant`."""
pre_connect_audio: bool = True
"""Pre-connect audio enabled or not."""
pre_connect_audio_timeout: float = 3.0
"""The pre-connect audio will be ignored if it doesn't arrive within this time."""


@dataclass
Expand Down Expand Up @@ -125,8 +130,17 @@ def __init__(
self._tasks: set[asyncio.Task] = set()
self._update_state_task: asyncio.Task | None = None

self._pre_connect_audio_handler: PreConnectAudioHandler | None = None

async def start(self) -> None:
# -- create inputs --
if self._input_options.pre_connect_audio:
self._pre_connect_audio_handler = PreConnectAudioHandler(
room=self._room,
timeout=self._input_options.pre_connect_audio_timeout,
)
self._pre_connect_audio_handler.register()

if self._input_options.text_enabled:
try:
self._room.register_text_stream_handler(TOPIC_CHAT, self._on_user_text_input)
Expand All @@ -144,6 +158,7 @@ async def start(self) -> None:
sample_rate=self._input_options.audio_sample_rate,
num_channels=self._input_options.audio_num_channels,
noise_cancellation=self._input_options.noise_cancellation,
pre_connect_audio_handler=self._pre_connect_audio_handler,
)

# -- create outputs --
Expand Down Expand Up @@ -209,6 +224,9 @@ async def aclose(self) -> None:
if self._init_atask:
await utils.aio.cancel_and_wait(self._init_atask)

if self._pre_connect_audio_handler:
await self._pre_connect_audio_handler.aclose()

if self._audio_input:
await self._audio_input.aclose()
if self._video_input:
Expand Down
2 changes: 1 addition & 1 deletion livekit-agents/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
]
dependencies = [
"click~=8.1",
"livekit>=1.0.6,<2",
"livekit>=1.0.7,<2",
"livekit-api>=1.0.2,<2",
"livekit-protocol~=1.0",
"protobuf>=3",
Expand Down
Loading