diff --git a/getstream/video/rtc/pc.py b/getstream/video/rtc/pc.py index ab7302eb..feab89a7 100644 --- a/getstream/video/rtc/pc.py +++ b/getstream/video/rtc/pc.py @@ -142,7 +142,9 @@ def __init__( self.track_map = {} # track_id -> (MediaRelay, original_track) self.video_frame_trackers = {} # track_id -> VideoFrameTracker - self._video_blackholes: dict[str, tuple[MediaBlackhole, asyncio.Task]] = {} + self._video_drains: dict[ + str, tuple[MediaBlackhole, asyncio.Task, MediaStreamTrack] + ] = {} self._background_tasks: set[asyncio.Task] = set() @self.on("track") @@ -168,6 +170,15 @@ async def on_track(track: aiortc.mediastreams.MediaStreamTrack): tracked_track = VideoFrameTracker(track) self.video_frame_trackers[track.id] = tracked_track + # Drain unconsumed video frames to prevent unbounded queue growth + # in RTCRtpReceiver (aiortc issue #554) + if self._drain_video_frames: + drain_proxy = relay.subscribe(tracked_track) + blackhole = MediaBlackhole() + blackhole.addTrack(drain_proxy) + drain_task = asyncio.create_task(blackhole.start()) + self._video_drains[track.id] = (blackhole, drain_task, drain_proxy) + self.track_map[track.id] = (relay, tracked_track) if track.kind == "audio": @@ -183,14 +194,6 @@ def _emit_pcm(pcm: PcmData): proxy = relay.subscribe(tracked_track) - # Drain unconsumed video frames to prevent unbounded queue growth - # in RTCRtpReceiver (aiortc issue #554) - if track.kind == "video" and self._drain_video_frames: - drain_proxy = relay.subscribe(tracked_track) - blackhole = MediaBlackhole() - blackhole.addTrack(drain_proxy) - drain_task = asyncio.create_task(blackhole.start()) - self._video_blackholes[track.id] = (blackhole, drain_task) self.emit("track_added", proxy, user) @self.on("icegatheringstatechange") @@ -205,11 +208,13 @@ def add_track_subscriber( """Add a new subscriber to an existing track's MediaRelay.""" track_data = self.track_map.get(track_id) - blackhole, drain_task = self._video_blackholes.pop(track_id, (None, None)) + video_drain = self._video_drains.pop(track_id, None) - if blackhole and drain_task: + if video_drain is not None: + blackhole, drain_task, drain_proxy = video_drain task = asyncio.create_task(blackhole.stop()) - drain_task.cancel() # safety net if start() becomes long-lived in future aiortc + drain_proxy.stop() + drain_task.cancel() self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) @@ -226,6 +231,7 @@ def handle_track_ended(self, track: aiortc.mediastreams.MediaStreamTrack) -> Non del self.track_map[track.id] if track.id in self.video_frame_trackers: del self.video_frame_trackers[track.id] + self._video_drains.pop(track.id, None) def get_video_frame_tracker(self) -> Optional[Any]: """Get a video frame tracker for stats collection. diff --git a/tests/rtc/test_subscriber_drain.py b/tests/rtc/test_subscriber_drain.py index 08b86f51..1d00b688 100644 --- a/tests/rtc/test_subscriber_drain.py +++ b/tests/rtc/test_subscriber_drain.py @@ -16,7 +16,7 @@ def subscriber_pc(): pc._drain_video_frames = True pc.track_map = {} pc.video_frame_trackers = {} - pc._video_blackholes = {} + pc._video_drains = {} pc._background_tasks = set() pc._listeners = {} return pc @@ -32,12 +32,14 @@ async def test_blackhole_stopped_when_subscriber_added(self, subscriber_pc): blackhole = Mock() blackhole.stop = AsyncMock() - subscriber_pc._video_blackholes[track_id] = (blackhole, Mock()) + drain_proxy = Mock() + subscriber_pc._video_drains[track_id] = (blackhole, Mock(), drain_proxy) subscriber_pc.add_track_subscriber(track_id) blackhole.stop.assert_called_once() - assert track_id not in subscriber_pc._video_blackholes + drain_proxy.stop.assert_called_once() + assert track_id not in subscriber_pc._video_drains def test_no_error_when_no_drain_exists(self, subscriber_pc): track_id = "user123:video:0"