diff --git a/backend/tests/unit/test_dg_start_guard.py b/backend/tests/unit/test_dg_start_guard.py index d993abef85..202594a013 100644 --- a/backend/tests/unit/test_dg_start_guard.py +++ b/backend/tests/unit/test_dg_start_guard.py @@ -7,7 +7,7 @@ import os import sys from types import ModuleType -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -49,6 +49,12 @@ # MagicMock auto-generates attributes on access, and overwriting would pollute # shared pytest state for test_streaming_deepgram_backoff.py's close/error handler tests. +_speaker_embedding = ModuleType('utils.stt.speaker_embedding') +_speaker_embedding.SPEAKER_MATCH_THRESHOLD = 0.45 +_speaker_embedding.async_extract_embedding_from_bytes = AsyncMock(return_value=None) +_speaker_embedding.compare_embeddings = MagicMock(return_value=0.0) +sys.modules.setdefault('utils.stt.speaker_embedding', _speaker_embedding) + # Now import the real streaming module from utils.stt.streaming import connect_to_deepgram diff --git a/backend/tests/unit/test_parakeet_diarization.py b/backend/tests/unit/test_parakeet_diarization.py index bd34a9814b..0611ff09df 100644 --- a/backend/tests/unit/test_parakeet_diarization.py +++ b/backend/tests/unit/test_parakeet_diarization.py @@ -7,14 +7,56 @@ import asyncio import os +import sys +from types import ModuleType +from unittest.mock import AsyncMock, MagicMock import numpy as np os.environ.setdefault('HOSTED_SPEAKER_EMBEDDING_API_URL', 'http://fake') # enables _diarize os.environ.setdefault('DEEPGRAM_API_KEY', 'x') +_owned_modules = set() +for _mod_name in [ + 'deepgram', + 'deepgram.clients', + 'deepgram.clients.live', + 'deepgram.clients.live.v1', + 'websockets', + 'websockets.exceptions', +]: + if _mod_name not in sys.modules: + sys.modules[_mod_name] = MagicMock() + _owned_modules.add(_mod_name) + +if 'deepgram' in _owned_modules: + sys.modules['deepgram'].DeepgramClient = MagicMock + sys.modules['deepgram'].DeepgramClientOptions = MagicMock + sys.modules['deepgram'].LiveTranscriptionEvents = MagicMock() +if 'deepgram.clients.live.v1' in _owned_modules: + sys.modules['deepgram.clients.live.v1'].LiveOptions = MagicMock + +_speaker_embedding = ModuleType('utils.stt.speaker_embedding') +_speaker_embedding.SPEAKER_MATCH_THRESHOLD = 0.45 +_speaker_embedding.async_extract_embedding_from_bytes = AsyncMock(return_value=None) + + +def _cosine_distance(a, b): + a = np.asarray(a, dtype=np.float32) + b = np.asarray(b, dtype=np.float32) + denom = np.linalg.norm(a) * np.linalg.norm(b) + if denom == 0: + return 1.0 + return float(1.0 - np.sum(a * b) / denom) + + +_speaker_embedding.compare_embeddings = _cosine_distance +sys.modules.setdefault('utils.stt.speaker_embedding', _speaker_embedding) + import utils.stt.streaming as st # noqa: E402 +st.compare_embeddings = _cosine_distance + def _dir_vec(idx: int, rng) -> np.ndarray: """A unit direction in dim `idx` plus small within-speaker noise -> (1, 256).""" diff --git a/backend/tests/unit/test_parakeet_stream_session.py b/backend/tests/unit/test_parakeet_stream_session.py index 90f6589f9b..68d8db0684 100644 --- a/backend/tests/unit/test_parakeet_stream_session.py +++ b/backend/tests/unit/test_parakeet_stream_session.py @@ -29,6 +29,50 @@ mock_transcribe.INFERENCE_MODE = "nemo" sys.modules['transcribe'] = mock_transcribe +_scipy = types.ModuleType('scipy') +_scipy_spatial = types.ModuleType('scipy.spatial') +_scipy_distance = types.ModuleType('scipy.spatial.distance') + + +def _cosine_cdist(a, b, metric="cosine"): + if metric != "cosine": + raise ValueError(f"unsupported metric: {metric}") + a = np.asarray(a, dtype=np.float32) + b = np.asarray(b, dtype=np.float32) + a_norm = np.linalg.norm(a, axis=1, keepdims=True) + b_norm = np.linalg.norm(b, axis=1, keepdims=True).T + denom = a_norm * b_norm + similarity = np.divide(a @ b.T, denom, out=np.zeros((a.shape[0], b.shape[0]), dtype=np.float32), where=denom != 0) + return 1.0 - similarity + + +_scipy_distance.cdist = _cosine_cdist +_scipy_spatial.distance = _scipy_distance +_scipy.spatial = _scipy_spatial +sys.modules.setdefault('scipy', _scipy) +sys.modules.setdefault('scipy.spatial', _scipy_spatial) +sys.modules.setdefault('scipy.spatial.distance', _scipy_distance) + +_torch = types.ModuleType('torch') +_torch.int16 = np.int16 + + +class _TorchArray: + def __init__(self, value): + self.value = np.asarray(value) + + def float(self): + return self + + def __truediv__(self, value): + return _TorchArray(self.value / value) + + +_torch.frombuffer = lambda buffer, dtype: _TorchArray(np.frombuffer(buffer, dtype=dtype)) +_torch.hub = MagicMock() +_torch.hub.load.side_effect = RuntimeError("torch hub unavailable in unit tests") +sys.modules.setdefault('torch', _torch) + import stream_handler as sh diff --git a/backend/tests/unit/test_streaming_deepgram_backoff.py b/backend/tests/unit/test_streaming_deepgram_backoff.py index 630149ee8c..bf76d5f2e8 100644 --- a/backend/tests/unit/test_streaming_deepgram_backoff.py +++ b/backend/tests/unit/test_streaming_deepgram_backoff.py @@ -8,6 +8,7 @@ import asyncio import sys +from types import ModuleType from unittest.mock import MagicMock, patch, AsyncMock import pytest @@ -39,6 +40,19 @@ sys.modules['deepgram'].LiveTranscriptionEvents = MagicMock() sys.modules['deepgram.clients.live.v1'].LiveOptions = MagicMock +_speaker_embedding = ModuleType('utils.stt.speaker_embedding') +_speaker_embedding.SPEAKER_MATCH_THRESHOLD = 0.45 +_speaker_embedding.async_extract_embedding_from_bytes = AsyncMock(return_value=None) +_speaker_embedding.compare_embeddings = MagicMock(return_value=0.0) +sys.modules.setdefault('utils.stt.speaker_embedding', _speaker_embedding) + +_vad = ModuleType('utils.stt.vad') +_vad._get_ort_session = MagicMock() +_vad.make_fresh_state = MagicMock(return_value=(None, None)) +_vad.run_vad_window = MagicMock(return_value=0.0) +_vad.VAD_WINDOW_SAMPLES = 512 +sys.modules.setdefault('utils.stt.vad', _vad) + from utils.stt.streaming import connect_to_deepgram_with_backoff, process_audio_dg # noqa: E402 from utils.stt.streaming import deepgram_options, deepgram_cloud_options # noqa: E402 from utils.stt.streaming import get_stt_service_for_language, STTService, should_preserve_filler_words # noqa: E402