Skip to content

Commit f3972a6

Browse files
authored
Merge pull request #233 from GetStream/fix/drain-proxy-cleanup
Stop drain_proxy on subscriber arrival to prevent relay queue leak
2 parents 4e2e147 + d7f50c6 commit f3972a6

2 files changed

Lines changed: 23 additions & 15 deletions

File tree

getstream/video/rtc/pc.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ def __init__(
142142

143143
self.track_map = {} # track_id -> (MediaRelay, original_track)
144144
self.video_frame_trackers = {} # track_id -> VideoFrameTracker
145-
self._video_blackholes: dict[str, tuple[MediaBlackhole, asyncio.Task]] = {}
145+
self._video_drains: dict[
146+
str, tuple[MediaBlackhole, asyncio.Task, MediaStreamTrack]
147+
] = {}
146148
self._background_tasks: set[asyncio.Task] = set()
147149

148150
@self.on("track")
@@ -168,6 +170,15 @@ async def on_track(track: aiortc.mediastreams.MediaStreamTrack):
168170
tracked_track = VideoFrameTracker(track)
169171
self.video_frame_trackers[track.id] = tracked_track
170172

173+
# Drain unconsumed video frames to prevent unbounded queue growth
174+
# in RTCRtpReceiver (aiortc issue #554)
175+
if self._drain_video_frames:
176+
drain_proxy = relay.subscribe(tracked_track)
177+
blackhole = MediaBlackhole()
178+
blackhole.addTrack(drain_proxy)
179+
drain_task = asyncio.create_task(blackhole.start())
180+
self._video_drains[track.id] = (blackhole, drain_task, drain_proxy)
181+
171182
self.track_map[track.id] = (relay, tracked_track)
172183

173184
if track.kind == "audio":
@@ -183,14 +194,6 @@ def _emit_pcm(pcm: PcmData):
183194

184195
proxy = relay.subscribe(tracked_track)
185196

186-
# Drain unconsumed video frames to prevent unbounded queue growth
187-
# in RTCRtpReceiver (aiortc issue #554)
188-
if track.kind == "video" and self._drain_video_frames:
189-
drain_proxy = relay.subscribe(tracked_track)
190-
blackhole = MediaBlackhole()
191-
blackhole.addTrack(drain_proxy)
192-
drain_task = asyncio.create_task(blackhole.start())
193-
self._video_blackholes[track.id] = (blackhole, drain_task)
194197
self.emit("track_added", proxy, user)
195198

196199
@self.on("icegatheringstatechange")
@@ -205,11 +208,13 @@ def add_track_subscriber(
205208
"""Add a new subscriber to an existing track's MediaRelay."""
206209
track_data = self.track_map.get(track_id)
207210

208-
blackhole, drain_task = self._video_blackholes.pop(track_id, (None, None))
211+
video_drain = self._video_drains.pop(track_id, None)
209212

210-
if blackhole and drain_task:
213+
if video_drain is not None:
214+
blackhole, drain_task, drain_proxy = video_drain
211215
task = asyncio.create_task(blackhole.stop())
212-
drain_task.cancel() # safety net if start() becomes long-lived in future aiortc
216+
drain_proxy.stop()
217+
drain_task.cancel()
213218
self._background_tasks.add(task)
214219
task.add_done_callback(self._background_tasks.discard)
215220

@@ -226,6 +231,7 @@ def handle_track_ended(self, track: aiortc.mediastreams.MediaStreamTrack) -> Non
226231
del self.track_map[track.id]
227232
if track.id in self.video_frame_trackers:
228233
del self.video_frame_trackers[track.id]
234+
self._video_drains.pop(track.id, None)
229235

230236
def get_video_frame_tracker(self) -> Optional[Any]:
231237
"""Get a video frame tracker for stats collection.

tests/rtc/test_subscriber_drain.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def subscriber_pc():
1616
pc._drain_video_frames = True
1717
pc.track_map = {}
1818
pc.video_frame_trackers = {}
19-
pc._video_blackholes = {}
19+
pc._video_drains = {}
2020
pc._background_tasks = set()
2121
pc._listeners = {}
2222
return pc
@@ -32,12 +32,14 @@ async def test_blackhole_stopped_when_subscriber_added(self, subscriber_pc):
3232

3333
blackhole = Mock()
3434
blackhole.stop = AsyncMock()
35-
subscriber_pc._video_blackholes[track_id] = (blackhole, Mock())
35+
drain_proxy = Mock()
36+
subscriber_pc._video_drains[track_id] = (blackhole, Mock(), drain_proxy)
3637

3738
subscriber_pc.add_track_subscriber(track_id)
3839

3940
blackhole.stop.assert_called_once()
40-
assert track_id not in subscriber_pc._video_blackholes
41+
drain_proxy.stop.assert_called_once()
42+
assert track_id not in subscriber_pc._video_drains
4143

4244
def test_no_error_when_no_drain_exists(self, subscriber_pc):
4345
track_id = "user123:video:0"

0 commit comments

Comments
 (0)