Skip to content

Commit 4e2e147

Browse files
authored
Merge pull request #232 from GetStream/fix/stop-drain-on-subscribe
fix: stop video drain when real subscriber arrives
2 parents 97bb477 + ba06494 commit 4e2e147

File tree

4 files changed

+73
-10
lines changed

4 files changed

+73
-10
lines changed

getstream/video/rtc/connection_manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,18 @@ def __init__(
5858
create: bool = True,
5959
subscription_config: Optional[SubscriptionConfig] = None,
6060
max_join_retries: int = 3,
61-
drain_video_frames: bool = False,
61+
drain_video_frames: bool = True,
6262
**kwargs: Any,
6363
):
64+
"""
65+
Args:
66+
drain_video_frames: When True, attaches a MediaBlackhole to each
67+
incoming video track so unconsumed frames are drained
68+
automatically. This prevents unbounded queue growth in
69+
RTCRtpReceiver when no subscriber is consuming the track.
70+
The drain is stopped once a real subscriber is added via
71+
add_track_subscriber.
72+
"""
6473
super().__init__()
6574

6675
# Public attributes

getstream/video/rtc/pc.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(
131131
self,
132132
connection,
133133
configuration: aiortc.RTCConfiguration,
134-
drain_video_frames: bool = False,
134+
drain_video_frames: bool = True,
135135
) -> None:
136136
logger.info(
137137
f"creating subscriber peer connection with configuration: {configuration}"
@@ -142,8 +142,8 @@ def __init__(
142142

143143
self.track_map = {} # track_id -> (MediaRelay, original_track)
144144
self.video_frame_trackers = {} # track_id -> VideoFrameTracker
145-
self._video_blackholes: dict[str, MediaBlackhole] = {}
146-
self._video_drain_tasks: dict[str, asyncio.Task] = {}
145+
self._video_blackholes: dict[str, tuple[MediaBlackhole, asyncio.Task]] = {}
146+
self._background_tasks: set[asyncio.Task] = set()
147147

148148
@self.on("track")
149149
async def on_track(track: aiortc.mediastreams.MediaStreamTrack):
@@ -189,11 +189,8 @@ def _emit_pcm(pcm: PcmData):
189189
drain_proxy = relay.subscribe(tracked_track)
190190
blackhole = MediaBlackhole()
191191
blackhole.addTrack(drain_proxy)
192-
self._video_blackholes[track.id] = blackhole
193-
self._video_drain_tasks[track.id] = asyncio.create_task(
194-
blackhole.start()
195-
)
196-
192+
drain_task = asyncio.create_task(blackhole.start())
193+
self._video_blackholes[track.id] = (blackhole, drain_task)
197194
self.emit("track_added", proxy, user)
198195

199196
@self.on("icegatheringstatechange")
@@ -208,6 +205,14 @@ def add_track_subscriber(
208205
"""Add a new subscriber to an existing track's MediaRelay."""
209206
track_data = self.track_map.get(track_id)
210207

208+
blackhole, drain_task = self._video_blackholes.pop(track_id, (None, None))
209+
210+
if blackhole and drain_task:
211+
task = asyncio.create_task(blackhole.stop())
212+
drain_task.cancel() # safety net if start() becomes long-lived in future aiortc
213+
self._background_tasks.add(task)
214+
task.add_done_callback(self._background_tasks.discard)
215+
211216
if track_data:
212217
relay, original_track = track_data
213218
return relay.subscribe(original_track, buffered=False)

getstream/video/rtc/peer_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
class PeerConnectionManager:
2929
"""Manages WebRTC peer connections for publishing and subscribing."""
3030

31-
def __init__(self, connection_manager, drain_video_frames: bool = False):
31+
def __init__(self, connection_manager, drain_video_frames: bool = True):
3232
self.connection_manager = connection_manager
3333
self._drain_video_frames = drain_video_frames
3434
self.publisher_pc: Optional[PublisherPeerConnection] = None

tests/rtc/test_subscriber_drain.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Tests for SubscriberPeerConnection video drain behavior."""
2+
3+
from unittest.mock import AsyncMock, Mock
4+
5+
import pytest
6+
from aiortc.contrib.media import MediaRelay
7+
8+
from getstream.video.rtc.pc import SubscriberPeerConnection
9+
10+
11+
@pytest.fixture
12+
def subscriber_pc():
13+
"""Create a SubscriberPeerConnection bypassing heavy parent inits."""
14+
pc = SubscriberPeerConnection.__new__(SubscriberPeerConnection)
15+
pc.connection = Mock()
16+
pc._drain_video_frames = True
17+
pc.track_map = {}
18+
pc.video_frame_trackers = {}
19+
pc._video_blackholes = {}
20+
pc._background_tasks = set()
21+
pc._listeners = {}
22+
return pc
23+
24+
25+
class TestAddTrackSubscriberStopsDrain:
26+
@pytest.mark.asyncio
27+
async def test_blackhole_stopped_when_subscriber_added(self, subscriber_pc):
28+
track_id = "user123:video:0"
29+
relay = MediaRelay()
30+
original_track = Mock()
31+
subscriber_pc.track_map[track_id] = (relay, original_track)
32+
33+
blackhole = Mock()
34+
blackhole.stop = AsyncMock()
35+
subscriber_pc._video_blackholes[track_id] = (blackhole, Mock())
36+
37+
subscriber_pc.add_track_subscriber(track_id)
38+
39+
blackhole.stop.assert_called_once()
40+
assert track_id not in subscriber_pc._video_blackholes
41+
42+
def test_no_error_when_no_drain_exists(self, subscriber_pc):
43+
track_id = "user123:video:0"
44+
relay = MediaRelay()
45+
original_track = Mock()
46+
subscriber_pc.track_map[track_id] = (relay, original_track)
47+
48+
result = subscriber_pc.add_track_subscriber(track_id)
49+
assert result is not None

0 commit comments

Comments
 (0)