Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 111 additions & 1 deletion backend/tests/unit/test_high_priority_usage_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
import os
import sys
import types
import json
import importlib
import importlib.util
from datetime import timezone
from pathlib import Path
from unittest.mock import MagicMock, AsyncMock
import asyncio

import pytest

os.environ.setdefault(
"ENCRYPTION_SECRET",
"omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv",
Expand All @@ -28,6 +34,77 @@ def _stub_module(name: str) -> types.ModuleType:
return sys.modules[name]


def _optional_stub_module(name: str) -> tuple[types.ModuleType, bool]:
if name in sys.modules:
return sys.modules[name], False

try:
spec = importlib.util.find_spec(name)
except (ImportError, ValueError):
spec = None

if spec is not None:
return importlib.import_module(name), False

mod = types.ModuleType(name)
sys.modules[name] = mod
return mod, True


# --- Stub minimal LangChain modules needed by usage_tracker ---
langchain_core_mod, langchain_core_stubbed = _optional_stub_module("langchain_core")
if langchain_core_stubbed and not hasattr(langchain_core_mod, '__path__'):
langchain_core_mod.__path__ = []

callbacks_mod, callbacks_stubbed = _optional_stub_module("langchain_core.callbacks")
outputs_mod, outputs_stubbed = _optional_stub_module("langchain_core.outputs")
output_parsers_mod, output_parsers_stubbed = _optional_stub_module("langchain_core.output_parsers")
prompts_mod, prompts_stubbed = _optional_stub_module("langchain_core.prompts")


class BaseCallbackHandler:
pass


class LLMResult:
pass


class PydanticOutputParser:
def __init__(self, pydantic_object):
self.pydantic_object = pydantic_object

def get_format_instructions(self):
return ""

def parse(self, text):
return self.pydantic_object(**json.loads(text))


class ChatPromptTemplate:
@classmethod
def from_messages(cls, _messages):
return cls()

def __or__(self, other):
return other


if callbacks_stubbed:
callbacks_mod.BaseCallbackHandler = BaseCallbackHandler
if outputs_stubbed:
outputs_mod.LLMResult = LLMResult
if output_parsers_stubbed:
output_parsers_mod.PydanticOutputParser = PydanticOutputParser
if prompts_stubbed:
prompts_mod.ChatPromptTemplate = ChatPromptTemplate

pytz_mod, pytz_stubbed = _optional_stub_module("pytz")
if pytz_stubbed:
pytz_mod.UTC = timezone.utc
pytz_mod.timezone = MagicMock(return_value=timezone.utc)


# --- Stub database package and submodules ---
# Use _stub_module which only creates if not already loaded
database_mod = _stub_module("database")
Expand Down Expand Up @@ -73,6 +150,7 @@ def _stub_module(name: str) -> types.ModuleType:
sys.modules["database.knowledge_graph"].upsert_knowledge_edge = MagicMock(return_value={'id': 'e1'})
sys.modules["database.knowledge_graph"].delete_knowledge_graph = MagicMock()
sys.modules["database.users"].get_user_profile = MagicMock(return_value={'time_zone': 'UTC'})
sys.modules["database.users"].get_user_language_preference = MagicMock(return_value='en')
sys.modules["database.users"].get_people_by_ids = MagicMock(return_value=[])
sys.modules["database.action_items"].get_action_items = MagicMock(return_value=[])
sys.modules["database.daily_summaries"].create_daily_summary = MagicMock(return_value="summary-1")
Expand Down Expand Up @@ -123,13 +201,45 @@ def _stub_module(name: str) -> types.ModuleType:
clients_mod.parser = mock_parser


def _get_llm_stub(name, **_kwargs):
if name in {'goals', 'knowledge_graph', 'daily_summary_simple'}:
return mock_llm_mini
if name in {'goals_advice', 'notifications'}:
return mock_llm_medium
if name == 'daily_summary':
return mock_llm_medium_experiment
raise ValueError(f"Unexpected get_llm feature in usage tracking test: {name}")


clients_mod.get_llm = _get_llm_stub
Comment on lines +204 to +214

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Silent fallback in _get_llm_stub masks uncovered get_llm call sites

The final return mock_llm_mini fallback means any new get_llm(name) call whose name is not enumerated in the stub (e.g., after a future feature is added) silently gets mock_llm_mini instead of failing loudly. The context-capture tests then pass even though the wrong mock is in use, so tracking regressions for new features would go undetected. Raising ValueError for unknown names (or at minimum warnings.warn) would make the mismatch visible immediately.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in f483ea6b2: _get_llm_stub now raises ValueError for unknown feature names instead of silently falling back to mock_llm_mini, and there is a regression test for that failure path. Local validation: 21 passed.



def test_get_llm_stub_rejects_unknown_features():
with pytest.raises(ValueError, match="Unexpected get_llm feature"):
_get_llm_stub("new-feature")


def test_optional_stub_module_keeps_loaded_modules():
module_name = "test_loaded_optional_langchain_module"
loaded_module = types.ModuleType(module_name)
loaded_module.Sentinel = object
sys.modules[module_name] = loaded_module
try:
module, stubbed = _optional_stub_module(module_name)
assert module is loaded_module
assert stubbed is False
assert module.Sentinel is object
finally:
sys.modules.pop(module_name, None)


# ── Source-level tests: verify track_usage wraps every LLM call ──


def _read_source(relative_path: str) -> str:
"""Read a source file relative to the backend directory."""
backend_dir = Path(__file__).resolve().parent.parent.parent
return (backend_dir / relative_path).read_text()
return (backend_dir / relative_path).read_text(encoding="utf-8")


class TestGoalsTracking:
Expand Down
Loading