From eba2046320ec643e7976296aabbc283f05aeaa52 Mon Sep 17 00:00:00 2001 From: xiaoyiluck666 <83876597+xiaoyiluck666@users.noreply.github.com> Date: Fri, 29 May 2026 13:34:10 +0800 Subject: [PATCH] fix: enrich OpenAI OAuth token refresh --- backend/cmd/server/wire_gen.go | 2 +- .../internal/service/openai_oauth_service.go | 54 +++++++++++------ .../openai_oauth_service_refresh_test.go | 7 +++ .../service/openai_privacy_service.go | 58 +++++++++++++++++++ .../service/openai_subscription_test.go | 42 ++++++++++++++ backend/internal/service/wire.go | 13 ++++- 6 files changed, 157 insertions(+), 19 deletions(-) create mode 100644 backend/internal/service/openai_subscription_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 465f5e25a7..441bcd6792 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -137,7 +137,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { httpUpstream := repository.NewHTTPUpstream(configConfig) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) openAIOAuthClient := repository.NewOpenAIOAuthClient() - openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) + openAIOAuthService := service.ProvideOpenAIOAuthService(proxyRepository, openAIOAuthClient, privacyClientFactory) oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) channelRepository := repository.NewChannelRepository(db) diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index dc094d43ce..0ee357a97c 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -278,11 +278,29 @@ func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *Ope tokenInfo.Email = info.Email } } + if strings.TrimSpace(tokenInfo.SubscriptionExpiresAt) == "" { + if expiresAt := fetchChatGPTSubscriptionExpiresAt(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, resolveChatGPTSubscriptionAccountID(tokenInfo, orgID)); expiresAt != "" { + tokenInfo.SubscriptionExpiresAt = expiresAt + } + } // 尝试设置隐私(关闭训练数据共享),best-effort tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL) } +func resolveChatGPTSubscriptionAccountID(tokenInfo *OpenAITokenInfo, orgID string) string { + for _, candidate := range []string{ + tokenInfo.ChatGPTAccountID, + tokenInfo.OrganizationID, + orgID, + } { + if trimmed := strings.TrimSpace(candidate); trimmed != "" { + return trimmed + } + } + return "" +} + // RefreshAccountToken refreshes token for an OpenAI OAuth account func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { if account.Platform != PlatformOpenAI { @@ -292,38 +310,40 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account") } + var proxyURL string + if account.ProxyID != nil { + proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID) + if err == nil && proxy != nil { + proxyURL = proxy.URL() + } + } + refreshToken := account.GetCredential("refresh_token") if refreshToken == "" { accessToken := account.GetCredential("access_token") if accessToken != "" { tokenInfo := &OpenAITokenInfo{ - AccessToken: accessToken, - RefreshToken: "", - IDToken: account.GetCredential("id_token"), - ClientID: account.GetCredential("client_id"), - Email: account.GetCredential("email"), - ChatGPTAccountID: account.GetCredential("chatgpt_account_id"), - ChatGPTUserID: account.GetCredential("chatgpt_user_id"), - OrganizationID: account.GetCredential("organization_id"), - PlanType: account.GetCredential("plan_type"), + AccessToken: accessToken, + RefreshToken: "", + IDToken: account.GetCredential("id_token"), + ClientID: account.GetCredential("client_id"), + Email: account.GetCredential("email"), + ChatGPTAccountID: account.GetCredential("chatgpt_account_id"), + ChatGPTUserID: account.GetCredential("chatgpt_user_id"), + OrganizationID: account.GetCredential("organization_id"), + PlanType: account.GetCredential("plan_type"), + SubscriptionExpiresAt: account.GetCredential("subscription_expires_at"), } if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil { tokenInfo.ExpiresAt = expiresAt.Unix() tokenInfo.ExpiresIn = int64(time.Until(*expiresAt).Seconds()) } + s.enrichTokenInfo(ctx, tokenInfo, proxyURL) return tokenInfo, nil } return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") } - var proxyURL string - if account.ProxyID != nil { - proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID) - if err == nil && proxy != nil { - proxyURL = proxy.URL() - } - } - clientID := account.GetCredential("client_id") return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) } diff --git a/backend/internal/service/openai_oauth_service_refresh_test.go b/backend/internal/service/openai_oauth_service_refresh_test.go index 84b68ea643..75588c8db4 100644 --- a/backend/internal/service/openai_oauth_service_refresh_test.go +++ b/backend/internal/service/openai_oauth_service_refresh_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/imroc/req/v3" "github.com/stretchr/testify/require" ) @@ -32,6 +33,11 @@ func (s *openaiOAuthClientRefreshStub) RefreshTokenWithClientID(ctx context.Cont func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccessToken(t *testing.T) { client := &openaiOAuthClientRefreshStub{} svc := NewOpenAIOAuthService(nil, client) + var privacyClientCalls int32 + svc.SetPrivacyClientFactory(func(proxyURL string) (*req.Client, error) { + atomic.AddInt32(&privacyClientCalls, 1) + return nil, errors.New("stop before request") + }) expiresAt := time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339) account := &Account{ @@ -51,6 +57,7 @@ func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccess require.Equal(t, "existing-access-token", info.AccessToken) require.Equal(t, "client-id-1", info.ClientID) require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh") + require.Positive(t, atomic.LoadInt32(&privacyClientCalls), "existing access token should still run enrichment") } func TestOpenAITokenRefresher_NeedsRefresh_SkipsAccountWithoutRefreshToken(t *testing.T) { diff --git a/backend/internal/service/openai_privacy_service.go b/backend/internal/service/openai_privacy_service.go index da6dbefc93..99cbb7267e 100644 --- a/backend/internal/service/openai_privacy_service.go +++ b/backend/internal/service/openai_privacy_service.go @@ -95,6 +95,8 @@ type ChatGPTAccountInfo struct { const chatGPTAccountsCheckURL = "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27" +var chatGPTSubscriptionsURL = "https://chatgpt.com/backend-api/subscriptions" + // fetchChatGPTAccountInfo calls ChatGPT backend-api to get account info (plan_type, etc.). // Used as fallback when id_token doesn't contain these fields (e.g., Mobile RT). // orgID is used to match the correct account when multiple accounts exist (e.g., personal + team). @@ -199,6 +201,62 @@ func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFac return info } +// fetchChatGPTSubscriptionExpiresAt reads the lightweight subscription endpoint used by +// ChatGPT/Codex clients. Some Plus accounts no longer expose entitlement.expires_at in +// accounts/check, but this endpoint still returns active_until. +func fetchChatGPTSubscriptionExpiresAt(ctx context.Context, clientFactory PrivacyClientFactory, accessToken, proxyURL, accountID string) string { + accountID = strings.TrimSpace(accountID) + if accessToken == "" || accountID == "" || clientFactory == nil { + return "" + } + + ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + client, err := clientFactory(proxyURL) + if err != nil { + slog.Debug("chatgpt_subscription_client_error", "error", err.Error()) + return "" + } + + var result struct { + PlanType string `json:"plan_type"` + ActiveUntil string `json:"active_until"` + WillRenew bool `json:"will_renew"` + ID string `json:"id"` + } + resp, err := client.R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+accessToken). + SetHeader("Origin", "https://chatgpt.com"). + SetHeader("Referer", "https://chatgpt.com/"). + SetHeader("Accept", "application/json"). + SetSuccessResult(&result). + SetQueryParam("account_id", accountID). + Get(chatGPTSubscriptionsURL) + if err != nil { + slog.Debug("chatgpt_subscription_request_error", "error", err.Error()) + return "" + } + if !resp.IsSuccessState() { + slog.Debug("chatgpt_subscription_failed", "status", resp.StatusCode, "body", truncate(resp.String(), 200)) + return "" + } + + activeUntil := strings.TrimSpace(result.ActiveUntil) + if activeUntil == "" { + slog.Debug("chatgpt_subscription_no_active_until", "plan_type", result.PlanType, "has_subscription_id", strings.TrimSpace(result.ID) != "", "will_renew", result.WillRenew) + return "" + } + if _, err := time.Parse(time.RFC3339, activeUntil); err != nil { + slog.Debug("chatgpt_subscription_bad_active_until", "active_until", activeUntil, "error", err.Error()) + return "" + } + + slog.Info("chatgpt_subscription_success", "plan_type", result.PlanType, "subscription_expires_at", activeUntil, "account_id", accountID) + return activeUntil +} + // fillAccountInfo 从单个 account 对象中提取 plan_type 和 subscription_expires_at func fillAccountInfo(info *ChatGPTAccountInfo, acct map[string]any) { info.PlanType = extractPlanType(acct) diff --git a/backend/internal/service/openai_subscription_test.go b/backend/internal/service/openai_subscription_test.go new file mode 100644 index 0000000000..89df54dbac --- /dev/null +++ b/backend/internal/service/openai_subscription_test.go @@ -0,0 +1,42 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/imroc/req/v3" + "github.com/stretchr/testify/require" +) + +func TestFetchChatGPTSubscriptionExpiresAt(t *testing.T) { + const wantExpiresAt = "2026-06-10T02:52:15Z" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/backend-api/subscriptions", r.URL.Path) + require.Equal(t, "acc_123", r.URL.Query().Get("account_id")) + require.Equal(t, "Bearer access-token", r.Header.Get("Authorization")) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "plan_type": "plus", + "active_until": wantExpiresAt, + "will_renew": true, + "id": "sub_123", + }) + })) + defer server.Close() + + oldURL := chatGPTSubscriptionsURL + chatGPTSubscriptionsURL = server.URL + "/backend-api/subscriptions" + t.Cleanup(func() { chatGPTSubscriptionsURL = oldURL }) + + got := fetchChatGPTSubscriptionExpiresAt(context.Background(), func(proxyURL string) (*req.Client, error) { + return req.C().SetTimeout(5 * time.Second), nil + }, "access-token", "", "acc_123") + + require.Equal(t, wantExpiresAt, got) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index b22e10ae38..e0c9f5910a 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -45,6 +45,17 @@ func ProvideOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiToke return NewOAuthRefreshAPI(accountRepo, tokenCache) } +// ProvideOpenAIOAuthService creates OpenAIOAuthService with privacy/account enrichment support. +func ProvideOpenAIOAuthService( + proxyRepo ProxyRepository, + oauthClient OpenAIOAuthClient, + privacyClientFactory PrivacyClientFactory, +) *OpenAIOAuthService { + svc := NewOpenAIOAuthService(proxyRepo, oauthClient) + svc.SetPrivacyClientFactory(privacyClientFactory) + return svc +} + // ProvideTokenRefreshService creates and starts TokenRefreshService func ProvideTokenRefreshService( accountRepo AccountRepository, @@ -461,7 +472,7 @@ var ProviderSet = wire.NewSet( NewOpenAIGatewayService, wire.Bind(new(AccountRuntimeBlocker), new(*OpenAIGatewayService)), NewOAuthService, - NewOpenAIOAuthService, + ProvideOpenAIOAuthService, NewGeminiOAuthService, NewGeminiQuotaService, NewCompositeTokenCacheInvalidator,