From 1652ff9ee770a2f44fcf378e9c1c73996f4878c7 Mon Sep 17 00:00:00 2001 From: hubo36 Date: Wed, 20 May 2026 16:27:34 +0800 Subject: [PATCH] fix(accounts): recover paid pool on manual quota refresh [skip changelog] Manual and import refreshes now probe all quota modes before inferring the account pool, so SuperGrok accounts misclassified as basic can recover. Partial quota responses keep the current pool instead of silently downgrading paid accounts.\n\nCloses #536 --- app/control/account/quota_defaults.py | 9 ++--- app/control/account/refresh.py | 51 ++++++++++++++++++++------- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/app/control/account/quota_defaults.py b/app/control/account/quota_defaults.py index eb3250c8..60488ca7 100644 --- a/app/control/account/quota_defaults.py +++ b/app/control/account/quota_defaults.py @@ -159,16 +159,17 @@ def normalize_quota_set(pool: str, quota_set: AccountQuotaSet) -> AccountQuotaSe return qs -def infer_pool(windows: dict[int, QuotaWindow]) -> str: +def infer_pool(windows: dict[int, QuotaWindow], *, fallback: str = "basic") -> str: """Infer pool type from live quota windows returned by the rate-limits API. Uses ``auto.total`` (mode_id=0) as the discriminating signal. - Falls back to ``"basic"`` when the value is absent or unrecognised. + Falls back to the current pool when the value is absent or unrecognised so + partial refreshes do not silently downgrade paid accounts. """ auto_win = windows.get(0) if auto_win is None: - return "basic" - return _AUTO_TOTAL_TO_POOL.get(auto_win.total, "basic") + return fallback + return _AUTO_TOTAL_TO_POOL.get(auto_win.total, fallback) __all__ = [ diff --git a/app/control/account/refresh.py b/app/control/account/refresh.py index f95a1baa..eb706462 100644 --- a/app/control/account/refresh.py +++ b/app/control/account/refresh.py @@ -53,6 +53,8 @@ def merge(self, other: "RefreshResult") -> None: 4: "quota_grok_4_3", } +_ALL_MODE_IDS = tuple(int(mode) for mode in ALL_MODES_FULL) + class AccountRefreshService: """Fetches real quota data from the upstream usage API and persists it. @@ -74,7 +76,10 @@ def __init__(self, repository: "AccountRepository") -> None: # ------------------------------------------------------------------ async def _fetch_all_quotas( - self, token: str, pool: str + self, + token: str, + pool: str, + mode_ids: tuple[int, ...] | None = None, ) -> dict[int, QuotaWindow] | None: """Fetch quota windows for every mode supported by *pool*. @@ -86,7 +91,8 @@ async def _fetch_all_quotas( try: from app.dataplane.reverse.protocol.xai_usage import fetch_all_quotas - return await fetch_all_quotas(token, supported_mode_ids(pool)) + requested = mode_ids if mode_ids is not None else supported_mode_ids(pool) + return await fetch_all_quotas(token, requested) except UpstreamError: raise except Exception as exc: @@ -140,7 +146,9 @@ async def refresh_on_import(self, tokens: list[str]) -> RefreshResult: concurrency = get_config("account.refresh.usage_concurrency", 50) results = await run_batch( active, - lambda r: self._refresh_one(r, apply_fallback=True), + lambda r: self._refresh_one( + r, apply_fallback=True, probe_all_modes=True + ), concurrency=concurrency, ) agg = RefreshResult(checked=len(records)) @@ -210,7 +218,11 @@ async def refresh_tokens(self, tokens: list[str]) -> RefreshResult: """Explicit refresh for a list of tokens (admin / manual trigger).""" records = [r for r in await self._repo.get_accounts(tokens) if is_manageable(r)] concurrency = get_config("account.refresh.usage_concurrency", 50) - results = await run_batch(records, self._refresh_one, concurrency=concurrency) + results = await run_batch( + records, + lambda r: self._refresh_one(r, probe_all_modes=True), + concurrency=concurrency, + ) agg = RefreshResult() for r in results: agg.merge(r) @@ -225,6 +237,7 @@ async def _refresh_one( record: AccountRecord, *, apply_fallback: bool = False, + probe_all_modes: bool = False, ) -> RefreshResult: """Fetch all pool-supported modes from the usage API and persist them. @@ -232,12 +245,18 @@ async def _refresh_one( decrement REAL quotas or reset expired DEFAULT windows. apply_fallback=False — used by manual/on-demand paths: if API fails, return failed=1 immediately without touching stored data. + probe_all_modes=True — used by import/manual refresh paths to recover paid + accounts that were previously misclassified as basic. """ if record.is_deleted(): return RefreshResult() try: - windows = await self._fetch_all_quotas(record.token, record.pool) + windows = await self._fetch_all_quotas( + record.token, + record.pool, + _ALL_MODE_IDS if probe_all_modes else None, + ) except UpstreamError as exc: if await self._expire_invalid_credentials(record, exc): return RefreshResult(checked=1, expired=1, failed=0) @@ -250,7 +269,14 @@ async def _refresh_one( # Scheduled/import path: apply conservative fallback. return await self._apply_fallback(record) - # We got at least a response — apply real data per mode. + # We got at least a response — infer the effective pool before + # normalising quotas. Auto-detect/manual refresh can probe all modes + # even when the stored pool is stale, so using record.pool here would + # drop paid-only windows and keep the account stuck in the basic pool. + effective_pool = infer_pool( + windows, fallback=record.pool + ) # type: ignore[arg-type] + qs = record.quota_set() now = now_ms() patches: dict[str, dict] = {} @@ -259,7 +285,9 @@ async def _refresh_one( for mode in ALL_MODES_FULL: mode_id = int(mode) if mode_id in windows: - window = normalize_quota_window(record.pool, mode_id, windows[mode_id]) + window = normalize_quota_window( + effective_pool, mode_id, windows[mode_id] + ) if window is None: continue patches[_MODE_KEYS[mode_id]] = window.to_dict() @@ -278,7 +306,7 @@ async def _refresh_one( source=QuotaSource.ESTIMATED, ).to_dict() elif existing.is_window_expired(now): - default = default_quota_window(record.pool, mode_id) + default = default_quota_window(effective_pool, mode_id) if default is None: continue patches[_MODE_KEYS[mode_id]] = QuotaWindow( @@ -293,15 +321,14 @@ async def _refresh_one( if not patches: return RefreshResult(checked=1, failed=0 if refreshed else 1) - # Infer pool type from live quota data and patch if it changed. - inferred = infer_pool(windows) # type: ignore[arg-type] - pool_patch = inferred if inferred != record.pool else None + # Patch pool type only when live data gives a confident replacement. + pool_patch = effective_pool if effective_pool != record.pool else None if pool_patch: logger.info( "account pool updated from live quota: token={}... previous_pool={} current_pool={}", record.token[:10], record.pool, - inferred, + effective_pool, ) from .commands import AccountPatch