Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 159 additions & 14 deletions backend/tests/unit/test_lock_bypass_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import os
import pytest
import sys
from datetime import datetime, timedelta, timezone, tzinfo
from types import ModuleType
from zoneinfo import ZoneInfo

os.environ.setdefault('OPENAI_API_KEY', 'sk-test-not-real')
os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv')
Expand All @@ -27,8 +29,80 @@ def __getattr__(self, name):
return mock


class _ToolWrapper:
"""Tiny LangChain tool stand-in for tests that call `.invoke(...)`."""

def __init__(self, fn):
self.fn = fn
self.name = fn.__name__

def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)

def invoke(self, args=None, config=None):
if args is not None and not isinstance(args, dict):
if config is not None:
return self.fn(args, config=config)
return self.fn(args)

kwargs = dict(args or {})
if config is not None:
kwargs['config'] = config
return self.fn(**kwargs)
Comment on lines +42 to +51

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _ToolWrapper.invoke silently breaks for non-dict inputs

dict(args or {}) works when args is a dict, but raises ValueError if args is a plain string (e.g. tool.invoke("my_input")). The real LangChain StructuredTool.invoke accepts Union[str, dict] as input, so any future test that calls a stubbed @tool function with a string argument will get a confusing error pointing inside _ToolWrapper rather than at the failing test. Adding an isinstance(args, dict) branch would make the stub more robust.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 8301e85d5. _ToolWrapper.invoke now branches on non-dict input and forwards it as a positional argument, matching the stubbed LangChain tool use case instead of trying dict(args). I also added a regression check that wrapped.invoke('hello') == 'hello'.

Revalidated on the Windows backend venv:

  • python -m pytest tests\unit\test_lock_bypass_fixes.py -q -> 57 passed
  • python -m black --line-length 120 --skip-string-normalization tests\unit\test_lock_bypass_fixes.py --check
  • python -m py_compile tests\unit\test_lock_bypass_fixes.py



class _PytzZoneInfo(tzinfo):
"""Minimal pytz timezone stand-in with `.localize(...)` for summary tests."""

def __init__(self, key):
try:
self._zone = ZoneInfo(key)
except Exception:
if key == 'UTC':
self._zone = timezone.utc
elif key == 'Asia/Kolkata':
self._zone = timezone(timedelta(hours=5, minutes=30), key)
else:
raise

def localize(self, value):
if value.tzinfo is not None:
return value.astimezone(self)
return value.replace(tzinfo=self)

def _delegate_value(self, value):
if value is not None and value.tzinfo is self:
return value.replace(tzinfo=self._zone)
return value

def utcoffset(self, value):
return self._zone.utcoffset(self._delegate_value(value))

def dst(self, value):
return self._zone.dst(self._delegate_value(value))

def tzname(self, value):
return self._zone.tzname(self._delegate_value(value))

def fromutc(self, value):
localized = value.replace(tzinfo=timezone.utc).astimezone(self._zone)
return localized.replace(tzinfo=self)


def _tool(func=None, *args, **kwargs):
def decorator(fn):
return _ToolWrapper(fn)

if callable(func):
return decorator(func)
return decorator


_stubs = [
'anthropic',
'av',
'database._client',
'database.cache',
'database.redis_db',
'database.conversations',
'database.memories',
Expand All @@ -46,13 +120,44 @@ def __getattr__(self, name):
'database.daily_summaries',
'database.fair_use',
'database.auth',
'database.llm_usage',
'database.phone_calls',
'deepgram',
'deepgram.clients',
'deepgram.clients.live',
'deepgram.clients.live.v1',
'firebase_admin',
'firebase_admin.messaging',
'firebase_admin.auth',
'google.cloud.firestore',
'google.cloud.firestore_v1',
'google.cloud.firestore_v1.FieldFilter',
'langchain_core',
'langchain_core.callbacks',
'langchain_core.language_models',
'langchain_core.output_parsers',
'langchain_core.outputs',
'langchain_core.prompts',
'langchain_core.runnables',
'langchain_core.tools',
'langchain_google_genai',
'langchain_openai',
'openai',
'PIL',
'PIL.Image',
'pinecone',
'pycountry',
'pytz',
'scipy',
'scipy.spatial',
'scipy.spatial.distance',
'tiktoken',
'twilio',
'twilio.jwt',
'twilio.jwt.access_token',
'twilio.jwt.access_token.grants',
'twilio.request_validator',
'twilio.rest',
'typesense',
'opuslib',
'pydub',
Expand All @@ -67,13 +172,24 @@ def __getattr__(self, name):
'utils.conversations.process_conversation',
'utils.notifications',
'utils.apps',
'utils.llm.clients',
'utils.llm.memories',
'utils.llm.chat',
'utils.llm.usage_tracker',
'websockets',
]
for mod_name in _stubs:
if mod_name not in sys.modules:
sys.modules[mod_name] = _AutoMockModule(mod_name)

# Concrete attributes used by imported modules during lightweight tests.
sys.modules['langchain_core.callbacks'].BaseCallbackHandler = object
sys.modules['langchain_core.outputs'].LLMResult = object
sys.modules['langchain_core.runnables'].RunnableConfig = dict
sys.modules['langchain_core.tools'].tool = _tool
sys.modules['pytz'].timezone = _PytzZoneInfo
sys.modules['pytz'].utc = timezone.utc

# Override specific attributes that need concrete values
sys.modules['firebase_admin.auth'].InvalidIdTokenError = type('InvalidIdTokenError', (Exception,), {})
sys.modules['firebase_admin.auth'].ExpiredIdTokenError = type('ExpiredIdTokenError', (Exception,), {})
Expand All @@ -82,6 +198,28 @@ def __getattr__(self, name):
sys.modules['firebase_admin.auth'].UserNotFoundError = type('UserNotFoundError', (Exception,), {})


class TestLightweightStubHelpers:
"""Keep lightweight dependency stubs aligned with the real interfaces tests rely on."""

def test_tool_wrapper_invoke_accepts_string_input(self):
def echo(value):
return value

wrapped = _tool(echo)

assert wrapped.invoke('hello') == 'hello'

def test_pytz_stub_supports_localize_and_datetime_now(self):
import pytz

user_tz = pytz.timezone('UTC')
localized = user_tz.localize(datetime(2026, 6, 10, 12, 0, 0))

assert localized.tzinfo is user_tz
assert localized.astimezone(pytz.utc).hour == 12
assert datetime.now(user_tz).tzinfo is user_tz


def _make_conversation(locked=False, conversation_id='conv-1'):
"""Create a minimal conversation dict for DB-layer return values."""
return {
Expand Down Expand Up @@ -808,16 +946,18 @@ def test_scheduled_summary_excludes_locked(self):
unlocked_conv = _make_conversation(locked=False, conversation_id='conv-2')
conversations_db.get_conversations = MagicMock(return_value=[locked_conv, unlocked_conv])

with patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True):
with patch(
'utils.other.notifications.generate_comprehensive_daily_summary',
return_value={'headline': 'Test', 'day_emoji': '📅', 'overview': 'ok'},
) as mock_gen:
daily_summaries_db.create_daily_summary = MagicMock(return_value='summary-1')
with patch('utils.other.notifications.send_notification'):
from utils.other.notifications import _send_summary_notification
with patch('utils.other.notifications.is_trial_paywalled', return_value=False):
with patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True):
with patch(
'utils.other.notifications.generate_comprehensive_daily_summary',
return_value={'headline': 'Test', 'day_emoji': '📅', 'overview': 'ok'},
) as mock_gen:
daily_summaries_db.create_daily_summary = MagicMock(return_value='summary-1')
daily_summaries_db.get_daily_summary_by_date = MagicMock(return_value=None)
with patch('utils.other.notifications.send_notification'):
from utils.other.notifications import _send_summary_notification

_send_summary_notification(('test-uid', 'token', 'UTC'))
_send_summary_notification(('test-uid', 'token', 'UTC'))

# generate_comprehensive_daily_summary must be called only with unlocked conversations
mock_gen.assert_called_once()
Expand All @@ -828,14 +968,17 @@ def test_scheduled_summary_excludes_locked(self):
def test_scheduled_summary_skips_when_all_locked(self):
"""_send_summary_notification returns early when all conversations are locked."""
import database.conversations as conversations_db
import database.daily_summaries as daily_summaries_db

conversations_db.get_conversations = MagicMock(return_value=[_make_conversation(locked=True)])
daily_summaries_db.get_daily_summary_by_date = MagicMock(return_value=None)

with patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True):
with patch('utils.other.notifications.generate_comprehensive_daily_summary') as mock_gen:
from utils.other.notifications import _send_summary_notification
with patch('utils.other.notifications.is_trial_paywalled', return_value=False):
with patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True):
with patch('utils.other.notifications.generate_comprehensive_daily_summary') as mock_gen:
from utils.other.notifications import _send_summary_notification

_send_summary_notification(('test-uid', 'token', 'UTC'))
_send_summary_notification(('test-uid', 'token', 'UTC'))

# Should not call LLM when no unlocked conversations remain
mock_gen.assert_not_called()
Expand Down Expand Up @@ -1326,8 +1469,10 @@ def test_suggest_goal_filters_locked_memories(self):
mock_track.__exit__ = MagicMock(return_value=False)

with patch('utils.llm.goals.track_usage', return_value=mock_track):
with patch('utils.llm.goals.llm_mini') as mock_llm:
with patch('utils.llm.goals.get_llm') as mock_get_llm:
mock_llm = MagicMock()
mock_llm.invoke.return_value = mock_llm_response
mock_get_llm.return_value = mock_llm

from utils.llm.goals import suggest_goal

Expand Down
Loading