Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions getstream/video/rtc/pc.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def __init__(

self.track_map = {} # track_id -> (MediaRelay, original_track)
self.video_frame_trackers = {} # track_id -> VideoFrameTracker
self._video_blackholes: dict[str, tuple[MediaBlackhole, asyncio.Task]] = {}
self._video_drains: dict[
str, tuple[MediaBlackhole, asyncio.Task, MediaStreamTrack]
] = {}
self._background_tasks: set[asyncio.Task] = set()

@self.on("track")
Expand All @@ -168,6 +170,15 @@ async def on_track(track: aiortc.mediastreams.MediaStreamTrack):
tracked_track = VideoFrameTracker(track)
self.video_frame_trackers[track.id] = tracked_track

# Drain unconsumed video frames to prevent unbounded queue growth
# in RTCRtpReceiver (aiortc issue #554)
if self._drain_video_frames:
drain_proxy = relay.subscribe(tracked_track)
blackhole = MediaBlackhole()
blackhole.addTrack(drain_proxy)
drain_task = asyncio.create_task(blackhole.start())
self._video_drains[track.id] = (blackhole, drain_task, drain_proxy)

self.track_map[track.id] = (relay, tracked_track)

if track.kind == "audio":
Expand All @@ -183,14 +194,6 @@ def _emit_pcm(pcm: PcmData):

proxy = relay.subscribe(tracked_track)

# Drain unconsumed video frames to prevent unbounded queue growth
# in RTCRtpReceiver (aiortc issue #554)
if track.kind == "video" and self._drain_video_frames:
drain_proxy = relay.subscribe(tracked_track)
blackhole = MediaBlackhole()
blackhole.addTrack(drain_proxy)
drain_task = asyncio.create_task(blackhole.start())
self._video_blackholes[track.id] = (blackhole, drain_task)
self.emit("track_added", proxy, user)

@self.on("icegatheringstatechange")
Expand All @@ -205,11 +208,13 @@ def add_track_subscriber(
"""Add a new subscriber to an existing track's MediaRelay."""
track_data = self.track_map.get(track_id)

blackhole, drain_task = self._video_blackholes.pop(track_id, (None, None))
video_drain = self._video_drains.pop(track_id, None)

if blackhole and drain_task:
if video_drain is not None:
blackhole, drain_task, drain_proxy = video_drain
task = asyncio.create_task(blackhole.stop())
drain_task.cancel() # safety net if start() becomes long-lived in future aiortc
drain_proxy.stop()
drain_task.cancel()
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)

Expand All @@ -226,6 +231,7 @@ def handle_track_ended(self, track: aiortc.mediastreams.MediaStreamTrack) -> Non
del self.track_map[track.id]
if track.id in self.video_frame_trackers:
del self.video_frame_trackers[track.id]
self._video_drains.pop(track.id, None)

def get_video_frame_tracker(self) -> Optional[Any]:
"""Get a video frame tracker for stats collection.
Expand Down
8 changes: 5 additions & 3 deletions tests/rtc/test_subscriber_drain.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def subscriber_pc():
pc._drain_video_frames = True
pc.track_map = {}
pc.video_frame_trackers = {}
pc._video_blackholes = {}
pc._video_drains = {}
pc._background_tasks = set()
pc._listeners = {}
return pc
Expand All @@ -32,12 +32,14 @@ async def test_blackhole_stopped_when_subscriber_added(self, subscriber_pc):

blackhole = Mock()
blackhole.stop = AsyncMock()
subscriber_pc._video_blackholes[track_id] = (blackhole, Mock())
drain_proxy = Mock()
subscriber_pc._video_drains[track_id] = (blackhole, Mock(), drain_proxy)

subscriber_pc.add_track_subscriber(track_id)

blackhole.stop.assert_called_once()
assert track_id not in subscriber_pc._video_blackholes
drain_proxy.stop.assert_called_once()
assert track_id not in subscriber_pc._video_drains

def test_no_error_when_no_drain_exists(self, subscriber_pc):
track_id = "user123:video:0"
Expand Down