Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 95 additions & 35 deletions backend/tests/unit/test_sync_transcription_prefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
"""

import os
import re
import sys
import threading
from types import ModuleType
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -75,6 +77,57 @@
sys.modules.setdefault('firebase_admin.messaging', _fb.messaging)
sys.modules.setdefault('firebase_admin.auth', _fb.auth)

# Stub import-time SDK/native audio dependencies that are not needed for these
# unit tests and can be unavailable on Windows developer machines.
_deepgram = sys.modules.setdefault('deepgram', ModuleType('deepgram'))
_deepgram.DeepgramClient = MagicMock
_deepgram.DeepgramClientOptions = MagicMock

_opuslib = ModuleType('opuslib')
_opuslib.Decoder = MagicMock
sys.modules.setdefault('opuslib', _opuslib)

_pydub = ModuleType('pydub')
_pydub.AudioSegment = MagicMock
sys.modules.setdefault('pydub', _pydub)

_process_conversation = ModuleType('utils.conversations.process_conversation')
_process_conversation.process_conversation = MagicMock()
sys.modules['utils.conversations.process_conversation'] = _process_conversation

_vad = ModuleType('utils.stt.vad')
_vad.vad_is_empty = MagicMock(return_value=False)
sys.modules['utils.stt.vad'] = _vad


def _detect_speaker_from_text(text: str):
match = re.search(r'\b(?:my name is|i am)\s+([a-z][a-z-]*)', text, re.IGNORECASE)
return match.group(1).capitalize() if match else None


_speaker_identification = ModuleType('utils.speaker_identification')
_speaker_identification.detect_speaker_from_text = _detect_speaker_from_text
sys.modules['utils.speaker_identification'] = _speaker_identification


def _compare_embeddings(embedding1: np.ndarray, embedding2: np.ndarray) -> float:
embedding1 = np.atleast_2d(embedding1)
embedding2 = np.atleast_2d(embedding2)
if embedding1.shape[1] != embedding2.shape[1]:
return 2.0
norm_product = np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
if norm_product == 0:
return 2.0
similarity = float(np.dot(embedding1.reshape(-1), embedding2.reshape(-1)) / norm_product)
return 1.0 - similarity
Comment on lines +113 to +122

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Stub assumes 2D embedding shape

The _compare_embeddings stub accesses embedding1.shape[1], which requires a 2D array. If the production compare_embeddings is ever called with a 1D embedding (512,) (e.g., from a new test or changed fixture), this raises an IndexError rather than the expected distance value. Every existing test passes (1, 512) arrays, but the boundary is invisible — a caller using np.array([0.1] * 512) instead of np.array([[0.1] * 512]) would trigger it. The guard could use np.atleast_2d or check embedding1.ndim first.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 4333bb8ba: the stub now normalizes both inputs with np.atleast_2d(...) before checking dimensions, and I added a regression test covering 1D vectors plus mismatched dimensions. Local validation: 62 passed.



_speaker_embedding = ModuleType('utils.stt.speaker_embedding')
_speaker_embedding.extract_embedding_from_bytes = MagicMock()
_speaker_embedding.compare_embeddings = _compare_embeddings
_speaker_embedding.SPEAKER_MATCH_THRESHOLD = 0.45
sys.modules['utils.stt.speaker_embedding'] = _speaker_embedding

# Stub google.cloud.storage.Client to avoid GCS credentials
import google.cloud.storage as _gcs

Expand All @@ -87,6 +140,12 @@
os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv')


def test_compare_embeddings_accepts_1d_vectors():
"""Speaker embedding stub should match production's 1D vector tolerance."""
assert _compare_embeddings(np.array([1.0, 0.0]), np.array([1.0, 0.0])) == pytest.approx(0.0)
assert _compare_embeddings(np.array([1.0, 0.0]), np.array([1.0])) == 2.0


# ---------------------------------------------------------------------------
# deepgram_prerecorded: keywords parameter
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -276,6 +335,7 @@ def test_empty_transcript_with_keywords_and_return_language(self, mock_client):
# ---------------------------------------------------------------------------


@patch('routers.sync.submit_with_context', MagicMock())
class TestProcessSegmentPreferences:
"""Verify process_segment applies user transcription preferences."""

Expand All @@ -288,7 +348,7 @@ def _make_mock_words(self):
@patch('routers.sync.process_conversation')
@patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None)
@patch('routers.sync.get_timestamp_from_path', return_value=1700000000)
@patch('routers.sync.deepgram_prerecorded')
@patch('routers.sync.prerecorded')
@patch('routers.sync.delete_syncing_temporal_file')
@patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav')
def test_vocabulary_passed_to_deepgram(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process):
Expand Down Expand Up @@ -320,13 +380,13 @@ def test_vocabulary_passed_to_deepgram(self, mock_url, mock_delete, mock_dg, moc
@patch('routers.sync.process_conversation')
@patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None)
@patch('routers.sync.get_timestamp_from_path', return_value=1700000000)
@patch('routers.sync.deepgram_prerecorded')
@patch('routers.sync.prerecorded')
@patch('routers.sync.delete_syncing_temporal_file')
@patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav')
def test_single_language_mode_selects_model(
def test_single_language_mode_passes_user_language(
self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process
):
"""Single language mode with a language should select the right model."""
"""Single language mode with a language should pass the user language."""
from routers.sync import process_segment

mock_dg.return_value = (self._make_mock_words(), 'en')
Expand All @@ -342,16 +402,16 @@ def test_single_language_mode_selects_model(

_, kwargs = mock_dg.call_args
assert kwargs['language'] == 'en'
assert kwargs['model'] == 'nova-3'
assert kwargs['return_language'] is True

@patch('routers.sync.process_conversation')
@patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None)
@patch('routers.sync.get_timestamp_from_path', return_value=1700000000)
@patch('routers.sync.deepgram_prerecorded')
@patch('routers.sync.prerecorded')
@patch('routers.sync.delete_syncing_temporal_file')
@patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav')
def test_chinese_selects_nova3(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process):
"""Chinese language should select nova-3 model."""
def test_chinese_passes_user_language(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process):
"""Chinese language should be passed through in single-language mode."""
from routers.sync import process_segment

mock_dg.return_value = (self._make_mock_words(), 'zh')
Expand All @@ -367,16 +427,16 @@ def test_chinese_selects_nova3(self, mock_url, mock_delete, mock_dg, mock_ts, mo

_, kwargs = mock_dg.call_args
assert kwargs['language'] == 'zh'
assert kwargs['model'] == 'nova-3'
assert kwargs['return_language'] is True

@patch('routers.sync.process_conversation')
@patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None)
@patch('routers.sync.get_timestamp_from_path', return_value=1700000000)
@patch('routers.sync.deepgram_prerecorded')
@patch('routers.sync.prerecorded')
@patch('routers.sync.delete_syncing_temporal_file')
@patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav')
def test_no_prefs_uses_defaults(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process):
"""Without preferences, should use multi/nova-3 defaults."""
"""Without preferences, should use multi-language defaults."""
from routers.sync import process_segment

mock_dg.return_value = (self._make_mock_words(), 'en')
Expand All @@ -390,14 +450,14 @@ def test_no_prefs_uses_defaults(self, mock_url, mock_delete, mock_dg, mock_ts, m

_, kwargs = mock_dg.call_args
assert kwargs['language'] == 'multi'
assert kwargs['model'] == 'nova-3'
assert kwargs['return_language'] is True
# Vocabulary should still include "Omi" even without prefs
assert 'Omi' in kwargs['keywords']

@patch('routers.sync.process_conversation')
@patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None)
@patch('routers.sync.get_timestamp_from_path', return_value=1700000000)
@patch('routers.sync.deepgram_prerecorded')
@patch('routers.sync.prerecorded')
@patch('routers.sync.delete_syncing_temporal_file')
@patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav')
def test_vocabulary_capped_at_100(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process):
Expand All @@ -424,13 +484,13 @@ def test_vocabulary_capped_at_100(self, mock_url, mock_delete, mock_dg, mock_ts,
@patch('routers.sync.process_conversation')
@patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None)
@patch('routers.sync.get_timestamp_from_path', return_value=1700000000)
@patch('routers.sync.deepgram_prerecorded')
@patch('routers.sync.prerecorded')
@patch('routers.sync.delete_syncing_temporal_file')
@patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav')
def test_single_language_empty_language_falls_back(
self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process
):
"""single_language_mode=True with empty language should fall back to multi/nova-3."""
"""single_language_mode=True with empty language should fall back to multi."""
from routers.sync import process_segment

mock_dg.return_value = (self._make_mock_words(), 'en')
Expand All @@ -446,12 +506,12 @@ def test_single_language_empty_language_falls_back(

_, kwargs = mock_dg.call_args
assert kwargs['language'] == 'multi'
assert kwargs['model'] == 'nova-3'
assert kwargs['return_language'] is True

@patch('routers.sync.process_conversation')
@patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None)
@patch('routers.sync.get_timestamp_from_path', return_value=1700000000)
@patch('routers.sync.deepgram_prerecorded')
@patch('routers.sync.prerecorded')
@patch('routers.sync.delete_syncing_temporal_file')
@patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav')
def test_multi_language_mode_default(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process):
Expand All @@ -471,12 +531,12 @@ def test_multi_language_mode_default(self, mock_url, mock_delete, mock_dg, mock_

_, kwargs = mock_dg.call_args
assert kwargs['language'] == 'multi'
assert kwargs['model'] == 'nova-3'
assert kwargs['return_language'] is True

@patch('routers.sync.process_conversation')
@patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None)
@patch('routers.sync.get_timestamp_from_path', return_value=1700000000)
@patch('routers.sync.deepgram_prerecorded')
@patch('routers.sync.prerecorded')
@patch('routers.sync.delete_syncing_temporal_file')
@patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav')
def test_single_language_trusts_user_language(
Expand Down Expand Up @@ -514,7 +574,7 @@ class TestSyncEndpointPrefsWiring:
@staticmethod
def _read_sync_source():
sync_path = os.path.join(os.path.dirname(__file__), '..', '..', 'routers', 'sync.py')
with open(sync_path) as f:
with open(sync_path, encoding='utf-8') as f:
return f.read()

def test_endpoint_fetches_transcription_prefs(self):
Expand Down Expand Up @@ -640,8 +700,6 @@ def test_unsupported_language_falls_back_to_multi(self):
import struct
import wave

import numpy as np


def _make_wav_bytes(duration_sec: float = 2.0, sample_rate: int = 16000) -> bytes:
"""Generate silent WAV bytes of the given duration for testing."""
Expand Down Expand Up @@ -695,9 +753,9 @@ def test_loads_people_embeddings(self, mock_users_db):

mock_users_db.get_user_speaker_embedding.return_value = None
mock_users_db.get_people.return_value = [
{'id': 'p1', 'name': 'Alice', 'speaker_embedding': [0.2] * 512},
{'id': 'p1', 'name': 'Alice', 'speaker_embedding': [0.2] * 512, 'speech_samples': ['sample-1']},
{'id': 'p2', 'name': 'Bob'}, # no embedding
{'id': 'p3', 'name': 'Carol', 'speaker_embedding': [0.3] * 512},
{'id': 'p3', 'name': 'Carol', 'speaker_embedding': [0.3] * 512, 'speech_samples': ['sample-3']},
]

cache = build_person_embeddings_cache('uid1')
Expand Down Expand Up @@ -1101,6 +1159,7 @@ def test_equal_best_clip_stable_order(self, mock_extract):
assert segments[1].person_id == 'p2'


@patch('routers.sync.submit_with_context', MagicMock())
class TestProcessSegmentSpeakerIdIntegration:
"""Verify process_segment wires speaker identification correctly."""

Expand All @@ -1114,7 +1173,7 @@ def _mock_words():
@patch('routers.sync.process_conversation')
@patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None)
@patch('routers.sync.get_timestamp_from_path', return_value=1700000000)
@patch('routers.sync.deepgram_prerecorded')
@patch('routers.sync.prerecorded')
@patch('routers.sync.delete_syncing_temporal_file')
@patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav')
@patch('routers.sync.identify_speakers_for_segments')
Expand Down Expand Up @@ -1149,7 +1208,7 @@ def test_speaker_id_called_when_cache_provided(
@patch('routers.sync.process_conversation')
@patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None)
@patch('routers.sync.get_timestamp_from_path', return_value=1700000000)
@patch('routers.sync.deepgram_prerecorded')
@patch('routers.sync.prerecorded')
@patch('routers.sync.delete_syncing_temporal_file')
@patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav')
@patch('routers.sync.identify_speakers_for_segments')
Expand Down Expand Up @@ -1187,7 +1246,7 @@ class TestSyncEndpointSpeakerIdWiring:
@staticmethod
def _read_sync_source():
sync_path = os.path.join(os.path.dirname(__file__), '..', '..', 'routers', 'sync.py')
with open(sync_path) as f:
with open(sync_path, encoding='utf-8') as f:
return f.read()

def test_endpoint_builds_embeddings_cache(self):
Expand All @@ -1213,29 +1272,30 @@ def test_endpoint_passes_cache_to_thread(self):
class TestDownloadAudioBytes:
"""Verify _download_audio_bytes handles success and failure."""

@patch('routers.sync.requests')
def test_download_success(self, mock_requests):
@patch('routers.sync.httpx')
def test_download_success(self, mock_httpx):
from routers.sync import _download_audio_bytes

mock_resp = MagicMock()
mock_resp.content = b'wav-bytes'
mock_resp.raise_for_status.return_value = None
mock_requests.get.return_value = mock_resp
mock_httpx.get.return_value = mock_resp

result = _download_audio_bytes('http://example.com/audio.wav')
assert result == b'wav-bytes'
mock_requests.get.assert_called_once_with('http://example.com/audio.wav', timeout=60)
mock_httpx.get.assert_called_once_with('http://example.com/audio.wav', timeout=60.0)

@patch('routers.sync.requests')
def test_download_failure_returns_none(self, mock_requests):
@patch('routers.sync.httpx')
def test_download_failure_returns_none(self, mock_httpx):
from routers.sync import _download_audio_bytes

mock_requests.get.side_effect = Exception("Connection refused")
mock_httpx.get.side_effect = Exception("Connection refused")

result = _download_audio_bytes('http://example.com/audio.wav')
assert result is None


@patch('routers.sync.submit_with_context', MagicMock())
class TestSpeakerIdExceptionHandling:
"""Verify process_segment swallows speaker ID exceptions gracefully."""

Expand All @@ -1249,7 +1309,7 @@ def _mock_words():
@patch('routers.sync.process_conversation')
@patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None)
@patch('routers.sync.get_timestamp_from_path', return_value=1700000000)
@patch('routers.sync.deepgram_prerecorded')
@patch('routers.sync.prerecorded')
@patch('routers.sync.delete_syncing_temporal_file')
@patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav')
@patch('routers.sync.identify_speakers_for_segments', side_effect=RuntimeError("embedding API down"))
Expand Down
Loading