diff --git a/backend/tests/unit/test_paywall_reconnect_gate.py b/backend/tests/unit/test_paywall_reconnect_gate.py index 3cc8ae2a34..2ad7886d72 100644 --- a/backend/tests/unit/test_paywall_reconnect_gate.py +++ b/backend/tests/unit/test_paywall_reconnect_gate.py @@ -19,7 +19,7 @@ def _read_source(path): - with open(path) as f: + with open(path, encoding='utf-8') as f: return f.read() @@ -231,19 +231,25 @@ def test_is_trial_paywalled_checks_desktop_tokens(self): assert 'macos' in src, "desktop tokens must include 'macos'" assert 'desktop' in src, "desktop tokens must include 'desktop'" - def test_is_trial_paywalled_respects_kill_switch(self): + def test_is_trial_paywalled_filters_before_expiry_lookup(self): src = _read_source(SUBSCRIPTION_SRC_PATH) fn_start = src.find('def is_trial_paywalled(') assert fn_start != -1 fn_body = src[fn_start : src.find('\ndef ', fn_start + 1)] - assert '_TRIAL_PAYWALL_ENABLED' in fn_body, "is_trial_paywalled must respect kill switch" + filter_pos = fn_body.find('platform.lower() not in _TRIAL_PAYWALL_DESKTOP_TOKENS') + expiry_pos = fn_body.find('_is_trial_expired_cached(uid)') + assert filter_pos != -1, "is_trial_paywalled must filter non-desktop platforms" + assert expiry_pos != -1, "is_trial_paywalled must call the cached expiry lookup" + assert filter_pos < expiry_pos, "platform filtering must happen before the expiry lookup" - def test_is_trial_paywalled_respects_test_uid_gating(self): + def test_is_trial_paywalled_delegates_to_cached_expiry(self): src = _read_source(SUBSCRIPTION_SRC_PATH) fn_start = src.find('def is_trial_paywalled(') assert fn_start != -1 fn_body = src[fn_start : src.find('\ndef ', fn_start + 1)] - assert '_TRIAL_PAYWALL_TEST_UIDS' in fn_body, "is_trial_paywalled must respect test UID gating" + assert ( + 'return _is_trial_expired_cached(uid)' in fn_body + ), "desktop paywall decisions must use the cached expiry lookup" def test_is_trial_paywalled_uses_lower_for_case_insensitivity(self): src = _read_source(SUBSCRIPTION_SRC_PATH) @@ -323,8 +329,6 @@ def _stub(name): self._mock_expired = MagicMock(return_value=True) self._orig_expired = sub._is_trial_expired_cached sub._is_trial_expired_cached = self._mock_expired - sub._TRIAL_PAYWALL_ENABLED = True - sub._TRIAL_PAYWALL_TEST_UIDS = set() yield @@ -361,20 +365,17 @@ def test_mixed_case_desktop(self): assert self._sub.is_trial_paywalled('uid1', 'Desktop') is True assert self._sub.is_trial_paywalled('uid1', 'MACOS') is True - def test_kill_switch_disabled(self): - self._sub._TRIAL_PAYWALL_ENABLED = False + def test_desktop_cache_false_returns_false(self): + self._mock_expired.return_value = False assert self._sub.is_trial_paywalled('uid1', 'desktop') is False - self._sub._TRIAL_PAYWALL_ENABLED = True - def test_test_uid_gating_allows_listed(self): - self._sub._TRIAL_PAYWALL_TEST_UIDS = {'uid1', 'uid2'} + def test_desktop_uid_delegates_to_expiry_cache(self): assert self._sub.is_trial_paywalled('uid1', 'desktop') is True - self._sub._TRIAL_PAYWALL_TEST_UIDS = set() + self._mock_expired.assert_called_with('uid1') - def test_test_uid_gating_blocks_unlisted(self): - self._sub._TRIAL_PAYWALL_TEST_UIDS = {'uid1', 'uid2'} - assert self._sub.is_trial_paywalled('uid99', 'desktop') is False - self._sub._TRIAL_PAYWALL_TEST_UIDS = set() + def test_different_desktop_uid_uses_same_expiry_path(self): + assert self._sub.is_trial_paywalled('uid99', 'desktop') is True + self._mock_expired.assert_called_with('uid99') def test_not_expired_returns_false(self): self._mock_expired.return_value = False @@ -446,8 +447,6 @@ def _stub(name): self._sub = sub self._byok = byok - sub._TRIAL_PAYWALL_ENABLED = True - sub._TRIAL_PAYWALL_TEST_UIDS = set() yield