Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions backend/tests/unit/test_pusher_ghost_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
TRANSCRIBE_SRC = Path(__file__).resolve().parents[2] / 'routers' / 'transcribe.py'


def _read_source(path: Path) -> str:
return path.read_text(encoding='utf-8')


def _parse_constant(name: str) -> float:
"""Extract a module-level constant from pusher.py without importing it."""
tree = ast.parse(PUSHER_SRC.read_text())
tree = ast.parse(_read_source(PUSHER_SRC))
for node in ast.iter_child_nodes(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
Expand Down Expand Up @@ -511,7 +515,7 @@ async def long_bg():
@pytest.mark.asyncio
async def test_task_names_assigned(self):
"""Verify tasks get names for production debugging."""
src = PUSHER_SRC.read_text()
src = _read_source(PUSHER_SRC)
assert 'name=f"ws:{' in src, "Tasks must have name= with uid for production debugging"

@pytest.mark.asyncio
Expand Down Expand Up @@ -565,30 +569,30 @@ class TestTranscribeSupervisor:

def test_transcribe_has_finite_tasks_set(self):
"""transcribe.py must define finite_tasks containing only intentionally finite tasks."""
src = TRANSCRIBE_SRC.read_text()
src = _read_source(TRANSCRIBE_SRC)
assert 'finite_tasks' in src, "transcribe.py must define a finite_tasks set"
assert 'pending_conversations_task' in src, "pending_conversations_task must be referenced"
assert 'speaker_id_task' in src, "speaker_id_task must be referenced"

def test_transcribe_lifetime_task_triggers_teardown(self):
"""Lifetime task handling via supervise_tasks utility with finite_tasks set."""
src = TRANSCRIBE_SRC.read_text()
src = _read_source(TRANSCRIBE_SRC)
assert 'finite_task' in src, "Transcribe must define finite tasks for supervisor"
assert 'supervise_tasks' in src, "Transcribe must use supervise_tasks utility"

def test_transcribe_uses_supervisor_utility(self):
src = TRANSCRIBE_SRC.read_text()
src = _read_source(TRANSCRIBE_SRC)
assert 'supervise_tasks' in src, "Transcribe must use supervise_tasks from async_tasks"
assert 'drain_tasks' in src, "Transcribe must use drain_tasks from async_tasks"

def test_transcribe_has_receive_timeout(self):
src = TRANSCRIBE_SRC.read_text()
src = _read_source(TRANSCRIBE_SRC)
assert 'WS_RECEIVE_TIMEOUT' in src

def test_transcribe_gauge_in_try_finally(self):
"""BACKEND_LISTEN_ACTIVE_WS_CONNECTIONS.inc() must be in try body,
.dec() in finally — verified via AST on _stream_handler."""
tree = ast.parse(TRANSCRIBE_SRC.read_text())
tree = ast.parse(_read_source(TRANSCRIBE_SRC))
handler = None
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == '_stream_handler':
Expand Down Expand Up @@ -619,7 +623,7 @@ def test_transcribe_gauge_in_try_finally(self):

def test_transcribe_supervisor_before_drain(self):
"""supervise_tasks() must appear before the main bg drain_tasks(bg_main_tasks) in transcribe.py."""
src = TRANSCRIBE_SRC.read_text()
src = _read_source(TRANSCRIBE_SRC)
supervise_pos = src.find('exit_result = await supervise_tasks(')
drain_pos = src.find('await drain_tasks(bg_main_tasks')
assert supervise_pos != -1, "'supervise_tasks(' call not found in transcribe.py"
Expand All @@ -628,7 +632,7 @@ def test_transcribe_supervisor_before_drain(self):

def test_transcribe_no_gauge_before_try(self):
"""Gauge inc must NOT appear before the main try block to prevent leak on early return."""
src = TRANSCRIBE_SRC.read_text()
src = _read_source(TRANSCRIBE_SRC)
lines = src.split('\n')
in_stream_handler = False
try_line = None
Expand All @@ -651,7 +655,7 @@ def test_transcribe_no_gauge_before_try(self):

def _parse_handler_ast():
"""Parse _websocket_util_trigger and return key AST info about the try/finally structure."""
tree = ast.parse(PUSHER_SRC.read_text())
tree = ast.parse(_read_source(PUSHER_SRC))
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == '_websocket_util_trigger':
return node
Expand Down Expand Up @@ -683,27 +687,27 @@ class TestStructuralIntegrity:
"""Verify the pusher source has the expected patterns."""

def test_source_has_receive_timeout(self):
src = PUSHER_SRC.read_text()
src = _read_source(PUSHER_SRC)
assert 'asyncio.wait_for(websocket.receive_bytes()' in src

def test_source_has_drain_timeout(self):
src = PUSHER_SRC.read_text()
src = _read_source(PUSHER_SRC)
assert 'BG_DRAIN_TIMEOUT' in src

def test_source_uses_supervisor_utility(self):
"""The supervisor uses supervise_tasks() from async_tasks to detect bg crashes."""
src = PUSHER_SRC.read_text()
src = _read_source(PUSHER_SRC)
assert 'supervise_tasks(' in src
assert 'drain_tasks(' in src

def test_source_does_not_gather_all_five_tasks(self):
"""The old pattern gathered all 5 tasks — verify it's gone."""
src = PUSHER_SRC.read_text()
src = _read_source(PUSHER_SRC)
assert 'receive_task,\n' not in src or 'await asyncio.gather(\n receive_task,' not in src

def test_source_has_speaker_shutdown_drain(self):
"""Verify the speaker sample queue skips age check on shutdown."""
src = PUSHER_SRC.read_text()
src = _read_source(PUSHER_SRC)
assert 'is_shutdown' in src


Expand Down Expand Up @@ -742,7 +746,7 @@ def test_gauge_inc_in_try_dec_in_finally(self):

def test_supervisor_before_bg_drain(self):
"""supervise_tasks() must appear before drain_tasks() — supervisor-then-drain ordering."""
src = PUSHER_SRC.read_text()
src = _read_source(PUSHER_SRC)
supervise_pos = src.find('supervise_tasks(')
drain_pos = src.find('drain_tasks(')

Expand All @@ -756,7 +760,7 @@ def test_supervisor_before_bg_drain(self):
def test_receive_task_not_in_gather_with_bg_tasks(self):
"""receive_task must NOT appear in the same gather() as bg_main_tasks.
This was the root cause of ghost connections."""
src = PUSHER_SRC.read_text()
src = _read_source(PUSHER_SRC)
lines = src.split('\n')

for i, line in enumerate(lines):
Expand All @@ -768,7 +772,7 @@ def test_receive_task_not_in_gather_with_bg_tasks(self):

def test_drain_tasks_used_for_bg_cleanup(self):
"""drain_tasks() must be used with BG_DRAIN_TIMEOUT for bg task cleanup."""
src = PUSHER_SRC.read_text()
src = _read_source(PUSHER_SRC)
assert 'drain_tasks(' in src, "drain_tasks must be used for bg task cleanup"
assert 'BG_DRAIN_TIMEOUT' in src, "BG_DRAIN_TIMEOUT must be passed to drain_tasks"

Expand All @@ -788,7 +792,7 @@ def test_finally_drains_remaining_tasks(self):

def test_bg_main_tasks_has_four_tasks(self):
"""bg_main_tasks list literal should have exactly 4 tasks (not 5 — receive is separate)."""
src = PUSHER_SRC.read_text()
src = _read_source(PUSHER_SRC)
lines = src.split('\n')

in_bg_list = False
Expand All @@ -809,7 +813,7 @@ def test_bg_main_tasks_has_four_tasks(self):
def test_is_shutdown_guards_speaker_sample_age_check(self):
"""In process_speaker_sample_queue, is_shutdown must be checked
in the same conditional as SPEAKER_SAMPLE_MIN_AGE."""
src = PUSHER_SRC.read_text()
src = _read_source(PUSHER_SRC)
lines = src.split('\n')

for line in lines:
Expand All @@ -819,5 +823,5 @@ def test_is_shutdown_guards_speaker_sample_age_check(self):

def test_drain_tasks_handles_timeout_logging(self):
"""drain_tasks utility handles timeout logging — verify it's used in pusher."""
src = PUSHER_SRC.read_text()
src = _read_source(PUSHER_SRC)
assert 'drain_tasks(' in src, "drain_tasks must be used in pusher for orderly cleanup"
Loading