diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 576b56b1..1123d20f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,5 +23,6 @@ jobs: - name: Run tests run: | python3 ./livekit-rtc/rust-sdks/download_ffi.py --output livekit-rtc/livekit/rtc/resources - pip3 install pytest ./livekit-protocol ./livekit-api ./livekit-rtc pydantic numpy + pip3 install ./livekit-protocol ./livekit-api ./livekit-rtc + pip3 install -r dev-requirements.txt pytest . --ignore=livekit-rtc/rust-sdks diff --git a/dev-requirements.txt b/dev-requirements.txt index a31ca5fd..23131b51 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -9,3 +9,8 @@ auditwheel; sys_platform == 'linux' cibuildwheel pytest +pytest-asyncio + +matplotlib +pydantic +numpy diff --git a/livekit-rtc/livekit/rtc/__init__.py b/livekit-rtc/livekit/rtc/__init__.py index 1482e846..8b3e7729 100644 --- a/livekit-rtc/livekit/rtc/__init__.py +++ b/livekit-rtc/livekit/rtc/__init__.py @@ -76,6 +76,7 @@ from .video_source import VideoSource from .video_stream import VideoFrameEvent, VideoStream from .audio_resampler import AudioResampler, AudioResamplerQuality +from .audio_mixer import AudioMixer from .apm import AudioProcessingModule from .utils import combine_audio_frames from .rpc import RpcError, RpcInvocationData @@ -148,6 +149,7 @@ "VideoFrameEvent", "VideoSource", "VideoStream", + "AudioMixer", "AudioResampler", "AudioResamplerQuality", "RpcError", diff --git a/livekit-rtc/livekit/rtc/audio_mixer.py b/livekit-rtc/livekit/rtc/audio_mixer.py new file mode 100644 index 00000000..e2f28c6b --- /dev/null +++ b/livekit-rtc/livekit/rtc/audio_mixer.py @@ -0,0 +1,200 @@ +import asyncio +import numpy as np +import contextlib +from dataclasses import dataclass +from typing import AsyncIterator, Optional +from .audio_frame import AudioFrame +from .log import logger + +_Stream = AsyncIterator[AudioFrame] + + +@dataclass +class _Contribution: + stream: _Stream + data: np.ndarray + buffer: np.ndarray + had_data: bool + exhausted: bool + + +class AudioMixer: + def __init__( + self, + sample_rate: int, + num_channels: int, + *, + blocksize: int = 0, + stream_timeout_ms: int = 100, + capacity: int = 100, + ) -> None: + """ + Initialize the AudioMixer. + + The mixer accepts multiple async audio streams and mixes them into a single output stream. + Each output frame is generated with a fixed chunk size determined by the blocksize (in samples). + If blocksize is not provided (or 0), it defaults to 100ms. + + Each input stream is processed in parallel, accumulating audio data until at least one chunk + of samples is available. If an input stream does not provide data within the specified timeout, + a warning is logged. The mixer can be closed immediately + (dropping unconsumed frames) or allowed to flush remaining data using end_input(). + + Args: + sample_rate (int): The audio sample rate in Hz. + num_channels (int): The number of audio channels. + blocksize (int, optional): The size of the audio block (in samples) for mixing. If not provided, + defaults to sample_rate // 10. + stream_timeout_ms (int, optional): The maximum wait time in milliseconds for each stream to provide + audio data before timing out. Defaults to 100 ms. + capacity (int, optional): The maximum number of mixed frames to store in the output queue. + Defaults to 100. + """ + self._streams: set[_Stream] = set() + self._buffers: dict[_Stream, np.ndarray] = {} + self._sample_rate: int = sample_rate + self._num_channels: int = num_channels + self._chunk_size: int = blocksize if blocksize > 0 else int(sample_rate // 10) + self._stream_timeout_ms: int = stream_timeout_ms + self._queue: asyncio.Queue[Optional[AudioFrame]] = asyncio.Queue(maxsize=capacity) + # _ending signals that no new streams will be added, + # but we continue processing until all streams are exhausted. + self._ending: bool = False + self._mixer_task: asyncio.Task = asyncio.create_task(self._mixer()) + + def add_stream(self, stream: AsyncIterator[AudioFrame]) -> None: + """ + Add an audio stream to the mixer. + + The stream is added to the internal set of streams and an empty buffer is initialized for it, + if not already present. + + Args: + stream (AsyncIterator[AudioFrame]): An async iterator that produces AudioFrame objects. + """ + if self._ending: + raise RuntimeError("Cannot add stream after mixer has been closed") + + self._streams.add(stream) + if stream not in self._buffers: + self._buffers[stream] = np.empty((0, self._num_channels), dtype=np.int16) + + def remove_stream(self, stream: AsyncIterator[AudioFrame]) -> None: + """ + Remove an audio stream from the mixer. + + This method removes the specified stream and its associated buffer from the mixer. + + Args: + stream (AsyncIterator[AudioFrame]): The audio stream to remove. + """ + self._streams.discard(stream) + self._buffers.pop(stream, None) + + def __aiter__(self) -> "AudioMixer": + return self + + async def __anext__(self) -> AudioFrame: + item = await self._queue.get() + if item is None: + raise StopAsyncIteration + return item + + async def aclose(self) -> None: + """ + Immediately stop mixing and close the mixer. + + This cancels the mixing task, and any unconsumed output in the queue may be dropped. + """ + self._ending = True + self._mixer_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._mixer_task + + def end_input(self) -> None: + """ + Signal that no more streams will be added. + + This method marks the mixer as closed so that it flushes any remaining buffered output before ending. + Note that existing streams will still be processed until exhausted. + """ + self._ending = True + + async def _mixer(self) -> None: + while True: + # If we're in ending mode and there are no more streams, exit. + if self._ending and not self._streams: + break + + if not self._streams: + await asyncio.sleep(0.01) + continue + + tasks = [ + self._get_contribution( + stream, + self._buffers.get(stream, np.empty((0, self._num_channels), dtype=np.int16)), + ) + for stream in list(self._streams) + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + contributions = [] + any_data = False + removals = [] + for contrib in results: + if not isinstance(contrib, _Contribution): + continue + + contributions.append(contrib.data.astype(np.float32)) + self._buffers[contrib.stream] = contrib.buffer + if contrib.had_data: + any_data = True + if contrib.exhausted and contrib.buffer.shape[0] == 0: + removals.append(contrib.stream) + + for stream in removals: + self.remove_stream(stream) + + if not any_data: + await asyncio.sleep(0.001) + continue + + mixed = np.sum(np.stack(contributions, axis=0), axis=0) + mixed = np.clip(mixed, -32768, 32767).astype(np.int16) + frame = AudioFrame( + mixed.tobytes(), self._sample_rate, self._num_channels, self._chunk_size + ) + await self._queue.put(frame) + + await self._queue.put(None) + + async def _get_contribution( + self, stream: AsyncIterator[AudioFrame], buf: np.ndarray + ) -> _Contribution: + had_data = buf.shape[0] > 0 + exhausted = False + while buf.shape[0] < self._chunk_size and not exhausted: + try: + frame = await asyncio.wait_for( + stream.__anext__(), timeout=self._stream_timeout_ms / 1000 + ) + except asyncio.TimeoutError: + logger.warning(f"AudioMixer: stream {stream} timeout, ignoring") + break + except StopAsyncIteration: + exhausted = True + break + new_data = np.frombuffer(frame.data.tobytes(), dtype=np.int16).reshape( + -1, self._num_channels + ) + buf = np.concatenate((buf, new_data), axis=0) if buf.size else new_data + had_data = True + if buf.shape[0] >= self._chunk_size: + contrib, buf = buf[: self._chunk_size], buf[self._chunk_size :] + else: + pad = np.zeros((self._chunk_size - buf.shape[0], self._num_channels), dtype=np.int16) + contrib, buf = ( + np.concatenate((buf, pad), axis=0), + np.empty((0, self._num_channels), dtype=np.int16), + ) + return _Contribution(stream, contrib, buf, had_data, exhausted) diff --git a/livekit-rtc/setup.py b/livekit-rtc/setup.py index 7cf1508c..9e4fb3f7 100644 --- a/livekit-rtc/setup.py +++ b/livekit-rtc/setup.py @@ -58,7 +58,7 @@ def finalize_options(self): license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["protobuf>=4.25.0", "types-protobuf>=3", "aiofiles>=24"], + install_requires=["protobuf>=4.25.0", "types-protobuf>=3", "aiofiles>=24", "numpy>=1.26"], package_data={ "livekit.rtc": ["_proto/*.py", "py.typed", "*.pyi", "**/*.pyi"], "livekit.rtc.resources": ["*.so", "*.dylib", "*.dll", "LICENSE.md", "*.h"], diff --git a/livekit-rtc/tests/test_mixer.py b/livekit-rtc/tests/test_mixer.py new file mode 100644 index 00000000..5ce9df3e --- /dev/null +++ b/livekit-rtc/tests/test_mixer.py @@ -0,0 +1,87 @@ +# type: ignore + +from typing import AsyncIterator +import numpy as np +import pytest +import matplotlib.pyplot as plt + +from livekit.rtc import AudioMixer +from livekit.rtc.audio_frame import AudioFrame + +SAMPLE_RATE = 48000 +# Use 100ms blocks (i.e. 1600 samples per frame) +BLOCKSIZE = SAMPLE_RATE // 10 + + +async def sine_wave_generator(freq: float, duration: float) -> AsyncIterator[AudioFrame]: + total_frames = int((duration * SAMPLE_RATE) // BLOCKSIZE) + t_frame = np.arange(BLOCKSIZE) / SAMPLE_RATE + for i in range(total_frames): + # Shift the time for each frame so that the sine wave is continuous + t = t_frame + i * BLOCKSIZE / SAMPLE_RATE + # Create a sine wave with amplitude 0.3 (to avoid clipping when summing) + signal = 0.3 * np.sin(2 * np.pi * freq * t) + # Convert from float [-0.5, 0.5] to int16 values + signal_int16 = np.int16(signal * 32767) + frame = AudioFrame( + signal_int16.tobytes(), + SAMPLE_RATE, + 1, + BLOCKSIZE, + ) + yield frame + + +@pytest.mark.asyncio +async def test_mixer_two_sine_waves(): + """ + Test that mixing two sine waves (440Hz and 880Hz) produces an output + containing both frequency components. + """ + duration = 1.0 + mixer = AudioMixer( + sample_rate=SAMPLE_RATE, + num_channels=1, + blocksize=BLOCKSIZE, + stream_timeout_ms=100, + capacity=100, + ) + stream1 = sine_wave_generator(440, duration) + stream2 = sine_wave_generator(880, duration) + mixer.add_stream(stream1) + mixer.add_stream(stream2) + mixer.end_input() + + mixed_signals = [] + async for frame in mixer: + data = np.frombuffer(frame.data.tobytes(), dtype=np.int16) + mixed_signals.append(data) + + await mixer.aclose() + + if not mixed_signals: + pytest.fail("No frames were produced by the mixer.") + + mixed_signal = np.concatenate(mixed_signals) + + plt.figure(figsize=(10, 4)) + plt.plot(mixed_signal[:1000]) # plot 1000 + plt.title("Mixed Signal") + plt.xlabel("Sample") + plt.ylabel("Amplitude") + plt.show() + + # Use FFT to analyze frequency components. + fft = np.fft.rfft(mixed_signal) + freqs = np.fft.rfftfreq(len(mixed_signal), 1 / SAMPLE_RATE) + magnitude = np.abs(fft) + + # Identify peak frequencies. We'll pick the 5 highest peaks. + peak_indices = np.argsort(magnitude)[-5:] + peak_freqs = freqs[peak_indices] + + print("Peak frequencies:", peak_freqs) + + # Assert that the peaks include 440Hz and 880Hz (with a tolerance of ±5 Hz) + assert any(np.isclose(peak_freqs, 440, atol=5)), f"Expected 440 Hz in peaks, got: {peak_freqs}" + assert any(np.isclose(peak_freqs, 880, atol=5)), f"Expected 880 Hz in peaks, got: {peak_freqs}"