Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions backend/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,8 @@ type GatewayConfig struct {
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
// OpenAIScheduler: OpenAI 高级调度器粘性逃逸配置
OpenAIScheduler GatewayOpenAISchedulerConfig `mapstructure:"openai_scheduler"`
// OpenAIHTTP2: OpenAI HTTP 上游协议策略(默认启用 HTTP/2,可按代理能力回退 HTTP/1.1)
OpenAIHTTP2 GatewayOpenAIHTTP2Config `mapstructure:"openai_http2"`
// ImageConcurrency: 图片生成独立并发限制配置(默认关闭)
Expand Down Expand Up @@ -948,6 +950,16 @@ type GatewayOpenAIWSSchedulerScoreWeights struct {
TTFT float64 `mapstructure:"ttft"`
}

// GatewayOpenAISchedulerConfig OpenAI 高级调度器配置。
type GatewayOpenAISchedulerConfig struct {
// StickyEscapeEnabled: 是否允许 session_hash sticky 在账号健康度劣化时临时逃逸
StickyEscapeEnabled bool `mapstructure:"sticky_escape_enabled"`
// StickyEscapeTTFTMs: TTFT EWMA 超过该阈值时跳过 sticky
StickyEscapeTTFTMs int `mapstructure:"sticky_escape_ttft_ms"`
// StickyEscapeErrorRate: 错误率 EWMA 超过该阈值时跳过 sticky
StickyEscapeErrorRate float64 `mapstructure:"sticky_escape_error_rate"`
}

// GatewayUsageRecordConfig 使用量记录异步队列配置
type GatewayUsageRecordConfig struct {
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
Expand Down Expand Up @@ -1369,6 +1381,15 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
if err := viper.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("unmarshal config error: %w", err)
}
if cfg.Gateway.OpenAIScheduler.StickyEscapeTTFTMs == 0 {
cfg.Gateway.OpenAIScheduler.StickyEscapeTTFTMs = 15000
}
if cfg.Gateway.OpenAIScheduler.StickyEscapeErrorRate == 0 {
cfg.Gateway.OpenAIScheduler.StickyEscapeErrorRate = 0.5
}
if !cfg.Gateway.OpenAIScheduler.StickyEscapeEnabled && !viper.IsSet("gateway.openai_scheduler.sticky_escape_enabled") {
cfg.Gateway.OpenAIScheduler.StickyEscapeEnabled = true
}

cfg.RunMode = NormalizeRunMode(cfg.RunMode)
cfg.Server.Mode = strings.ToLower(strings.TrimSpace(cfg.Server.Mode))
Expand Down Expand Up @@ -2603,6 +2624,12 @@ func (c *Config) Validate() error {
if weightSum <= 0 {
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero")
}
if c.Gateway.OpenAIScheduler.StickyEscapeTTFTMs <= 0 {
return fmt.Errorf("gateway.openai_scheduler.sticky_escape_ttft_ms must be positive")
}
if c.Gateway.OpenAIScheduler.StickyEscapeErrorRate < 0 || c.Gateway.OpenAIScheduler.StickyEscapeErrorRate > 1 {
return fmt.Errorf("gateway.openai_scheduler.sticky_escape_error_rate must be between 0 and 1")
}
if c.Gateway.MaxLineSize < 0 {
return fmt.Errorf("gateway.max_line_size must be non-negative")
}
Expand Down
24 changes: 24 additions & 0 deletions backend/internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
if cfg.Gateway.OpenAIWS.StickySessionTTLSeconds != 3600 {
t.Fatalf("Gateway.OpenAIWS.StickySessionTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickySessionTTLSeconds)
}
if !cfg.Gateway.OpenAIScheduler.StickyEscapeEnabled {
t.Fatalf("Gateway.OpenAIScheduler.StickyEscapeEnabled = false, want true")
}
if cfg.Gateway.OpenAIScheduler.StickyEscapeTTFTMs != 15000 {
t.Fatalf("Gateway.OpenAIScheduler.StickyEscapeTTFTMs = %d, want 15000", cfg.Gateway.OpenAIScheduler.StickyEscapeTTFTMs)
}
if cfg.Gateway.OpenAIScheduler.StickyEscapeErrorRate != 0.5 {
t.Fatalf("Gateway.OpenAIScheduler.StickyEscapeErrorRate = %v, want 0.5", cfg.Gateway.OpenAIScheduler.StickyEscapeErrorRate)
}
if !cfg.Gateway.OpenAIWS.SessionHashReadOldFallback {
t.Fatalf("Gateway.OpenAIWS.SessionHashReadOldFallback = false, want true")
}
Expand Down Expand Up @@ -1705,6 +1714,21 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
},
wantErr: "gateway.openai_ws.scheduler_score_weights must not all be zero",
},
{
name: "sticky_escape_ttft_ms 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIScheduler.StickyEscapeTTFTMs = 0 },
wantErr: "gateway.openai_scheduler.sticky_escape_ttft_ms",
},
{
name: "sticky_escape_error_rate 不能小于 0",
mutate: func(c *Config) { c.Gateway.OpenAIScheduler.StickyEscapeErrorRate = -0.1 },
wantErr: "gateway.openai_scheduler.sticky_escape_error_rate",
},
{
name: "sticky_escape_error_rate 不能大于 1",
mutate: func(c *Config) { c.Gateway.OpenAIScheduler.StickyEscapeErrorRate = 1.1 },
wantErr: "gateway.openai_scheduler.sticky_escape_error_rate",
},
}

for _, tc := range cases {
Expand Down
105 changes: 90 additions & 15 deletions backend/internal/service/openai_account_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type OpenAIAccountScheduleRequest struct {
GroupID *int64
SessionHash string
StickyAccountID int64
PreserveStickyBinding bool
PreviousResponseID string
RequestedModel string
RequiredTransport OpenAIUpstreamTransport
Expand Down Expand Up @@ -241,6 +242,12 @@ type defaultOpenAIAccountScheduler struct {
stats *openAIAccountRuntimeStats
}

type openAIStickyEscapeConfig struct {
enabled bool
ttftMs float64
errorRate float64
}

func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler {
if stats == nil {
stats = newOpenAIAccountRuntimeStats()
Expand Down Expand Up @@ -296,7 +303,7 @@ func (s *defaultOpenAIAccountScheduler) Select(
}
}

selection, err := s.selectBySessionHash(ctx, req)
selection, escapedSticky, err := s.selectBySessionHash(ctx, req)
if err != nil {
return nil, decision, err
}
Expand All @@ -307,6 +314,9 @@ func (s *defaultOpenAIAccountScheduler) Select(
decision.SelectedAccountType = selection.Account.Type
return selection, decision, nil
}
if escapedSticky {
req.PreserveStickyBinding = true
}

selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req)
decision.Layer = openAIAccountScheduleLayerLoadBalance
Expand All @@ -326,49 +336,59 @@ func (s *defaultOpenAIAccountScheduler) Select(
func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, error) {
) (*AccountSelectionResult, bool, error) {
sessionHash := strings.TrimSpace(req.SessionHash)
if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil {
return nil, nil
return nil, false, nil
}

accountID := req.StickyAccountID
if accountID <= 0 {
var err error
accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash)
if err != nil || accountID <= 0 {
return nil, nil
return nil, false, nil
}
}
if accountID <= 0 {
return nil, nil
return nil, false, nil
}
if req.ExcludedIDs != nil {
if _, excluded := req.ExcludedIDs[accountID]; excluded {
return nil, nil
return nil, false, nil
}
}

account, err := s.service.getSchedulableAccount(ctx, accountID)
if err != nil || account == nil {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
return nil, false, nil
}
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() || !account.IsSchedulable() {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
return nil, false, nil
}
if !s.isAccountRequestCompatible(ctx, account, req) {
return nil, nil
return nil, false, nil
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
return nil, false, nil
}
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact, req.RequiredCapability)
if account == nil || !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
return nil, false, nil
}
escapeCfg := s.service.openAIStickyEscapeConfig()
if reason, errorRate, ttft, shouldEscape := s.shouldEscapeStickyAccount(accountID, escapeCfg); shouldEscape {
slog.Info("sticky_escape_triggered",
"account_id", accountID,
"reason", reason,
"error_rate", errorRate,
"ttft", ttft,
)
return nil, true, nil
}

result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
Expand All @@ -378,12 +398,22 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}, false, nil
}

cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
if s.service.concurrencyService != nil {
if escapeCfg.enabled && acquireErr == nil && result != nil && !result.Acquired {
errorRate, ttft, _ := s.stats.snapshot(accountID)
slog.Info("sticky_escape_triggered",
"account_id", accountID,
"reason", "concurrency_full",
"error_rate", errorRate,
"ttft", ttft,
)
return nil, true, nil
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
Expand All @@ -392,9 +422,23 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}, false, nil
}
return nil, false, nil
}

func (s *defaultOpenAIAccountScheduler) shouldEscapeStickyAccount(accountID int64, cfg openAIStickyEscapeConfig) (reason string, errorRate float64, ttft float64, shouldEscape bool) {
if !cfg.enabled || s == nil || s.stats == nil || accountID <= 0 {
return "", 0, 0, false
}
errorRate, ttft, hasTTFT := s.stats.snapshot(accountID)
if hasTTFT && ttft > cfg.ttftMs {
return "ttft", errorRate, ttft, true
}
if errorRate > cfg.errorRate {
return "error_rate", errorRate, ttft, true
}
return nil, nil
return "", errorRate, ttft, false
}

type openAIAccountCandidateScore struct {
Expand Down Expand Up @@ -810,7 +854,7 @@ func (s *defaultOpenAIAccountScheduler) tryAcquireOpenAISelectionOrder(
return nil, compactBlocked, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
if req.SessionHash != "" && !req.PreserveStickyBinding {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
}
return &AccountSelectionResult{
Expand Down Expand Up @@ -1305,6 +1349,37 @@ func (s *OpenAIGatewayService) openAIWSLBTopK() int {
return 7
}

func (s *OpenAIGatewayService) openAIStickyEscapeConfig() openAIStickyEscapeConfig {
if s != nil && s.cfg != nil {
cfg := s.cfg.Gateway.OpenAIScheduler
enabled := cfg.StickyEscapeEnabled
if !enabled && cfg.StickyEscapeTTFTMs == 0 && cfg.StickyEscapeErrorRate == 0 {
enabled = true
}
ttftMs := float64(cfg.StickyEscapeTTFTMs)
if ttftMs <= 0 {
ttftMs = 15000
}
errorRate := cfg.StickyEscapeErrorRate
if errorRate < 0 || errorRate > 1 {
errorRate = 0.5
}
if errorRate == 0 && cfg.StickyEscapeTTFTMs == 0 && cfg.StickyEscapeErrorRate == 0 {
errorRate = 0.5
}
return openAIStickyEscapeConfig{
enabled: enabled,
ttftMs: ttftMs,
errorRate: errorRate,
}
}
return openAIStickyEscapeConfig{
enabled: true,
ttftMs: 15000,
errorRate: 0.5,
}
}

func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView {
if s != nil && s.cfg != nil {
return GatewayOpenAIWSSchedulerScoreWeightsView{
Expand Down
Loading
Loading