Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions app/control/account/quota_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,17 @@ def normalize_quota_set(pool: str, quota_set: AccountQuotaSet) -> AccountQuotaSe
return qs


def infer_pool(windows: dict[int, QuotaWindow]) -> str:
def infer_pool(windows: dict[int, QuotaWindow], *, fallback: str = "basic") -> str:
"""Infer pool type from live quota windows returned by the rate-limits API.

Uses ``auto.total`` (mode_id=0) as the discriminating signal.
Falls back to ``"basic"`` when the value is absent or unrecognised.
Falls back to the current pool when the value is absent or unrecognised so
partial refreshes do not silently downgrade paid accounts.
"""
auto_win = windows.get(0)
if auto_win is None:
return "basic"
return _AUTO_TOTAL_TO_POOL.get(auto_win.total, "basic")
return fallback
return _AUTO_TOTAL_TO_POOL.get(auto_win.total, fallback)


__all__ = [
Expand Down
51 changes: 39 additions & 12 deletions app/control/account/refresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def merge(self, other: "RefreshResult") -> None:
4: "quota_grok_4_3",
}

_ALL_MODE_IDS = tuple(int(mode) for mode in ALL_MODES_FULL)


class AccountRefreshService:
"""Fetches real quota data from the upstream usage API and persists it.
Expand All @@ -74,7 +76,10 @@ def __init__(self, repository: "AccountRepository") -> None:
# ------------------------------------------------------------------

async def _fetch_all_quotas(
self, token: str, pool: str
self,
token: str,
pool: str,
mode_ids: tuple[int, ...] | None = None,
) -> dict[int, QuotaWindow] | None:
"""Fetch quota windows for every mode supported by *pool*.

Expand All @@ -86,7 +91,8 @@ async def _fetch_all_quotas(
try:
from app.dataplane.reverse.protocol.xai_usage import fetch_all_quotas

return await fetch_all_quotas(token, supported_mode_ids(pool))
requested = mode_ids if mode_ids is not None else supported_mode_ids(pool)
return await fetch_all_quotas(token, requested)
except UpstreamError:
raise
except Exception as exc:
Expand Down Expand Up @@ -140,7 +146,9 @@ async def refresh_on_import(self, tokens: list[str]) -> RefreshResult:
concurrency = get_config("account.refresh.usage_concurrency", 50)
results = await run_batch(
active,
lambda r: self._refresh_one(r, apply_fallback=True),
lambda r: self._refresh_one(
r, apply_fallback=True, probe_all_modes=True
),
concurrency=concurrency,
)
agg = RefreshResult(checked=len(records))
Expand Down Expand Up @@ -210,7 +218,11 @@ async def refresh_tokens(self, tokens: list[str]) -> RefreshResult:
"""Explicit refresh for a list of tokens (admin / manual trigger)."""
records = [r for r in await self._repo.get_accounts(tokens) if is_manageable(r)]
concurrency = get_config("account.refresh.usage_concurrency", 50)
results = await run_batch(records, self._refresh_one, concurrency=concurrency)
results = await run_batch(
records,
lambda r: self._refresh_one(r, probe_all_modes=True),
concurrency=concurrency,
)
agg = RefreshResult()
for r in results:
agg.merge(r)
Expand All @@ -225,19 +237,26 @@ async def _refresh_one(
record: AccountRecord,
*,
apply_fallback: bool = False,
probe_all_modes: bool = False,
) -> RefreshResult:
"""Fetch all pool-supported modes from the usage API and persist them.

apply_fallback=True — used by scheduled/import paths: when API fails,
decrement REAL quotas or reset expired DEFAULT windows.
apply_fallback=False — used by manual/on-demand paths: if API fails, return
failed=1 immediately without touching stored data.
probe_all_modes=True — used by import/manual refresh paths to recover paid
accounts that were previously misclassified as basic.
"""
if record.is_deleted():
return RefreshResult()

try:
windows = await self._fetch_all_quotas(record.token, record.pool)
windows = await self._fetch_all_quotas(
record.token,
record.pool,
_ALL_MODE_IDS if probe_all_modes else None,
)
except UpstreamError as exc:
if await self._expire_invalid_credentials(record, exc):
return RefreshResult(checked=1, expired=1, failed=0)
Expand All @@ -250,7 +269,14 @@ async def _refresh_one(
# Scheduled/import path: apply conservative fallback.
return await self._apply_fallback(record)

# We got at least a response — apply real data per mode.
# We got at least a response — infer the effective pool before
# normalising quotas. Auto-detect/manual refresh can probe all modes
# even when the stored pool is stale, so using record.pool here would
# drop paid-only windows and keep the account stuck in the basic pool.
effective_pool = infer_pool(
windows, fallback=record.pool
) # type: ignore[arg-type]

qs = record.quota_set()
now = now_ms()
patches: dict[str, dict] = {}
Expand All @@ -259,7 +285,9 @@ async def _refresh_one(
for mode in ALL_MODES_FULL:
mode_id = int(mode)
if mode_id in windows:
window = normalize_quota_window(record.pool, mode_id, windows[mode_id])
window = normalize_quota_window(
effective_pool, mode_id, windows[mode_id]
)
if window is None:
continue
patches[_MODE_KEYS[mode_id]] = window.to_dict()
Expand All @@ -278,7 +306,7 @@ async def _refresh_one(
source=QuotaSource.ESTIMATED,
).to_dict()
elif existing.is_window_expired(now):
default = default_quota_window(record.pool, mode_id)
default = default_quota_window(effective_pool, mode_id)
if default is None:
continue
patches[_MODE_KEYS[mode_id]] = QuotaWindow(
Expand All @@ -293,15 +321,14 @@ async def _refresh_one(
if not patches:
return RefreshResult(checked=1, failed=0 if refreshed else 1)

# Infer pool type from live quota data and patch if it changed.
inferred = infer_pool(windows) # type: ignore[arg-type]
pool_patch = inferred if inferred != record.pool else None
# Patch pool type only when live data gives a confident replacement.
pool_patch = effective_pool if effective_pool != record.pool else None
if pool_patch:
logger.info(
"account pool updated from live quota: token={}... previous_pool={} current_pool={}",
record.token[:10],
record.pool,
inferred,
effective_pool,
)

from .commands import AccountPatch
Expand Down