Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 57 additions & 14 deletions backend/tests/unit/test_desktop_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@
sys.modules.setdefault('firebase_admin.messaging', _fb.messaging)
sys.modules.setdefault('firebase_admin.auth', _fb.auth)

_deepgram = ModuleType('deepgram')
_deepgram.DeepgramClient = MagicMock
_deepgram.DeepgramClientOptions = MagicMock
sys.modules.setdefault('deepgram', _deepgram)

_speaker_embedding = ModuleType('utils.stt.speaker_embedding')
_speaker_embedding.SPEAKER_MATCH_THRESHOLD = 0.45
_speaker_embedding.compare_embeddings = MagicMock(return_value=0.0)
_speaker_embedding.extract_embedding_from_bytes = MagicMock()
sys.modules['utils.stt.speaker_embedding'] = _speaker_embedding

import google.cloud.storage as _gcs

_gcs.Client = MagicMock
Expand All @@ -103,6 +114,12 @@
os.environ.setdefault('DEEPGRAM_API_KEY', 'fake-for-test')
os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv')


@pytest.fixture(autouse=True)
def _ensure_tmp_dir():
os.makedirs('/tmp', exist_ok=True)


# Stub transitive imports for utils.chat (avoid pulling in all of utils.llm etc.)
# Do NOT stub utils.other.endpoints — it contains the @timeit decorator that must
# be a real function (not MagicMock) or it corrupts decorated function signatures.
Expand All @@ -113,7 +130,6 @@
'utils.llm.chat',
'utils.llm.goals',
'utils.llm.usage_tracker',
'utils.conversations',
'utils.conversations.process_conversation',
'utils.notifications',
'utils.other.storage',
Expand All @@ -131,6 +147,15 @@
]:
sys.modules.setdefault(_ufull, MagicMock())

_utils_conversations_pkg = ModuleType('utils.conversations')
_utils_conversations_pkg.__path__ = []
_utils_conversations_pkg.__package__ = 'utils.conversations'
_utils_conversations_factory = ModuleType('utils.conversations.factory')
_utils_conversations_factory.deserialize_conversation = MagicMock(side_effect=lambda conversation: conversation)
sys.modules['utils.conversations'] = _utils_conversations_pkg
sys.modules['utils.conversations.factory'] = _utils_conversations_factory
setattr(_utils_conversations_pkg, 'factory', _utils_conversations_factory)

# Force-import real models.chat (has no project deps, needed for FastAPI response_model)
import importlib.util as _ilu

Expand Down Expand Up @@ -301,7 +326,7 @@ class TestTranscribePcmBytes:
"""Verify transcribe_pcm_bytes passes language/model and propagates errors."""

@patch('utils.chat.postprocess_words')
@patch('utils.chat.deepgram_prerecorded_from_bytes')
@patch('utils.chat.prerecorded_from_bytes')
@patch('utils.chat.get_deepgram_model_for_language')
def test_language_model_forwarded(self, mock_get_model, mock_dg, mock_postprocess):
"""stt_language and stt_model should be passed to deepgram_prerecorded_from_bytes."""
Expand All @@ -322,7 +347,7 @@ def test_language_model_forwarded(self, mock_get_model, mock_dg, mock_postproces
assert call_kwargs['encoding'] == 'linear16'
assert text == 'Hola'

@patch('utils.chat.deepgram_prerecorded_from_bytes')
@patch('utils.chat.prerecorded_from_bytes')
@patch('utils.chat.get_deepgram_model_for_language')
def test_runtime_error_propagates(self, mock_get_model, mock_dg):
"""RuntimeError from Deepgram should propagate (not be caught)."""
Expand All @@ -334,7 +359,7 @@ def test_runtime_error_propagates(self, mock_get_model, mock_dg):
with pytest.raises(RuntimeError, match='Deepgram failed'):
transcribe_pcm_bytes(b'\x00' * 100, 'test-uid')

@patch('utils.chat.deepgram_prerecorded_from_bytes')
@patch('utils.chat.prerecorded_from_bytes')
@patch('utils.chat.get_deepgram_model_for_language')
def test_empty_words_returns_none(self, mock_get_model, mock_dg):
"""Empty word list should return (None, language)."""
Expand All @@ -348,7 +373,7 @@ def test_empty_words_returns_none(self, mock_get_model, mock_dg):
assert lang == 'en'

@patch('utils.chat.postprocess_words')
@patch('utils.chat.deepgram_prerecorded_from_bytes')
@patch('utils.chat.prerecorded_from_bytes')
@patch('utils.chat.get_deepgram_model_for_language')
def test_multi_language_returns_detected_language(self, mock_get_model, mock_dg, mock_postprocess):
"""Multi-language mode should return the Deepgram-detected language, not hardcoded 'en'."""
Expand All @@ -370,7 +395,7 @@ def test_multi_language_returns_detected_language(self, mock_get_model, mock_dg,
assert call_kwargs['return_language'] is True

@patch('utils.chat.postprocess_words')
@patch('utils.chat.deepgram_prerecorded_from_bytes')
@patch('utils.chat.prerecorded_from_bytes')
@patch('utils.chat.get_deepgram_model_for_language')
def test_chinese_language_uses_nova3(self, mock_get_model, mock_dg, mock_postprocess):
"""Chinese should use nova-3 model."""
Expand All @@ -389,7 +414,7 @@ def test_chinese_language_uses_nova3(self, mock_get_model, mock_dg, mock_postpro
assert call_kwargs['language'] == 'zh'

@patch('utils.chat.postprocess_words')
@patch('utils.chat.deepgram_prerecorded_from_bytes')
@patch('utils.chat.prerecorded_from_bytes')
@patch('utils.chat.get_deepgram_model_for_language')
def test_whitespace_only_transcript_returns_none(self, mock_get_model, mock_dg, mock_postprocess):
"""Whitespace-only transcript after postprocessing should return (None, language)."""
Expand All @@ -405,7 +430,7 @@ def test_whitespace_only_transcript_returns_none(self, mock_get_model, mock_dg,
assert text is None
assert lang == 'en'

@patch('utils.chat.deepgram_prerecorded_from_bytes')
@patch('utils.chat.prerecorded_from_bytes')
@patch('utils.chat.get_deepgram_model_for_language')
def test_postprocess_empty_returns_none(self, mock_get_model, mock_dg):
"""postprocess_words returning empty list should return (None, language)."""
Expand All @@ -430,14 +455,14 @@ class TestDeepgramPrerecordedFromBytesEdgeCases:

@patch('utils.stt.pre_recorded._deepgram_client')
def test_retry_raises_after_max_attempts(self, mock_client):
"""After 3 failed attempts, should raise RuntimeError."""
"""After the configured retry is exhausted, should raise RuntimeError."""
mock_client.listen.rest.v.return_value.transcribe_file.side_effect = Exception('connection timeout')

with pytest.raises(RuntimeError, match='Deepgram transcription failed after 3 attempts'):
with pytest.raises(RuntimeError, match='Deepgram transcription failed after 2 attempts'):
deepgram_prerecorded_from_bytes(b'\x00' * 100, encoding='linear16')

# Should have been called 3 times (attempts 0, 1, 2)
assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 3
# Should have been called twice (initial attempt + one retry)
assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 2

@patch('utils.stt.pre_recorded._deepgram_client')
def test_return_language_empty_words_returns_detected_lang(self, mock_client):
Expand All @@ -462,10 +487,10 @@ def test_no_channels_raises_and_retries(self, mock_client):
mock_response.to_dict.return_value = {'results': {'channels': []}}
mock_client.listen.rest.v.return_value.transcribe_file.return_value = mock_response

with pytest.raises(RuntimeError, match='Deepgram transcription failed after 3 attempts'):
with pytest.raises(RuntimeError, match='Deepgram transcription failed after 2 attempts'):
deepgram_prerecorded_from_bytes(b'\x00' * 100)

assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 3
assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 2


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -505,6 +530,24 @@ def _stub_router_deps():
]
for mod in extra_models + extra_database + extra_utils:
sys.modules.setdefault(mod, MagicMock())
opuslib_stub = ModuleType('opuslib')
opuslib_stub.Decoder = MagicMock()
sys.modules['opuslib'] = opuslib_stub
pydub_stub = ModuleType('pydub')
pydub_stub.AudioSegment = MagicMock()
sys.modules['pydub'] = pydub_stub
limiter_stub = ModuleType('utils.voice_duration_limiter')
limiter_stub.compute_pcm_duration_ms = lambda byte_count, sample_rate, channels: int(
byte_count / (sample_rate * channels * 2) * 1000
)
limiter_stub.read_wav_duration_ms = MagicMock(return_value=1000)
limiter_stub.try_consume_budget = MagicMock(return_value=(True, 0, 7200000))
limiter_stub.check_budget = MagicMock(return_value=(True, 0, 7200000))
limiter_stub.record_actual_duration = MagicMock()
sys.modules['utils.voice_duration_limiter'] = limiter_stub
subscription_stub = sys.modules.setdefault('utils.subscription', MagicMock())
subscription_stub.enforce_chat_quota = MagicMock()
subscription_stub.is_trial_paywalled = MagicMock(return_value=False)
# Ensure redis_db.check_rate_limit returns (True, 99, 0)
rdb = sys.modules.get('database.redis_db')
if rdb:
Expand Down
Loading