Skip to content

Commit d1ca6fb

Browse files
committed
refactor: merge blackhole and drain task into single dict
Two separate dicts (_video_blackholes, _video_drain_tasks) tracked the same lifecycle, risking desync. Store them as a tuple in one dict.
1 parent feb2d55 commit d1ca6fb

File tree

2 files changed

+5
-13
lines changed

2 files changed

+5
-13
lines changed

getstream/video/rtc/pc.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ def __init__(
142142

143143
self.track_map = {} # track_id -> (MediaRelay, original_track)
144144
self.video_frame_trackers = {} # track_id -> VideoFrameTracker
145-
self._video_blackholes: dict[str, MediaBlackhole] = {}
146-
self._video_drain_tasks: dict[str, asyncio.Task] = {}
145+
self._video_blackholes: dict[str, tuple[MediaBlackhole, asyncio.Task]] = {}
147146
self._background_tasks: set[asyncio.Task] = set()
148147

149148
@self.on("track")
@@ -190,11 +189,8 @@ def _emit_pcm(pcm: PcmData):
190189
drain_proxy = relay.subscribe(tracked_track)
191190
blackhole = MediaBlackhole()
192191
blackhole.addTrack(drain_proxy)
193-
self._video_blackholes[track.id] = blackhole
194-
self._video_drain_tasks[track.id] = asyncio.create_task(
195-
blackhole.start()
196-
)
197-
192+
drain_task = asyncio.create_task(blackhole.start())
193+
self._video_blackholes[track.id] = (blackhole, drain_task)
198194
self.emit("track_added", proxy, user)
199195

200196
@self.on("icegatheringstatechange")
@@ -209,8 +205,7 @@ def add_track_subscriber(
209205
"""Add a new subscriber to an existing track's MediaRelay."""
210206
track_data = self.track_map.get(track_id)
211207

212-
self._video_drain_tasks.pop(track_id, None)
213-
blackhole = self._video_blackholes.pop(track_id, None)
208+
blackhole, _ = self._video_blackholes.pop(track_id, (None, None))
214209

215210
if blackhole:
216211
task = asyncio.create_task(blackhole.stop())

tests/rtc/test_subscriber_drain.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def subscriber_pc():
1717
pc.track_map = {}
1818
pc.video_frame_trackers = {}
1919
pc._video_blackholes = {}
20-
pc._video_drain_tasks = {}
2120
pc._background_tasks = set()
2221
pc._listeners = {}
2322
return pc
@@ -33,14 +32,12 @@ async def test_blackhole_stopped_when_subscriber_added(self, subscriber_pc):
3332

3433
blackhole = Mock()
3534
blackhole.stop = AsyncMock()
36-
subscriber_pc._video_blackholes[track_id] = blackhole
37-
subscriber_pc._video_drain_tasks[track_id] = Mock()
35+
subscriber_pc._video_blackholes[track_id] = (blackhole, Mock())
3836

3937
subscriber_pc.add_track_subscriber(track_id)
4038

4139
blackhole.stop.assert_called_once()
4240
assert track_id not in subscriber_pc._video_blackholes
43-
assert track_id not in subscriber_pc._video_drain_tasks
4441

4542
def test_no_error_when_no_drain_exists(self, subscriber_pc):
4643
track_id = "user123:video:0"

0 commit comments

Comments
 (0)