diff --git a/app/control/account/quota_defaults.py b/app/control/account/quota_defaults.py index eb3250c8..60488ca7 100644 --- a/app/control/account/quota_defaults.py +++ b/app/control/account/quota_defaults.py @@ -159,16 +159,17 @@ def normalize_quota_set(pool: str, quota_set: AccountQuotaSet) -> AccountQuotaSe return qs -def infer_pool(windows: dict[int, QuotaWindow]) -> str: +def infer_pool(windows: dict[int, QuotaWindow], *, fallback: str = "basic") -> str: """Infer pool type from live quota windows returned by the rate-limits API. Uses ``auto.total`` (mode_id=0) as the discriminating signal. - Falls back to ``"basic"`` when the value is absent or unrecognised. + Falls back to the current pool when the value is absent or unrecognised so + partial refreshes do not silently downgrade paid accounts. """ auto_win = windows.get(0) if auto_win is None: - return "basic" - return _AUTO_TOTAL_TO_POOL.get(auto_win.total, "basic") + return fallback + return _AUTO_TOTAL_TO_POOL.get(auto_win.total, fallback) __all__ = [ diff --git a/app/control/account/refresh.py b/app/control/account/refresh.py index f95a1baa..eb706462 100644 --- a/app/control/account/refresh.py +++ b/app/control/account/refresh.py @@ -53,6 +53,8 @@ def merge(self, other: "RefreshResult") -> None: 4: "quota_grok_4_3", } +_ALL_MODE_IDS = tuple(int(mode) for mode in ALL_MODES_FULL) + class AccountRefreshService: """Fetches real quota data from the upstream usage API and persists it. @@ -74,7 +76,10 @@ def __init__(self, repository: "AccountRepository") -> None: # ------------------------------------------------------------------ async def _fetch_all_quotas( - self, token: str, pool: str + self, + token: str, + pool: str, + mode_ids: tuple[int, ...] | None = None, ) -> dict[int, QuotaWindow] | None: """Fetch quota windows for every mode supported by *pool*. @@ -86,7 +91,8 @@ async def _fetch_all_quotas( try: from app.dataplane.reverse.protocol.xai_usage import fetch_all_quotas - return await fetch_all_quotas(token, supported_mode_ids(pool)) + requested = mode_ids if mode_ids is not None else supported_mode_ids(pool) + return await fetch_all_quotas(token, requested) except UpstreamError: raise except Exception as exc: @@ -140,7 +146,9 @@ async def refresh_on_import(self, tokens: list[str]) -> RefreshResult: concurrency = get_config("account.refresh.usage_concurrency", 50) results = await run_batch( active, - lambda r: self._refresh_one(r, apply_fallback=True), + lambda r: self._refresh_one( + r, apply_fallback=True, probe_all_modes=True + ), concurrency=concurrency, ) agg = RefreshResult(checked=len(records)) @@ -210,7 +218,11 @@ async def refresh_tokens(self, tokens: list[str]) -> RefreshResult: """Explicit refresh for a list of tokens (admin / manual trigger).""" records = [r for r in await self._repo.get_accounts(tokens) if is_manageable(r)] concurrency = get_config("account.refresh.usage_concurrency", 50) - results = await run_batch(records, self._refresh_one, concurrency=concurrency) + results = await run_batch( + records, + lambda r: self._refresh_one(r, probe_all_modes=True), + concurrency=concurrency, + ) agg = RefreshResult() for r in results: agg.merge(r) @@ -225,6 +237,7 @@ async def _refresh_one( record: AccountRecord, *, apply_fallback: bool = False, + probe_all_modes: bool = False, ) -> RefreshResult: """Fetch all pool-supported modes from the usage API and persist them. @@ -232,12 +245,18 @@ async def _refresh_one( decrement REAL quotas or reset expired DEFAULT windows. apply_fallback=False — used by manual/on-demand paths: if API fails, return failed=1 immediately without touching stored data. + probe_all_modes=True — used by import/manual refresh paths to recover paid + accounts that were previously misclassified as basic. """ if record.is_deleted(): return RefreshResult() try: - windows = await self._fetch_all_quotas(record.token, record.pool) + windows = await self._fetch_all_quotas( + record.token, + record.pool, + _ALL_MODE_IDS if probe_all_modes else None, + ) except UpstreamError as exc: if await self._expire_invalid_credentials(record, exc): return RefreshResult(checked=1, expired=1, failed=0) @@ -250,7 +269,14 @@ async def _refresh_one( # Scheduled/import path: apply conservative fallback. return await self._apply_fallback(record) - # We got at least a response — apply real data per mode. + # We got at least a response — infer the effective pool before + # normalising quotas. Auto-detect/manual refresh can probe all modes + # even when the stored pool is stale, so using record.pool here would + # drop paid-only windows and keep the account stuck in the basic pool. + effective_pool = infer_pool( + windows, fallback=record.pool + ) # type: ignore[arg-type] + qs = record.quota_set() now = now_ms() patches: dict[str, dict] = {} @@ -259,7 +285,9 @@ async def _refresh_one( for mode in ALL_MODES_FULL: mode_id = int(mode) if mode_id in windows: - window = normalize_quota_window(record.pool, mode_id, windows[mode_id]) + window = normalize_quota_window( + effective_pool, mode_id, windows[mode_id] + ) if window is None: continue patches[_MODE_KEYS[mode_id]] = window.to_dict() @@ -278,7 +306,7 @@ async def _refresh_one( source=QuotaSource.ESTIMATED, ).to_dict() elif existing.is_window_expired(now): - default = default_quota_window(record.pool, mode_id) + default = default_quota_window(effective_pool, mode_id) if default is None: continue patches[_MODE_KEYS[mode_id]] = QuotaWindow( @@ -293,15 +321,14 @@ async def _refresh_one( if not patches: return RefreshResult(checked=1, failed=0 if refreshed else 1) - # Infer pool type from live quota data and patch if it changed. - inferred = infer_pool(windows) # type: ignore[arg-type] - pool_patch = inferred if inferred != record.pool else None + # Patch pool type only when live data gives a confident replacement. + pool_patch = effective_pool if effective_pool != record.pool else None if pool_patch: logger.info( "account pool updated from live quota: token={}... previous_pool={} current_pool={}", record.token[:10], record.pool, - inferred, + effective_pool, ) from .commands import AccountPatch