diff --git a/backend/tests/unit/test_high_priority_usage_tracking.py b/backend/tests/unit/test_high_priority_usage_tracking.py index 8fa10aaa6a..5c1c2fc644 100644 --- a/backend/tests/unit/test_high_priority_usage_tracking.py +++ b/backend/tests/unit/test_high_priority_usage_tracking.py @@ -11,10 +11,16 @@ import os import sys import types +import json +import importlib +import importlib.util +from datetime import timezone from pathlib import Path from unittest.mock import MagicMock, AsyncMock import asyncio +import pytest + os.environ.setdefault( "ENCRYPTION_SECRET", "omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv", @@ -28,6 +34,77 @@ def _stub_module(name: str) -> types.ModuleType: return sys.modules[name] +def _optional_stub_module(name: str) -> tuple[types.ModuleType, bool]: + if name in sys.modules: + return sys.modules[name], False + + try: + spec = importlib.util.find_spec(name) + except (ImportError, ValueError): + spec = None + + if spec is not None: + return importlib.import_module(name), False + + mod = types.ModuleType(name) + sys.modules[name] = mod + return mod, True + + +# --- Stub minimal LangChain modules needed by usage_tracker --- +langchain_core_mod, langchain_core_stubbed = _optional_stub_module("langchain_core") +if langchain_core_stubbed and not hasattr(langchain_core_mod, '__path__'): + langchain_core_mod.__path__ = [] + +callbacks_mod, callbacks_stubbed = _optional_stub_module("langchain_core.callbacks") +outputs_mod, outputs_stubbed = _optional_stub_module("langchain_core.outputs") +output_parsers_mod, output_parsers_stubbed = _optional_stub_module("langchain_core.output_parsers") +prompts_mod, prompts_stubbed = _optional_stub_module("langchain_core.prompts") + + +class BaseCallbackHandler: + pass + + +class LLMResult: + pass + + +class PydanticOutputParser: + def __init__(self, pydantic_object): + self.pydantic_object = pydantic_object + + def get_format_instructions(self): + return "" + + def parse(self, text): + return self.pydantic_object(**json.loads(text)) + + +class ChatPromptTemplate: + @classmethod + def from_messages(cls, _messages): + return cls() + + def __or__(self, other): + return other + + +if callbacks_stubbed: + callbacks_mod.BaseCallbackHandler = BaseCallbackHandler +if outputs_stubbed: + outputs_mod.LLMResult = LLMResult +if output_parsers_stubbed: + output_parsers_mod.PydanticOutputParser = PydanticOutputParser +if prompts_stubbed: + prompts_mod.ChatPromptTemplate = ChatPromptTemplate + +pytz_mod, pytz_stubbed = _optional_stub_module("pytz") +if pytz_stubbed: + pytz_mod.UTC = timezone.utc + pytz_mod.timezone = MagicMock(return_value=timezone.utc) + + # --- Stub database package and submodules --- # Use _stub_module which only creates if not already loaded database_mod = _stub_module("database") @@ -73,6 +150,7 @@ def _stub_module(name: str) -> types.ModuleType: sys.modules["database.knowledge_graph"].upsert_knowledge_edge = MagicMock(return_value={'id': 'e1'}) sys.modules["database.knowledge_graph"].delete_knowledge_graph = MagicMock() sys.modules["database.users"].get_user_profile = MagicMock(return_value={'time_zone': 'UTC'}) +sys.modules["database.users"].get_user_language_preference = MagicMock(return_value='en') sys.modules["database.users"].get_people_by_ids = MagicMock(return_value=[]) sys.modules["database.action_items"].get_action_items = MagicMock(return_value=[]) sys.modules["database.daily_summaries"].create_daily_summary = MagicMock(return_value="summary-1") @@ -123,13 +201,45 @@ def _stub_module(name: str) -> types.ModuleType: clients_mod.parser = mock_parser +def _get_llm_stub(name, **_kwargs): + if name in {'goals', 'knowledge_graph', 'daily_summary_simple'}: + return mock_llm_mini + if name in {'goals_advice', 'notifications'}: + return mock_llm_medium + if name == 'daily_summary': + return mock_llm_medium_experiment + raise ValueError(f"Unexpected get_llm feature in usage tracking test: {name}") + + +clients_mod.get_llm = _get_llm_stub + + +def test_get_llm_stub_rejects_unknown_features(): + with pytest.raises(ValueError, match="Unexpected get_llm feature"): + _get_llm_stub("new-feature") + + +def test_optional_stub_module_keeps_loaded_modules(): + module_name = "test_loaded_optional_langchain_module" + loaded_module = types.ModuleType(module_name) + loaded_module.Sentinel = object + sys.modules[module_name] = loaded_module + try: + module, stubbed = _optional_stub_module(module_name) + assert module is loaded_module + assert stubbed is False + assert module.Sentinel is object + finally: + sys.modules.pop(module_name, None) + + # ── Source-level tests: verify track_usage wraps every LLM call ── def _read_source(relative_path: str) -> str: """Read a source file relative to the backend directory.""" backend_dir = Path(__file__).resolve().parent.parent.parent - return (backend_dir / relative_path).read_text() + return (backend_dir / relative_path).read_text(encoding="utf-8") class TestGoalsTracking: