|
| 1 | +"""Tests for SubscriberPeerConnection video drain behavior.""" |
| 2 | + |
| 3 | +from unittest.mock import AsyncMock, Mock |
| 4 | + |
| 5 | +import pytest |
| 6 | +from aiortc.contrib.media import MediaRelay |
| 7 | + |
| 8 | +from getstream.video.rtc.pc import SubscriberPeerConnection |
| 9 | + |
| 10 | + |
| 11 | +@pytest.fixture |
| 12 | +def subscriber_pc(): |
| 13 | + """Create a SubscriberPeerConnection bypassing heavy parent inits.""" |
| 14 | + pc = SubscriberPeerConnection.__new__(SubscriberPeerConnection) |
| 15 | + pc.connection = Mock() |
| 16 | + pc._drain_video_frames = True |
| 17 | + pc.track_map = {} |
| 18 | + pc.video_frame_trackers = {} |
| 19 | + pc._video_blackholes = {} |
| 20 | + pc._video_drain_tasks = {} |
| 21 | + pc._listeners = {} |
| 22 | + return pc |
| 23 | + |
| 24 | + |
| 25 | +class TestAddTrackSubscriberStopsDrain: |
| 26 | + def test_blackhole_stopped_when_subscriber_added(self, subscriber_pc): |
| 27 | + track_id = "user123:video:0" |
| 28 | + relay = MediaRelay() |
| 29 | + original_track = Mock() |
| 30 | + subscriber_pc.track_map[track_id] = (relay, original_track) |
| 31 | + |
| 32 | + blackhole = Mock() |
| 33 | + blackhole.stop = AsyncMock() |
| 34 | + subscriber_pc._video_blackholes[track_id] = blackhole |
| 35 | + subscriber_pc._video_drain_tasks[track_id] = Mock() |
| 36 | + |
| 37 | + subscriber_pc.add_track_subscriber(track_id) |
| 38 | + |
| 39 | + blackhole.stop.assert_called_once() |
| 40 | + assert track_id not in subscriber_pc._video_blackholes |
| 41 | + assert track_id not in subscriber_pc._video_drain_tasks |
| 42 | + |
| 43 | + def test_no_error_when_no_drain_exists(self, subscriber_pc): |
| 44 | + track_id = "user123:video:0" |
| 45 | + relay = MediaRelay() |
| 46 | + original_track = Mock() |
| 47 | + subscriber_pc.track_map[track_id] = (relay, original_track) |
| 48 | + |
| 49 | + result = subscriber_pc.add_track_subscriber(track_id) |
| 50 | + assert result is not None |
0 commit comments