Skip to content

Commit 13563e8

Browse files
authored
optimize unnecessary copy (#573)
1 parent 57ae0ec commit 13563e8

4 files changed

Lines changed: 55 additions & 18 deletions

File tree

livekit-rtc/livekit/rtc/_utils.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections import deque
1818
import ctypes
1919
import random
20-
from typing import Callable, Generator, Generic, List, TypeVar
20+
from typing import Callable, Generator, Generic, List, TypeVar, Union
2121

2222
logger = logging.getLogger("livekit")
2323

@@ -40,8 +40,35 @@ def task_done_logger(task: asyncio.Task) -> None:
4040
return
4141

4242

43-
def get_address(mv: memoryview) -> int:
44-
return ctypes.addressof(ctypes.c_char.from_buffer(mv))
43+
def _buffer_supported_or_raise(
44+
data: Union[bytes, bytearray, memoryview],
45+
) -> None:
46+
"""Validate a buffer for FFI use.
47+
48+
Raises clear errors for non-contiguous or sliced memoryviews.
49+
"""
50+
if isinstance(data, memoryview):
51+
if not data.contiguous:
52+
raise ValueError("memoryview must be contiguous")
53+
if data.nbytes != len(data.obj): # type: ignore[arg-type]
54+
raise ValueError("sliced memoryviews are not supported")
55+
elif not isinstance(data, (bytes, bytearray)):
56+
raise TypeError(f"expected bytes, bytearray, or memoryview, got {type(data)}")
57+
58+
59+
def get_address(data) -> int:
60+
if isinstance(data, memoryview):
61+
_buffer_supported_or_raise(data)
62+
if not data.readonly:
63+
return ctypes.addressof(ctypes.c_char.from_buffer(data))
64+
data = data.obj
65+
if isinstance(data, bytearray):
66+
return ctypes.addressof(ctypes.c_char.from_buffer(data))
67+
if isinstance(data, bytes):
68+
addr = ctypes.cast(ctypes.c_char_p(data), ctypes.c_void_p).value
69+
assert addr is not None
70+
return addr
71+
raise TypeError(f"expected bytes, bytearray, or memoryview, got {type(data)}")
4572

4673

4774
T = TypeVar("T")

livekit-rtc/livekit/rtc/apm.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,15 @@ def process_stream(self, data: AudioFrame) -> None:
4848
Important:
4949
Audio frames must be exactly 10 ms in duration.
5050
"""
51-
bdata = data.data.cast("b")
51+
if isinstance(data._data, bytes) or (
52+
isinstance(data._data, memoryview) and data._data.readonly
53+
):
54+
data._data = bytearray(data._data)
5255

5356
req = proto_ffi.FfiRequest()
5457
req.apm_process_stream.apm_handle = self._ffi_handle.handle
55-
req.apm_process_stream.data_ptr = get_address(memoryview(bdata))
56-
req.apm_process_stream.size = len(bdata)
58+
req.apm_process_stream.data_ptr = get_address(data._data)
59+
req.apm_process_stream.size = len(data._data)
5760
req.apm_process_stream.sample_rate = data.sample_rate
5861
req.apm_process_stream.num_channels = data.num_channels
5962

@@ -73,12 +76,15 @@ def process_reverse_stream(self, data: AudioFrame) -> None:
7376
Important:
7477
Audio frames must be exactly 10 ms in duration.
7578
"""
76-
bdata = data.data.cast("b")
79+
if isinstance(data._data, bytes) or (
80+
isinstance(data._data, memoryview) and data._data.readonly
81+
):
82+
data._data = bytearray(data._data)
7783

7884
req = proto_ffi.FfiRequest()
7985
req.apm_process_reverse_stream.apm_handle = self._ffi_handle.handle
80-
req.apm_process_reverse_stream.data_ptr = get_address(memoryview(bdata))
81-
req.apm_process_reverse_stream.size = len(bdata)
86+
req.apm_process_reverse_stream.data_ptr = get_address(data._data)
87+
req.apm_process_reverse_stream.size = len(data._data)
8288
req.apm_process_reverse_stream.sample_rate = data.sample_rate
8389
req.apm_process_reverse_stream.num_channels = data.num_channels
8490

livekit-rtc/livekit/rtc/audio_frame.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import ctypes
1616
from ._ffi_client import FfiHandle
1717
from ._proto import audio_frame_pb2 as proto_audio
18-
from ._utils import get_address
18+
from ._utils import _buffer_supported_or_raise, get_address
1919
from typing import Any, Union
2020

2121

@@ -49,19 +49,21 @@ def __init__(
4949
Raises:
5050
ValueError: If the length of `data` is smaller than the required size.
5151
"""
52-
data = memoryview(data).cast("B")
52+
_buffer_supported_or_raise(data)
5353

54-
if len(data) < num_channels * samples_per_channel * ctypes.sizeof(ctypes.c_int16):
54+
min_size = num_channels * samples_per_channel * ctypes.sizeof(ctypes.c_int16)
55+
data_len = len(data)
56+
57+
if data_len < min_size:
5558
raise ValueError(
5659
"data length must be >= num_channels * samples_per_channel * sizeof(int16)"
5760
)
5861

59-
if len(data) % ctypes.sizeof(ctypes.c_int16) != 0:
62+
if data_len % ctypes.sizeof(ctypes.c_int16) != 0:
6063
# can happen if data is bigger than needed
6164
raise ValueError("data length must be a multiple of sizeof(int16)")
6265

63-
n = len(data) // ctypes.sizeof(ctypes.c_int16)
64-
self._data = (ctypes.c_int16 * n).from_buffer_copy(data)
66+
self._data = data
6567

6668
self._sample_rate = sample_rate
6769
self._num_channels = num_channels
@@ -97,7 +99,7 @@ def _from_owned_info(owned_info: proto_audio.OwnedAudioFrameBuffer) -> "AudioFra
9799

98100
def _proto_info(self) -> proto_audio.AudioFrameBufferInfo:
99101
audio_info = proto_audio.AudioFrameBufferInfo()
100-
audio_info.data_ptr = get_address(memoryview(self._data))
102+
audio_info.data_ptr = get_address(self._data)
101103
audio_info.sample_rate = self.sample_rate
102104
audio_info.num_channels = self.num_channels
103105
audio_info.samples_per_channel = self.samples_per_channel

livekit-rtc/livekit/rtc/video_frame.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ._proto import ffi_pb2 as proto
1919
from typing import List, Optional
2020
from ._ffi_client import FfiClient, FfiHandle
21-
from ._utils import get_address
21+
from ._utils import _buffer_supported_or_raise, get_address
2222

2323
from typing import Any
2424

@@ -48,10 +48,12 @@ def __init__(
4848
(e.g., RGBA, BGRA, RGB24, etc.).
4949
data (Union[bytes, bytearray, memoryview]): The raw pixel data for the video frame.
5050
"""
51+
_buffer_supported_or_raise(data)
52+
5153
self._width = width
5254
self._height = height
5355
self._type = type
54-
self._data = bytearray(data)
56+
self._data = data
5557

5658
@property
5759
def width(self) -> int:

0 commit comments

Comments
 (0)