diff --git a/backend/tests/unit/test_desktop_migration.py b/backend/tests/unit/test_desktop_migration.py index 25d9a05dc7..a8f8e4b4ed 100644 --- a/backend/tests/unit/test_desktop_migration.py +++ b/backend/tests/unit/test_desktop_migration.py @@ -45,9 +45,35 @@ def _stub_package(name): return mod +def _remove_module_tree(prefix): + for name in list(sys.modules): + if name == prefix or name.startswith(prefix + "."): + sys.modules.pop(name, None) + + +def _ensure_package_path(name, path): + mod = sys.modules.get(name) + if not isinstance(mod, types.ModuleType): + mod = types.ModuleType(name) + sys.modules[name] = mod + mod.__path__ = [str(path)] + return mod + + # --------------------------------------------------------------------------- # Stub heavy dependencies before any production imports # --------------------------------------------------------------------------- +for module_prefix in [ + "database", + "models", + "utils", + "routers.chat_sessions", + "routers.focus_sessions", + "routers.advice", + "routers.staged_tasks", +]: + _remove_module_tree(module_prefix) + for mod_name in [ "firebase_admin", "firebase_admin.firestore", @@ -83,17 +109,14 @@ def _stub_package(name): sys.modules["google.cloud.firestore_v1"].FieldFilter = field_filter_stub.FieldFilter sys.modules["google.cloud.firestore_v1"].transactional = lambda f: f +redis_stub = sys.modules["database.redis_db"] +redis_stub.try_acquire_user_platform_write_lock = MagicMock(return_value=True) + # Add backend dir to sys.path sys.path.insert(0, str(BACKEND_DIR)) # Stub database package and _client -if "database" not in sys.modules: - db_pkg = _stub_package("database") - db_pkg.__path__ = [str(BACKEND_DIR / "database")] -else: - db_mod = sys.modules["database"] - if not hasattr(db_mod, '__path__'): - db_mod.__path__ = [str(BACKEND_DIR / "database")] +_ensure_package_path("database", BACKEND_DIR / "database") client_stub = _stub_module("database._client") mock_db = MagicMock() @@ -148,6 +171,10 @@ def _stub_package(name): from routers.advice import CreateAdviceRequest # noqa: E402 from routers.staged_tasks import BatchUpdateScoresRequest, BatchScoreEntry # noqa: E402 +_ensure_package_path("models", BACKEND_DIR / "models") +_ensure_package_path("utils", BACKEND_DIR / "utils") +_ensure_package_path("utils.other", BACKEND_DIR / "utils" / "other") + # Cannot import routers.users directly — it pulls in database.conversations → utils.other.hume # which has heavy deps. Mirror the models here and verify parity via AST test below. @@ -1358,7 +1385,7 @@ def test_notification_settings_fields_match_source(self): """Inline UpdateNotificationSettingsRequest matches routers/users.py definition.""" import ast - source = (BACKEND_DIR / 'routers' / 'users.py').read_text() + source = (BACKEND_DIR / 'routers' / 'users.py').read_text(encoding='utf-8') tree = ast.parse(source) for node in ast.walk(tree): if isinstance(node, ast.ClassDef) and node.name == 'UpdateNotificationSettingsRequest': @@ -1377,7 +1404,7 @@ def test_llm_usage_fields_match_source(self): """Inline RecordLlmUsageBucketRequest matches routers/users.py definition.""" import ast - source = (BACKEND_DIR / 'routers' / 'users.py').read_text() + source = (BACKEND_DIR / 'routers' / 'users.py').read_text(encoding='utf-8') tree = ast.parse(source) for node in ast.walk(tree): if isinstance(node, ast.ClassDef) and node.name == 'RecordLlmUsageBucketRequest': @@ -1694,10 +1721,14 @@ def test_returns_title(self, mock_update): session_id='s1', messages=[TitleMessageInput(text='hi', sender='human'), TitleMessageInput(text='hello', sender='ai')], ) - with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(llm_mini=mock_llm)}): + mock_get_llm = MagicMock(return_value=mock_llm) + llm_clients_stub = types.ModuleType('utils.llm.clients') + llm_clients_stub.get_llm = mock_get_llm + with patch.dict('sys.modules', {'utils.llm.clients': llm_clients_stub}): result = generate_session_title(request, uid='u1') assert result == {'title': 'Project Discussion'} + mock_get_llm.assert_called_once_with('session_titles') mock_update.assert_called_once_with('u1', 's1', title='Project Discussion') @patch('database.chat.update_chat_session') @@ -1713,10 +1744,14 @@ def test_empty_response_defaults_to_new_chat(self, mock_update): session_id='s1', messages=[TitleMessageInput(text='hi', sender='human')], ) - with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(llm_mini=mock_llm)}): + mock_get_llm = MagicMock(return_value=mock_llm) + llm_clients_stub = types.ModuleType('utils.llm.clients') + llm_clients_stub.get_llm = mock_get_llm + with patch.dict('sys.modules', {'utils.llm.clients': llm_clients_stub}): result = generate_session_title(request, uid='u1') assert result == {'title': 'New Chat'} + mock_get_llm.assert_called_once_with('session_titles') class TestChatMessageCount: