Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/cmd/server/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 37 additions & 17 deletions backend/internal/service/openai_oauth_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,29 @@ func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *Ope
tokenInfo.Email = info.Email
}
}
if strings.TrimSpace(tokenInfo.SubscriptionExpiresAt) == "" {
if expiresAt := fetchChatGPTSubscriptionExpiresAt(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, resolveChatGPTSubscriptionAccountID(tokenInfo, orgID)); expiresAt != "" {
tokenInfo.SubscriptionExpiresAt = expiresAt
}
}

// 尝试设置隐私(关闭训练数据共享),best-effort
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
}

func resolveChatGPTSubscriptionAccountID(tokenInfo *OpenAITokenInfo, orgID string) string {
for _, candidate := range []string{
tokenInfo.ChatGPTAccountID,
tokenInfo.OrganizationID,
orgID,
} {
if trimmed := strings.TrimSpace(candidate); trimmed != "" {
return trimmed
}
}
return ""
}

// RefreshAccountToken refreshes token for an OpenAI OAuth account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if account.Platform != PlatformOpenAI {
Expand All @@ -292,38 +310,40 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
}

var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}

refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" {
accessToken := account.GetCredential("access_token")
if accessToken != "" {
tokenInfo := &OpenAITokenInfo{
AccessToken: accessToken,
RefreshToken: "",
IDToken: account.GetCredential("id_token"),
ClientID: account.GetCredential("client_id"),
Email: account.GetCredential("email"),
ChatGPTAccountID: account.GetCredential("chatgpt_account_id"),
ChatGPTUserID: account.GetCredential("chatgpt_user_id"),
OrganizationID: account.GetCredential("organization_id"),
PlanType: account.GetCredential("plan_type"),
AccessToken: accessToken,
RefreshToken: "",
IDToken: account.GetCredential("id_token"),
ClientID: account.GetCredential("client_id"),
Email: account.GetCredential("email"),
ChatGPTAccountID: account.GetCredential("chatgpt_account_id"),
ChatGPTUserID: account.GetCredential("chatgpt_user_id"),
OrganizationID: account.GetCredential("organization_id"),
PlanType: account.GetCredential("plan_type"),
SubscriptionExpiresAt: account.GetCredential("subscription_expires_at"),
}
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
tokenInfo.ExpiresAt = expiresAt.Unix()
tokenInfo.ExpiresIn = int64(time.Until(*expiresAt).Seconds())
}
s.enrichTokenInfo(ctx, tokenInfo, proxyURL)
return tokenInfo, nil
}
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
}

var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}

clientID := account.GetCredential("client_id")
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
}
Expand Down
7 changes: 7 additions & 0 deletions backend/internal/service/openai_oauth_service_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
)

Expand All @@ -32,6 +33,11 @@ func (s *openaiOAuthClientRefreshStub) RefreshTokenWithClientID(ctx context.Cont
func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccessToken(t *testing.T) {
client := &openaiOAuthClientRefreshStub{}
svc := NewOpenAIOAuthService(nil, client)
var privacyClientCalls int32
svc.SetPrivacyClientFactory(func(proxyURL string) (*req.Client, error) {
atomic.AddInt32(&privacyClientCalls, 1)
return nil, errors.New("stop before request")
})

expiresAt := time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339)
account := &Account{
Expand All @@ -51,6 +57,7 @@ func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccess
require.Equal(t, "existing-access-token", info.AccessToken)
require.Equal(t, "client-id-1", info.ClientID)
require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh")
require.Positive(t, atomic.LoadInt32(&privacyClientCalls), "existing access token should still run enrichment")
}

func TestOpenAITokenRefresher_NeedsRefresh_SkipsAccountWithoutRefreshToken(t *testing.T) {
Expand Down
58 changes: 58 additions & 0 deletions backend/internal/service/openai_privacy_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ type ChatGPTAccountInfo struct {

const chatGPTAccountsCheckURL = "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27"

var chatGPTSubscriptionsURL = "https://chatgpt.com/backend-api/subscriptions"

// fetchChatGPTAccountInfo calls ChatGPT backend-api to get account info (plan_type, etc.).
// Used as fallback when id_token doesn't contain these fields (e.g., Mobile RT).
// orgID is used to match the correct account when multiple accounts exist (e.g., personal + team).
Expand Down Expand Up @@ -199,6 +201,62 @@ func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFac
return info
}

// fetchChatGPTSubscriptionExpiresAt reads the lightweight subscription endpoint used by
// ChatGPT/Codex clients. Some Plus accounts no longer expose entitlement.expires_at in
// accounts/check, but this endpoint still returns active_until.
func fetchChatGPTSubscriptionExpiresAt(ctx context.Context, clientFactory PrivacyClientFactory, accessToken, proxyURL, accountID string) string {
accountID = strings.TrimSpace(accountID)
if accessToken == "" || accountID == "" || clientFactory == nil {
return ""
}

ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()

client, err := clientFactory(proxyURL)
if err != nil {
slog.Debug("chatgpt_subscription_client_error", "error", err.Error())
return ""
}

var result struct {
PlanType string `json:"plan_type"`
ActiveUntil string `json:"active_until"`
WillRenew bool `json:"will_renew"`
ID string `json:"id"`
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Origin", "https://chatgpt.com").
SetHeader("Referer", "https://chatgpt.com/").
SetHeader("Accept", "application/json").
SetSuccessResult(&result).
SetQueryParam("account_id", accountID).
Get(chatGPTSubscriptionsURL)
if err != nil {
slog.Debug("chatgpt_subscription_request_error", "error", err.Error())
return ""
}
if !resp.IsSuccessState() {
slog.Debug("chatgpt_subscription_failed", "status", resp.StatusCode, "body", truncate(resp.String(), 200))
return ""
}

activeUntil := strings.TrimSpace(result.ActiveUntil)
if activeUntil == "" {
slog.Debug("chatgpt_subscription_no_active_until", "plan_type", result.PlanType, "has_subscription_id", strings.TrimSpace(result.ID) != "", "will_renew", result.WillRenew)
return ""
}
if _, err := time.Parse(time.RFC3339, activeUntil); err != nil {
slog.Debug("chatgpt_subscription_bad_active_until", "active_until", activeUntil, "error", err.Error())
return ""
}

slog.Info("chatgpt_subscription_success", "plan_type", result.PlanType, "subscription_expires_at", activeUntil, "account_id", accountID)
return activeUntil
}

// fillAccountInfo 从单个 account 对象中提取 plan_type 和 subscription_expires_at
func fillAccountInfo(info *ChatGPTAccountInfo, acct map[string]any) {
info.PlanType = extractPlanType(acct)
Expand Down
42 changes: 42 additions & 0 deletions backend/internal/service/openai_subscription_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package service

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
)

func TestFetchChatGPTSubscriptionExpiresAt(t *testing.T) {
const wantExpiresAt = "2026-06-10T02:52:15Z"

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/backend-api/subscriptions", r.URL.Path)
require.Equal(t, "acc_123", r.URL.Query().Get("account_id"))
require.Equal(t, "Bearer access-token", r.Header.Get("Authorization"))

w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"plan_type": "plus",
"active_until": wantExpiresAt,
"will_renew": true,
"id": "sub_123",
})
}))
defer server.Close()

oldURL := chatGPTSubscriptionsURL
chatGPTSubscriptionsURL = server.URL + "/backend-api/subscriptions"
t.Cleanup(func() { chatGPTSubscriptionsURL = oldURL })

got := fetchChatGPTSubscriptionExpiresAt(context.Background(), func(proxyURL string) (*req.Client, error) {
return req.C().SetTimeout(5 * time.Second), nil
}, "access-token", "", "acc_123")

require.Equal(t, wantExpiresAt, got)
}
13 changes: 12 additions & 1 deletion backend/internal/service/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ func ProvideOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiToke
return NewOAuthRefreshAPI(accountRepo, tokenCache)
}

// ProvideOpenAIOAuthService creates OpenAIOAuthService with privacy/account enrichment support.
func ProvideOpenAIOAuthService(
proxyRepo ProxyRepository,
oauthClient OpenAIOAuthClient,
privacyClientFactory PrivacyClientFactory,
) *OpenAIOAuthService {
svc := NewOpenAIOAuthService(proxyRepo, oauthClient)
svc.SetPrivacyClientFactory(privacyClientFactory)
return svc
}

// ProvideTokenRefreshService creates and starts TokenRefreshService
func ProvideTokenRefreshService(
accountRepo AccountRepository,
Expand Down Expand Up @@ -461,7 +472,7 @@ var ProviderSet = wire.NewSet(
NewOpenAIGatewayService,
wire.Bind(new(AccountRuntimeBlocker), new(*OpenAIGatewayService)),
NewOAuthService,
NewOpenAIOAuthService,
ProvideOpenAIOAuthService,
NewGeminiOAuthService,
NewGeminiQuotaService,
NewCompositeTokenCacheInvalidator,
Expand Down
Loading