diff --git a/backend/tests/unit/test_daily_summary_race_condition.py b/backend/tests/unit/test_daily_summary_race_condition.py index 421799dedd..98b849f3c5 100644 --- a/backend/tests/unit/test_daily_summary_race_condition.py +++ b/backend/tests/unit/test_daily_summary_race_condition.py @@ -11,6 +11,7 @@ import sys import types import threading +from datetime import datetime, timedelta, timezone, tzinfo from unittest.mock import MagicMock, patch os.environ.setdefault( @@ -19,12 +20,125 @@ ) +_STUB_MODULE_NAMES = set() + + def _stub_module(name: str) -> types.ModuleType: mod = types.ModuleType(name) sys.modules[name] = mod + _STUB_MODULE_NAMES.add(name) return mod +def _remove_stub_module(name: str) -> None: + mod = sys.modules.pop(name, None) + if "." not in name or mod is None: + return + parent_name, attr_name = name.rsplit(".", 1) + parent = sys.modules.get(parent_name) + if getattr(parent, attr_name, None) is mod: + delattr(parent, attr_name) + + +def _remove_empty_stub_package(name: str) -> None: + mod = sys.modules.get(name) + if mod is None or getattr(mod, "__file__", None): + return + if getattr(mod, "__path__", None) == []: + _remove_stub_module(name) + + +def _clear_stale_package_tree(name: str) -> None: + mod = sys.modules.get(name) + if mod is not None and getattr(mod, "__file__", None): + return + if mod is None or getattr(mod, "__path__", None) == []: + prefix = f"{name}." + for module_name in list(sys.modules): + if module_name == name or module_name.startswith(prefix): + sys.modules.pop(module_name, None) + + +class _PytzFixedTimezone(tzinfo): + def __init__(self, offset: timedelta, name: str): + self._offset = offset + self._zone = timezone(offset, name) + + def utcoffset(self, dt): + return self._offset + + def dst(self, dt): + return timedelta(0) + + def tzname(self, dt): + return self._zone.tzname(dt) + + def fromutc(self, value): + return (value + self._offset).replace(tzinfo=self) + + def localize(self, value): + return value.replace(tzinfo=self) + + +class _PytzEasternTimezone(tzinfo): + _standard_offset = timedelta(hours=-5) + _daylight_offset = timedelta(hours=-4) + + @staticmethod + def _first_sunday_on_or_after(year: int, month: int, day: int) -> datetime: + value = datetime(year, month, day) + return value + timedelta(days=(6 - value.weekday()) % 7) + + @classmethod + def _dst_local_bounds(cls, year: int) -> tuple[datetime, datetime]: + start = cls._first_sunday_on_or_after(year, 3, 8).replace(hour=2) + end = cls._first_sunday_on_or_after(year, 11, 1).replace(hour=2) + return start, end + + @classmethod + def _is_dst_local(cls, value: datetime) -> bool: + start, end = cls._dst_local_bounds(value.year) + return start <= value.replace(tzinfo=None) < end + + @classmethod + def _is_dst_utc(cls, value: datetime) -> bool: + start_local, end_local = cls._dst_local_bounds(value.year) + start_utc = start_local - cls._standard_offset + end_utc = end_local - cls._daylight_offset + return start_utc <= value.replace(tzinfo=None) < end_utc + + def utcoffset(self, dt): + if dt is None: + return self._standard_offset + return self._daylight_offset if self._is_dst_local(dt) else self._standard_offset + + def dst(self, dt): + return self.utcoffset(dt) - self._standard_offset + + def tzname(self, dt): + return "EDT" if self.dst(dt) else "EST" + + def fromutc(self, value): + offset = self._daylight_offset if self._is_dst_utc(value) else self._standard_offset + return (value + offset).replace(tzinfo=self) + + def localize(self, value): + return value.replace(tzinfo=self) + + +_UTC_TZ = _PytzFixedTimezone(timedelta(0), "UTC") +_NY_TZ = _PytzEasternTimezone() +_PYTZ_ZONES = {"UTC": _UTC_TZ, "America/New_York": _NY_TZ} + +pytz_stub = types.ModuleType("pytz") +pytz_stub.utc = _UTC_TZ +pytz_stub.all_timezones = list(_PYTZ_ZONES) +pytz_stub.timezone = lambda name: _PYTZ_ZONES.get(name, _UTC_TZ) +if "pytz" not in sys.modules: + sys.modules["pytz"] = pytz_stub + _STUB_MODULE_NAMES.add("pytz") + + # Stub database package and submodules to avoid Firestore init. if "database" not in sys.modules: database_mod = _stub_module("database") @@ -70,6 +184,10 @@ def try_acquire_daily_summary_lock(uid: str, date: str, ttl: int = 60 * 60 * 2) client_mod.document_id_from_seed = MagicMock(return_value="doc-id") # Stub utils modules that pull in heavy dependencies. +_clear_stale_package_tree("utils") +for package_name in ["utils", "utils.other"]: + _remove_empty_stub_package(package_name) + for name in [ "utils.llm.external_integrations", "utils.notifications", @@ -121,8 +239,18 @@ def try_acquire_daily_summary_lock(uid: str, date: str, ttl: int = 60 * 60 * 2) sub_mod = _stub_module("utils.subscription") sub_mod.is_trial_paywalled = MagicMock(return_value=False) -# Now we can safely import -from utils.other.notifications import _send_summary_notification +# Now we can safely import the real module while keeping handles to its stubbed collaborators. +import utils.other.notifications as notifications_module + +_send_summary_notification = notifications_module._send_summary_notification +_CONVERSATIONS_DB = notifications_module.conversations_db +_DAILY_SUMMARIES_DB = notifications_module.daily_summaries_db +_GENERATE_COMPREHENSIVE_DAILY_SUMMARY = notifications_module.generate_comprehensive_daily_summary +_SEND_NOTIFICATION = notifications_module.send_notification + +for stub_name in sorted(_STUB_MODULE_NAMES, key=lambda item: item.count("."), reverse=True): + _remove_stub_module(stub_name) +_remove_stub_module("utils.other.notifications") class TestTryAcquireDailySummaryLock: @@ -209,12 +337,12 @@ def test_redis_error_propagates_no_silent_swallow(self): class TestSendSummaryNotificationLockIntegration: """Verify _send_summary_notification respects the lock.""" - @patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=False) + @patch.object(notifications_module, 'try_acquire_daily_summary_lock', return_value=False) def test_skips_when_lock_not_acquired(self, mock_lock): - convos_db = sys.modules["database.conversations"] + convos_db = _CONVERSATIONS_DB convos_db.get_conversations = MagicMock() - gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary - send_mock = sys.modules["utils.notifications"].send_notification + gen_mock = _GENERATE_COMPREHENSIVE_DAILY_SUMMARY + send_mock = _SEND_NOTIFICATION convos_db.get_conversations.reset_mock() gen_mock.reset_mock() @@ -228,19 +356,19 @@ def test_skips_when_lock_not_acquired(self, mock_lock): gen_mock.assert_not_called() send_mock.assert_not_called() - @patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True) + @patch.object(notifications_module, 'try_acquire_daily_summary_lock', return_value=True) def test_proceeds_when_lock_acquired(self, mock_lock): - convos_db = sys.modules["database.conversations"] + convos_db = _CONVERSATIONS_DB convos_db.get_conversations = MagicMock(return_value=[{'id': 'c1'}]) - gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary + gen_mock = _GENERATE_COMPREHENSIVE_DAILY_SUMMARY gen_mock.return_value = {'day_emoji': '!', 'headline': 'Test', 'overview': 'Summary'} - daily_db = sys.modules["database.daily_summaries"] + daily_db = _DAILY_SUMMARIES_DB daily_db.get_daily_summary_by_date = MagicMock(return_value=None) daily_db.create_daily_summary = MagicMock(return_value='summary-123') - send_mock = sys.modules["utils.notifications"].send_notification + send_mock = _SEND_NOTIFICATION send_mock.reset_mock() user_data = ('uid1', ['token1'], 'America/New_York') @@ -251,22 +379,22 @@ def test_proceeds_when_lock_acquired(self, mock_lock): gen_mock.assert_called_once() send_mock.assert_called_once() - @patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True) + @patch.object(notifications_module, 'try_acquire_daily_summary_lock', return_value=True) def test_skips_when_summary_already_exists(self, mock_lock): """#4608: if a summary already exists for the date (lock lost on a later tick), skip before spending LLM tokens or sending — do not create a duplicate doc.""" - convos_db = sys.modules["database.conversations"] + convos_db = _CONVERSATIONS_DB convos_db.get_conversations = MagicMock() convos_db.get_conversations.reset_mock() - gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary + gen_mock = _GENERATE_COMPREHENSIVE_DAILY_SUMMARY gen_mock.reset_mock() - daily_db = sys.modules["database.daily_summaries"] + daily_db = _DAILY_SUMMARIES_DB daily_db.get_daily_summary_by_date = MagicMock(return_value={'id': 'existing-1'}) daily_db.create_daily_summary = MagicMock() - send_mock = sys.modules["utils.notifications"].send_notification + send_mock = _SEND_NOTIFICATION send_mock.reset_mock() user_data = ('uid1', ['token1'], 'America/New_York') @@ -280,19 +408,19 @@ def test_skips_when_summary_already_exists(self, mock_lock): daily_db.create_daily_summary.assert_not_called() send_mock.assert_not_called() - @patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True) + @patch.object(notifications_module, 'try_acquire_daily_summary_lock', return_value=True) def test_summary_lookup_error_propagates_no_duplicate(self, mock_lock): """#4608: a transient Firestore error during the by-date lookup must propagate (skip this tick, retry next) rather than being swallowed into a duplicate-creating path.""" - daily_db = sys.modules["database.daily_summaries"] + daily_db = _DAILY_SUMMARIES_DB daily_db.get_daily_summary_by_date = MagicMock(side_effect=Exception("Firestore unavailable")) daily_db.create_daily_summary = MagicMock() - convos_db = sys.modules["database.conversations"] + convos_db = _CONVERSATIONS_DB convos_db.get_conversations = MagicMock() convos_db.get_conversations.reset_mock() - gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary + gen_mock = _GENERATE_COMPREHENSIVE_DAILY_SUMMARY gen_mock.reset_mock() user_data = ('uid1', ['token1'], 'America/New_York') @@ -307,17 +435,17 @@ def test_summary_lookup_error_propagates_no_duplicate(self, mock_lock): gen_mock.assert_not_called() convos_db.get_conversations.assert_not_called() - @patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True) + @patch.object(notifications_module, 'try_acquire_daily_summary_lock', return_value=True) def test_no_conversations_skips_llm(self, mock_lock): - convos_db = sys.modules["database.conversations"] + convos_db = _CONVERSATIONS_DB convos_db.get_conversations = MagicMock(return_value=[]) - daily_db = sys.modules["database.daily_summaries"] + daily_db = _DAILY_SUMMARIES_DB daily_db.get_daily_summary_by_date = MagicMock(return_value=None) - gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary + gen_mock = _GENERATE_COMPREHENSIVE_DAILY_SUMMARY gen_mock.reset_mock() - send_mock = sys.modules["utils.notifications"].send_notification + send_mock = _SEND_NOTIFICATION send_mock.reset_mock() user_data = ('uid1', ['token1'], 'America/New_York') @@ -328,14 +456,14 @@ def test_no_conversations_skips_llm(self, mock_lock): gen_mock.assert_not_called() send_mock.assert_not_called() - @patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=False) + @patch.object(notifications_module, 'try_acquire_daily_summary_lock', return_value=False) def test_utc_fallback_still_acquires_lock(self, mock_lock): """User data without timezone falls back to UTC; lock must still be called.""" - convos_db = sys.modules["database.conversations"] + convos_db = _CONVERSATIONS_DB convos_db.get_conversations = MagicMock() convos_db.get_conversations.reset_mock() - gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary + gen_mock = _GENERATE_COMPREHENSIVE_DAILY_SUMMARY gen_mock.reset_mock() # No timezone element in tuple — triggers UTC fallback