Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions backend/internal/service/antigravity_quota_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package service

import (
"context"
"fmt"
"log"
"strings"
"time"

"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
Expand All @@ -22,13 +25,24 @@ func (f *AntigravityQuotaFetcher) CanFetch(account *Account) bool {
if account.Platform != PlatformAntigravity {
return false
}
accessToken := account.GetCredential("access_token")
return accessToken != ""
if account.Type != AccountTypeOAuth {
return false
}
if account.GetCredential("access_token") != "" {
return true
}
return account.GetCredential("refresh_token") != ""
}

// FetchQuota 获取 Antigravity 账户额度信息
func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error) {
accessToken := account.GetCredential("access_token")
accessToken, err := f.resolveAccessToken(ctx, account, proxyURL)
if err != nil {
return nil, err
}
if accessToken == "" {
return nil, fmt.Errorf("no antigravity access_token available")
}
projectID := account.GetCredential("project_id")

client := antigravity.NewClient(proxyURL)
Expand All @@ -48,6 +62,38 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
}, nil
}

func (f *AntigravityQuotaFetcher) resolveAccessToken(ctx context.Context, account *Account, proxyURL string) (string, error) {
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
expiresAt := account.GetCredentialAsTime("expires_at")
if accessToken != "" && (expiresAt == nil || time.Until(*expiresAt) > 3*time.Minute) {
return accessToken, nil
}

refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
if refreshToken == "" {
if accessToken != "" {
return accessToken, nil
}
return "", fmt.Errorf("no antigravity refresh_token available")
}

client := antigravity.NewClient(proxyURL)
refreshResp, err := client.RefreshToken(ctx, refreshToken)
if err != nil {
if accessToken != "" {
log.Printf("[antigravity] token refresh failed during quota fetch, fallback to existing access_token: %v", err)
return accessToken, nil
}
return "", fmt.Errorf("refresh access_token failed: %w", err)
}

if strings.TrimSpace(refreshResp.AccessToken) == "" {
return "", fmt.Errorf("refreshed access_token is empty")
}

return refreshResp.AccessToken, nil
}

// buildUsageInfo 将 API 响应转换为 UsageInfo
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
now := time.Now()
Expand Down
74 changes: 62 additions & 12 deletions backend/internal/service/token_refresh_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)

// TokenRefreshService OAuth token自动刷新服务
Expand Down Expand Up @@ -150,6 +151,16 @@ func (s *TokenRefreshService) processRefresh() {

// 检查是否需要刷新
if !refresher.NeedsRefresh(account, refreshWindow) {
if account.Platform == PlatformAntigravity && account.Status == StatusError {
errMsg := strings.ToLower(account.ErrorMessage)
if strings.Contains(errMsg, "token refresh failed") || strings.Contains(errMsg, "invalid_client") {
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
slog.Warn("token_refresh.clear_stale_error_failed", "account_id", account.ID, "error", clearErr)
} else {
slog.Info("token_refresh.cleared_stale_antigravity_error", "account_id", account.ID)
}
}
}
break // 不需要刷新,跳过
}

Expand Down Expand Up @@ -195,7 +206,40 @@ func (s *TokenRefreshService) processRefresh() {
// listActiveAccounts 获取所有active状态的账号
// 使用ListActive确保刷新所有活跃账号的token(包括临时禁用的)
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account, error) {
return s.accountRepo.ListActive(ctx)
active, err := s.accountRepo.ListActive(ctx)
if err != nil {
return nil, err
}

// Antigravity 历史 refresh 错误账号也需要纳入刷新:
// client_secret 修复后,这些账号应可自动恢复为 active。
merged := make([]Account, 0, len(active)+8)
seen := make(map[int64]struct{}, len(active)+8)
for _, acc := range active {
merged = append(merged, acc)
seen[acc.ID] = struct{}{}
}

params := pagination.PaginationParams{Page: 1, PageSize: 100}
for {
errAccounts, pageInfo, listErr := s.accountRepo.ListWithFilters(ctx, params, PlatformAntigravity, "", StatusError, "", 0)
if listErr != nil {
return nil, listErr
}
for _, acc := range errAccounts {
if _, ok := seen[acc.ID]; ok {
continue
}
merged = append(merged, acc)
seen[acc.ID] = struct{}{}
}
if pageInfo == nil || params.Page >= pageInfo.Pages || len(errAccounts) == 0 {
break
}
params.Page++
}

return merged, nil
}

// refreshWithRetry 带重试的刷新
Expand All @@ -218,17 +262,23 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
}

if err == nil {
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
if account.Platform == PlatformAntigravity &&
account.Status == StatusError &&
strings.Contains(account.ErrorMessage, "missing_project_id:") {
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
slog.Warn("token_refresh.clear_account_error_failed",
"account_id", account.ID,
"error", clearErr,
)
} else {
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
// Antigravity 账户:刷新成功后清理历史 refresh 错误状态。
// 常见场景:client_secret 修复后,旧的 invalid_client 错误会残留在账号状态中。
if account.Platform == PlatformAntigravity && account.Status == StatusError {
errMsg := strings.ToLower(account.ErrorMessage)
if strings.Contains(errMsg, "missing_project_id:") ||
strings.Contains(errMsg, "token refresh failed") ||
strings.Contains(errMsg, "invalid_client") {
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
slog.Warn("token_refresh.clear_account_error_failed",
"account_id", account.ID,
"error", clearErr,
)
} else {
slog.Info("token_refresh.cleared_antigravity_error", "account_id", account.ID)
account.Status = StatusActive
account.ErrorMessage = ""
}
}
}
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
Expand Down
Loading