Skip to content

Commit 01f50fc

Browse files
committed
fix: stop video drain when real subscriber arrives
Enable drain by default and clean up blackhole when add_track_subscriber is called, preventing stale drains from competing with actual consumers.
1 parent 9bb07b2 commit 01f50fc

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

getstream/video/rtc/pc.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(
131131
self,
132132
connection,
133133
configuration: aiortc.RTCConfiguration,
134-
drain_video_frames: bool = False,
134+
drain_video_frames: bool = True,
135135
) -> None:
136136
logger.info(
137137
f"creating subscriber peer connection with configuration: {configuration}"
@@ -208,6 +208,12 @@ def add_track_subscriber(
208208
"""Add a new subscriber to an existing track's MediaRelay."""
209209
track_data = self.track_map.get(track_id)
210210

211+
self._video_drain_tasks.pop(track_id, None)
212+
blackhole = self._video_blackholes.pop(track_id, None)
213+
214+
if blackhole:
215+
asyncio.ensure_future(blackhole.stop())
216+
211217
if track_data:
212218
relay, original_track = track_data
213219
return relay.subscribe(original_track, buffered=False)

tests/rtc/test_subscriber_drain.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Tests for SubscriberPeerConnection video drain behavior."""
2+
3+
from unittest.mock import AsyncMock, Mock
4+
5+
import pytest
6+
from aiortc.contrib.media import MediaRelay
7+
8+
from getstream.video.rtc.pc import SubscriberPeerConnection
9+
10+
11+
@pytest.fixture
12+
def subscriber_pc():
13+
"""Create a SubscriberPeerConnection bypassing heavy parent inits."""
14+
pc = SubscriberPeerConnection.__new__(SubscriberPeerConnection)
15+
pc.connection = Mock()
16+
pc._drain_video_frames = True
17+
pc.track_map = {}
18+
pc.video_frame_trackers = {}
19+
pc._video_blackholes = {}
20+
pc._video_drain_tasks = {}
21+
pc._listeners = {}
22+
return pc
23+
24+
25+
class TestAddTrackSubscriberStopsDrain:
26+
def test_blackhole_stopped_when_subscriber_added(self, subscriber_pc):
27+
track_id = "user123:video:0"
28+
relay = MediaRelay()
29+
original_track = Mock()
30+
subscriber_pc.track_map[track_id] = (relay, original_track)
31+
32+
blackhole = Mock()
33+
blackhole.stop = AsyncMock()
34+
subscriber_pc._video_blackholes[track_id] = blackhole
35+
subscriber_pc._video_drain_tasks[track_id] = Mock()
36+
37+
subscriber_pc.add_track_subscriber(track_id)
38+
39+
blackhole.stop.assert_called_once()
40+
assert track_id not in subscriber_pc._video_blackholes
41+
assert track_id not in subscriber_pc._video_drain_tasks
42+
43+
def test_no_error_when_no_drain_exists(self, subscriber_pc):
44+
track_id = "user123:video:0"
45+
relay = MediaRelay()
46+
original_track = Mock()
47+
subscriber_pc.track_map[track_id] = (relay, original_track)
48+
49+
result = subscriber_pc.add_track_subscriber(track_id)
50+
assert result is not None

0 commit comments

Comments
 (0)