diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index df9dcefc601..0df7d09e84a 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -885,6 +885,12 @@ type GatewayOpenAIWSConfig struct { StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"` // PrewarmGenerateEnabled: 是否启用 WSv2 generate=false 预热(默认 false) PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"` + // ClientReadLimitBytes: 入站客户端 WS 单帧读取上限。 + ClientReadLimitBytes int64 `mapstructure:"client_read_limit_bytes"` + // HTTPBridgeEnabled: 首包过大时,保持客户端 WS,改用 HTTP Responses 上游。 + HTTPBridgeEnabled bool `mapstructure:"http_bridge_enabled"` + // HTTPBridgeThresholdBytes: 触发 HTTP bridge 的入站 WS payload 阈值。 + HTTPBridgeThresholdBytes int64 `mapstructure:"http_bridge_threshold_bytes"` // Feature 开关:v2 优先于 v1 ResponsesWebsockets bool `mapstructure:"responses_websockets"` @@ -1806,6 +1812,9 @@ func setDefaults() { viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict") viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true) viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false) + viper.SetDefault("gateway.openai_ws.client_read_limit_bytes", 64*1024*1024) + viper.SetDefault("gateway.openai_ws.http_bridge_enabled", true) + viper.SetDefault("gateway.openai_ws.http_bridge_threshold_bytes", 15*1024*1024) viper.SetDefault("gateway.openai_ws.responses_websockets", false) viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true) viper.SetDefault("gateway.openai_ws.max_conns_per_account", 128) @@ -2543,6 +2552,15 @@ func (c *Config) Validate() error { if c.Gateway.OpenAIWS.PrewarmCooldownMS < 0 { return fmt.Errorf("gateway.openai_ws.prewarm_cooldown_ms must be non-negative") } + if c.Gateway.OpenAIWS.ClientReadLimitBytes <= 0 { + return fmt.Errorf("gateway.openai_ws.client_read_limit_bytes must be positive") + } + if c.Gateway.OpenAIWS.HTTPBridgeThresholdBytes < 0 { + return fmt.Errorf("gateway.openai_ws.http_bridge_threshold_bytes must be non-negative") + } + if c.Gateway.OpenAIWS.HTTPBridgeEnabled && c.Gateway.OpenAIWS.HTTPBridgeThresholdBytes == 0 { + return fmt.Errorf("gateway.openai_ws.http_bridge_threshold_bytes must be positive when http_bridge_enabled is true") + } if c.Gateway.OpenAIWS.FallbackCooldownSeconds < 0 { return fmt.Errorf("gateway.openai_ws.fallback_cooldown_seconds must be non-negative") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 1eae5ed9595..9478b5102da 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -134,6 +134,15 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) { if cfg.Gateway.OpenAIWS.PrewarmCooldownMS != 300 { t.Fatalf("Gateway.OpenAIWS.PrewarmCooldownMS = %d, want 300", cfg.Gateway.OpenAIWS.PrewarmCooldownMS) } + if cfg.Gateway.OpenAIWS.ClientReadLimitBytes != 64*1024*1024 { + t.Fatalf("Gateway.OpenAIWS.ClientReadLimitBytes = %d, want %d", cfg.Gateway.OpenAIWS.ClientReadLimitBytes, 64*1024*1024) + } + if !cfg.Gateway.OpenAIWS.HTTPBridgeEnabled { + t.Fatalf("Gateway.OpenAIWS.HTTPBridgeEnabled = false, want true") + } + if cfg.Gateway.OpenAIWS.HTTPBridgeThresholdBytes != 15*1024*1024 { + t.Fatalf("Gateway.OpenAIWS.HTTPBridgeThresholdBytes = %d, want %d", cfg.Gateway.OpenAIWS.HTTPBridgeThresholdBytes, 15*1024*1024) + } if cfg.Gateway.OpenAIWS.RetryBackoffInitialMS != 120 { t.Fatalf("Gateway.OpenAIWS.RetryBackoffInitialMS = %d, want 120", cfg.Gateway.OpenAIWS.RetryBackoffInitialMS) } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 0aa477b08cd..f3d4caf08db 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -1167,7 +1167,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { defer func() { _ = wsConn.CloseNow() }() - wsConn.SetReadLimit(16 * 1024 * 1024) + wsConn.SetReadLimit(service.ResolveOpenAIWSClientReadLimitBytes(h.cfg)) ctx := c.Request.Context() readCtx, cancel := context.WithTimeout(ctx, 30*time.Second) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 89ddaa7d1c5..b1a9559452b 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -256,6 +256,9 @@ type OpenAIForwardResult struct { ImageOutputSizes []string ImageSizeSource string ImageSizeBreakdown map[string]int + + wsReplayInput []json.RawMessage + wsReplayInputExists bool } type OpenAIWSRetryMetricsSnapshot struct { diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index dd2f45facbd..66c66134b0e 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1560,6 +1560,38 @@ func openAIWSRawItemsHasFunctionCallOutput(items []json.RawMessage) bool { return false } +func openAIWSRawItemsHaveToolCallContextForOutputs(items []json.RawMessage) bool { + if len(items) == 0 { + return false + } + contextCallIDs := make(map[string]struct{}) + outputCallIDs := make(map[string]struct{}) + for _, item := range items { + itemType := gjson.GetBytes(item, "type").String() + callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) + switch { + case isCodexToolCallContextItemType(itemType): + if callID != "" { + contextCallIDs[callID] = struct{}{} + } + case isCodexToolCallOutputItemType(itemType): + if callID == "" { + return false + } + outputCallIDs[callID] = struct{}{} + } + } + if len(outputCallIDs) == 0 || len(contextCallIDs) == 0 { + return false + } + for callID := range outputCallIDs { + if _, ok := contextCallIDs[callID]; !ok { + return false + } + } + return true +} + func openAIWSRawPayloadHasToolCallOutput(payload []byte) bool { if len(payload) == 0 { return false @@ -2664,6 +2696,27 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( }, nil } + writeClientMessage := func(message []byte) error { + writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + defer cancel() + return clientConn.Write(writeCtx, coderws.MessageText, message) + } + + readClientMessage := func() ([]byte, error) { + msgType, payload, readErr := clientConn.Read(ctx) + if readErr != nil { + return nil, readErr + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + fmt.Sprintf("unsupported websocket client message type: %s", msgType.String()), + nil, + ) + } + return payload, nil + } + firstPayload, err := parseClientPayload(firstClientMessage) if err != nil { return err @@ -2672,25 +2725,152 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) stateStore := s.getOpenAIWSStateStore() groupID := getOpenAIGroupIDFromContext(c) - sessionHash := s.GenerateSessionHash(c, firstPayload.rawForHash) - if turnState == "" && stateStore != nil && sessionHash != "" { - if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { - turnState = savedTurnState + storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() + sessionHash := "" + preferredConnID := "" + storeDisabled := false + refreshIngressRouteState := func(payload openAIWSClientPayload) { + sessionHash = s.GenerateSessionHash(c, payload.rawForHash) + if turnState == "" && stateStore != nil && sessionHash != "" { + if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { + turnState = savedTurnState + } } - } - preferredConnID := "" - if stateStore != nil && firstPayload.previousResponseID != "" { - if connID, ok := stateStore.GetResponseConn(firstPayload.previousResponseID); ok { - preferredConnID = connID + preferredConnID = "" + if stateStore != nil && payload.previousResponseID != "" { + if connID, ok := stateStore.GetResponseConn(payload.previousResponseID); ok { + preferredConnID = connID + } + } + + storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(payload.payloadRaw, account) + if stateStore != nil && storeDisabled && payload.previousResponseID == "" && sessionHash != "" { + if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { + preferredConnID = connID + } } } + refreshIngressRouteState(firstPayload) - storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account) - storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() - if stateStore != nil && storeDisabled && firstPayload.previousResponseID == "" && sessionHash != "" { - if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { - preferredConnID = connID + if s.shouldBridgeOpenAIWSHTTP(firstPayload.payloadBytes, firstPayload.previousResponseID) { + logOpenAIWSModeInfo( + "ingress_ws_http_bridge_start account_id=%d account_type=%s payload_bytes=%d threshold_bytes=%d has_session_hash=%v store_disabled=%v", + account.ID, + account.Type, + firstPayload.payloadBytes, + s.openAIWSHTTPBridgeThresholdBytes(), + sessionHash != "", + storeDisabled, + ) + currentBridgePayload := firstPayload + var bridgeReplayInput []json.RawMessage + bridgeReplayInputExists := false + for turn := 1; ; turn++ { + if turn > 1 && hooks != nil && hooks.BeforeRequest != nil { + if err := hooks.BeforeRequest(turn, currentBridgePayload.payloadRaw, currentBridgePayload.originalModel); err != nil { + return err + } + } + if hooks != nil && hooks.BeforeTurn != nil { + if err := hooks.BeforeTurn(turn); err != nil { + return err + } + } + if turnState != "" && c != nil && c.Request != nil { + c.Request.Header.Set(openAIWSTurnStateHeader, turnState) + } + bridgePayloadRaw := currentBridgePayload.payloadRaw + bridgePayloadBytes := currentBridgePayload.payloadBytes + needsBridgeReplay := currentBridgePayload.previousResponseID != "" || openAIWSRawPayloadHasToolCallOutput(currentBridgePayload.payloadRaw) + turnReplayInput, turnReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence( + bridgeReplayInput, + bridgeReplayInputExists, + currentBridgePayload.payloadRaw, + needsBridgeReplay, + ) + if replayInputErr != nil { + return fmt.Errorf("build websocket http bridge replay input: %w", replayInputErr) + } + if needsBridgeReplay && turnReplayInputExists { + updatedPayload, setInputErr := setOpenAIWSPayloadInputSequence( + currentBridgePayload.payloadRaw, + turnReplayInput, + true, + ) + if setInputErr != nil { + return fmt.Errorf("set websocket http bridge replay input: %w", setInputErr) + } + bridgePayloadRaw = updatedPayload + bridgePayloadBytes = len(updatedPayload) + logOpenAIWSModeInfo( + "ingress_ws_http_bridge_replay_input account_id=%d turn=%d input_items=%d previous_response_id_present=%v has_tool_output=%v", + account.ID, + turn, + len(turnReplayInput), + currentBridgePayload.previousResponseID != "", + openAIWSRawPayloadHasToolCallOutput(currentBridgePayload.payloadRaw), + ) + } + result, bridgeErr := s.proxyOpenAIWSHTTPBridgeTurn( + ctx, + c, + account, + token, + bridgePayloadRaw, + bridgePayloadBytes, + currentBridgePayload.originalModel, + currentBridgePayload.imageBillingModel, + currentBridgePayload.imageSizeTier, + currentBridgePayload.imageInputSize, + turn, + writeClientMessage, + ) + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turn, result, bridgeErr) + } + if bridgeErr != nil { + return bridgeErr + } + if result == nil { + return errors.New("websocket http bridge turn result is nil") + } + bridgeReplayInput = cloneOpenAIWSRawMessages(turnReplayInput) + bridgeReplayInputExists = turnReplayInputExists + if result.wsReplayInputExists { + bridgeReplayInput = append(bridgeReplayInput, cloneOpenAIWSRawMessages(result.wsReplayInput)...) + bridgeReplayInputExists = true + } + if bridgeTurnState := strings.TrimSpace(result.ResponseHeaders.Get(openAIWSTurnStateHeader)); bridgeTurnState != "" { + turnState = bridgeTurnState + if stateStore != nil && sessionHash != "" { + stateStore.BindSessionTurnState(groupID, sessionHash, bridgeTurnState, s.openAIWSSessionStickyTTL()) + } + } + responseID := strings.TrimSpace(result.RequestID) + if responseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) + } + nextClientMessage, readErr := readClientMessage() + if readErr != nil { + if isOpenAIWSClientDisconnectError(readErr) { + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "ingress_ws_http_bridge_client_closed account_id=%d close_status=%s close_reason=%s", + account.ID, + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + return nil + } + return fmt.Errorf("read client websocket request: %w", readErr) + } + nextPayload, parseErr := parseClientPayload(nextClientMessage) + if parseErr != nil { + return parseErr + } + currentBridgePayload = nextPayload } } @@ -2844,27 +3024,6 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( return lease, nil } - writeClientMessage := func(message []byte) error { - writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) - defer cancel() - return clientConn.Write(writeCtx, coderws.MessageText, message) - } - - readClientMessage := func() ([]byte, error) { - msgType, payload, readErr := clientConn.Read(ctx) - if readErr != nil { - return nil, readErr - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - return nil, NewOpenAIWSClientCloseError( - coderws.StatusPolicyViolation, - fmt.Sprintf("unsupported websocket client message type: %s", msgType.String()), - nil, - ) - } - return payload, nil - } - sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string, imageBillingModel string, imageSizeTier string, imageInputSize string) (*OpenAIForwardResult, error) { if lease == nil { return nil, errors.New("upstream websocket lease is nil") @@ -2901,6 +3060,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( eventCount := 0 tokenEventCount := 0 terminalEventCount := 0 + replayCollector := &openAIWSToolCallReplayCollector{} firstEventType := "" lastEventType := "" needModelReplace := false @@ -3031,6 +3191,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( upstreamMessage = corrected } } + replayCollector.AddEvent(eventType, upstreamMessage) if err := writeClientMessage(upstreamMessage); err != nil { if isOpenAIWSClientDisconnectError(err) { clientDisconnected = true @@ -3094,6 +3255,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( Duration: time.Since(turnStart), FirstTokenMs: firstTokenMs, } + if replayInput := replayCollector.Items(); len(replayInput) > 0 { + result.wsReplayInput = replayInput + result.wsReplayInputExists = true + } if imageCount > 0 { result.ImageCount = imageCount result.ImageSize = imageSizeTier @@ -3487,9 +3652,12 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( if forcePreferredConn { // 携带 function_call_output 的请求不能丢弃 previous_response_id: // 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use, - // 丢弃后会导致 "No tool call found for function call output" 400 错误。 + // 除非 replay input 已经包含与每个 tool_result 匹配的 tool_use 上下文。 hasFCOutput := hasFunctionCallOutput - if !turnPrevRecoveryTried && currentPreviousResponseID != "" && !hasFCOutput { + hasReplayToolContext := hasFCOutput && + currentTurnReplayInputExists && + openAIWSRawItemsHaveToolCallContextForOutputs(currentTurnReplayInput) + if !turnPrevRecoveryTried && currentPreviousResponseID != "" && (!hasFCOutput || hasReplayToolContext) { updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) if dropErr != nil || !removed { reason := "not_removed" @@ -3521,11 +3689,13 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ) } else { logOpenAIWSModeInfo( - "ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s", + "ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s has_function_call_output=%v has_replay_tool_context=%v", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + hasFCOutput, + hasReplayToolContext, ) turnPrevRecoveryTried = true currentPayload = updatedWithInput @@ -3537,12 +3707,18 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } } if hasFCOutput && currentPreviousResponseID != "" { + reason := "function_call_output_missing_replay_context" + if hasReplayToolContext { + reason = "function_call_output_replay_not_applied" + } logOpenAIWSModeInfo( - "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=function_call_output action=fail_close previous_response_id=%s", + "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s action=fail_close previous_response_id=%s has_replay_tool_context=%v", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + reason, truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + hasReplayToolContext, ) } resetSessionLease(true) @@ -3622,6 +3798,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload) lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput) lastTurnReplayInputExists = currentTurnReplayInputExists + if result.wsReplayInputExists { + lastTurnReplayInput = append(lastTurnReplayInput, cloneOpenAIWSRawMessages(result.wsReplayInput)...) + lastTurnReplayInputExists = true + } nextStrictState, strictStateErr := buildOpenAIWSIngressPreviousTurnStrictState(currentPayload) if strictStateErr != nil { lastTurnStrictState = nil diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index b7f1bc4f78d..069a9dee1f3 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -2305,6 +2305,161 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStr require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPreflightPingFailReplaysFunctionCallOutputWithContext(t *testing.T) { + gin.SetMode(gin.TestMode) + prevPreflightPingIdle := openAIWSIngressPreflightPingIdle + openAIWSIngressPreflightPingIdle = 0 + defer func() { + openAIWSIngressPreflightPingIdle = prevPreflightPingIdle + }() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSPreflightFailConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_replay_ctx_1","model":"gpt-5.1","output":[{"type":"function_call","id":"fc_replay_1","call_id":"call_replay_1","name":"shell","arguments":"{}"}],"usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_replay_ctx_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 128, + Name: "openai-ingress-preflight-replay-function-output-with-context", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"call tool"}]}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_ping_replay_ctx_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_replay_ctx_1","input":[{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_ping_replay_ctx_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket function_call_output 自包含重放后结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "带完整 tool 上下文的 function_call_output 应在 ping 失败后换新连接重放") + require.Equal(t, 1, firstConn.WriteCount()) + require.GreaterOrEqual(t, firstConn.PingCount(), 1) + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1) + secondWrite := requestToJSONString(secondWrites[0]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists()) + require.Equal(t, 3, len(gjson.Get(secondWrite, "input").Array())) + require.Equal(t, "message", gjson.Get(secondWrite, "input.0.type").String()) + require.Equal(t, "function_call", gjson.Get(secondWrite, "input.1.type").String()) + require.Equal(t, "call_replay_1", gjson.Get(secondWrite, "input.1.call_id").String()) + require.Equal(t, "function_call_output", gjson.Get(secondWrite, "input.2.type").String()) + require.Equal(t, "call_replay_1", gjson.Get(secondWrite, "input.2.call_id").String()) +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPreflightPingFailClosesWhenFunctionCallOutputNeedsPreviousResponseID(t *testing.T) { gin.SetMode(gin.TestMode) prevPreflightPingIdle := openAIWSIngressPreflightPingIdle diff --git a/backend/internal/service/openai_ws_http_bridge.go b/backend/internal/service/openai_ws_http_bridge.go new file mode 100644 index 00000000000..1f0f32a0575 --- /dev/null +++ b/backend/internal/service/openai_ws_http_bridge.go @@ -0,0 +1,387 @@ +package service + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +const ( + openAIWSClientReadLimitBytesDefault int64 = 64 * 1024 * 1024 + openAIWSHTTPBridgeThresholdBytesDefault int64 = 15 * 1024 * 1024 + openAIWSHTTPBridgeErrorBodyLimitBytes = 64 * 1024 +) + +func ResolveOpenAIWSClientReadLimitBytes(cfg *config.Config) int64 { + if cfg == nil || cfg.Gateway.OpenAIWS.ClientReadLimitBytes <= 0 { + return openAIWSClientReadLimitBytesDefault + } + return cfg.Gateway.OpenAIWS.ClientReadLimitBytes +} + +func (s *OpenAIGatewayService) openAIWSHTTPBridgeEnabled() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.HTTPBridgeEnabled +} + +func (s *OpenAIGatewayService) openAIWSHTTPBridgeThresholdBytes() int64 { + if s == nil || s.cfg == nil || s.cfg.Gateway.OpenAIWS.HTTPBridgeThresholdBytes <= 0 { + return openAIWSHTTPBridgeThresholdBytesDefault + } + return s.cfg.Gateway.OpenAIWS.HTTPBridgeThresholdBytes +} + +func (s *OpenAIGatewayService) shouldBridgeOpenAIWSHTTP(payloadBytes int, previousResponseID string) bool { + if !s.openAIWSHTTPBridgeEnabled() { + return false + } + if strings.TrimSpace(previousResponseID) != "" { + return false + } + threshold := s.openAIWSHTTPBridgeThresholdBytes() + return threshold > 0 && int64(payloadBytes) >= threshold +} + +func prepareOpenAIWSHTTPBridgeBody(payload []byte) ([]byte, error) { + var body map[string]any + if err := json.Unmarshal(payload, &body); err != nil { + return nil, err + } + if body == nil { + return nil, errors.New("response.create payload must be a JSON object") + } + delete(body, "type") + delete(body, "generate") + delete(body, "previous_response_id") + body["stream"] = true + return json.Marshal(body) +} + +type openAIWSToolCallReplayCollector struct { + items []json.RawMessage + seen map[string]struct{} +} + +func (c *openAIWSToolCallReplayCollector) AddEvent(eventType string, message []byte) { + switch strings.TrimSpace(eventType) { + case "response.output_item.done": + c.addItem(gjson.GetBytes(message, "item")) + case "response.completed", "response.done": + output := gjson.GetBytes(message, "response.output") + if !output.IsArray() { + return + } + for _, item := range output.Array() { + c.addItem(item) + } + } +} + +func (c *openAIWSToolCallReplayCollector) Items() []json.RawMessage { + return cloneOpenAIWSRawMessages(c.items) +} + +func (c *openAIWSToolCallReplayCollector) addItem(item gjson.Result) { + if !item.Exists() || item.Type != gjson.JSON { + return + } + raw := strings.TrimSpace(item.Raw) + if raw == "" || !strings.HasPrefix(raw, "{") { + return + } + if !isCodexToolCallContextItemType(item.Get("type").String()) { + return + } + key := strings.TrimSpace(item.Get("id").String()) + if key == "" { + key = strings.TrimSpace(item.Get("call_id").String()) + } + if key == "" { + key = raw + } + if c.seen == nil { + c.seen = make(map[string]struct{}) + } + if _, ok := c.seen[key]; ok { + return + } + c.seen[key] = struct{}{} + c.items = append(c.items, json.RawMessage(raw)) +} + +func buildOpenAIWSHTTPBridgeErrorEvent(statusCode int, message string) []byte { + message = strings.TrimSpace(message) + if message == "" { + message = http.StatusText(statusCode) + } + if message == "" { + message = "upstream request failed" + } + event := map[string]any{ + "type": "error", + "status": statusCode, + "error": map[string]any{ + "type": "upstream_error", + "message": message, + }, + } + body, err := json.Marshal(event) + if err != nil { + return []byte(`{"type":"error","error":{"type":"upstream_error","message":"upstream request failed"}}`) + } + return body +} + +func (s *OpenAIGatewayService) proxyOpenAIWSHTTPBridgeTurn( + ctx context.Context, + c *gin.Context, + account *Account, + token string, + payload []byte, + payloadBytes int, + originalModel string, + imageBillingModel string, + imageSizeTier string, + imageInputSize string, + turn int, + writeClientMessage func([]byte) error, +) (*OpenAIForwardResult, error) { + if s == nil { + return nil, errors.New("service is nil") + } + if s.httpUpstream == nil { + return nil, errors.New("openai http upstream is nil") + } + if account == nil { + return nil, errors.New("account is nil") + } + if writeClientMessage == nil { + return nil, errors.New("client websocket writer is nil") + } + + body, err := prepareOpenAIWSHTTPBridgeBody(payload) + if err != nil { + return nil, fmt.Errorf("prepare http bridge body: %w", err) + } + + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) + releaseUpstreamCtx() + if err != nil { + return nil, err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + if c != nil { + c.Set("openai_passthrough", true) + c.Set("openai_ws_http_bridge", true) + } + + turnStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + _ = writeClientMessage(buildOpenAIWSHTTPBridgeErrorEvent(http.StatusBadGateway, "Upstream request failed")) + return nil, fmt.Errorf("upstream http bridge request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, openAIWSHTTPBridgeErrorBodyLimitBytes)) + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if upstreamMsg == "" { + upstreamMsg = http.StatusText(resp.StatusCode) + } + _ = writeClientMessage(buildOpenAIWSHTTPBridgeErrorEvent(resp.StatusCode, upstreamMsg)) + return nil, fmt.Errorf("upstream http bridge error: status=%d message=%s", resp.StatusCode, upstreamMsg) + } + + responseID := "" + usage := OpenAIUsage{} + imageCounter := newOpenAIImageOutputCounter() + var firstTokenMs *int + reqStream := openAIWSPayloadBoolFromRaw(body, "stream", true) + eventCount := 0 + tokenEventCount := 0 + terminalEventCount := 0 + replayCollector := &openAIWSToolCallReplayCollector{} + firstEventType := "" + lastEventType := "" + sawDone := false + wroteDownstream := false + clientDisconnected := false + mappedModel := "" + needModelReplace := false + var mappedModelBytes []byte + if originalModel != "" { + mappedModel = normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel)) + needModelReplace = mappedModel != "" && mappedModel != originalModel + if needModelReplace { + mappedModelBytes = []byte(mappedModel) + } + } + + resultWithUsage := func() *OpenAIForwardResult { + imageCount := imageCounter.Count() + result := &OpenAIForwardResult{ + RequestID: responseID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ServiceTier: extractOpenAIServiceTierFromBody(body), + ReasoningEffort: extractOpenAIReasoningEffortFromBody(body, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + ResponseHeaders: cloneHeader(resp.Header), + Duration: time.Since(turnStart), + FirstTokenMs: firstTokenMs, + } + if replayInput := replayCollector.Items(); len(replayInput) > 0 { + result.wsReplayInput = replayInput + result.wsReplayInputExists = true + } + if imageCount > 0 { + result.ImageCount = imageCount + result.ImageSize = imageSizeTier + result.ImageInputSize = imageInputSize + result.ImageOutputSizes = imageCounter.Sizes() + result.BillingModel = imageBillingModel + } + return result + } + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + defer putSSEScannerBuf64K(scanBuf) + + for scanner.Scan() { + line := scanner.Text() + data, ok := extractOpenAISSEDataLine(line) + if !ok { + continue + } + trimmedData := strings.TrimSpace(data) + if trimmedData == "" { + continue + } + if trimmedData == "[DONE]" { + sawDone = true + continue + } + + upstreamMessage := []byte(trimmedData) + eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(upstreamMessage) + if responseID == "" && eventResponseID != "" { + responseID = eventResponseID + } + if eventType != "" { + eventCount++ + if firstEventType == "" { + firstEventType = eventType + } + lastEventType = eventType + } + if isOpenAIWSTokenEvent(eventType) { + tokenEventCount++ + if firstTokenMs == nil { + ms := int(time.Since(turnStart).Milliseconds()) + firstTokenMs = &ms + } + } + if openAIWSEventShouldParseUsage(eventType) { + parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage) + } + imageCounter.AddSSEData(upstreamMessage) + + if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && strings.Contains(trimmedData, mappedModel) { + upstreamMessage = replaceOpenAIWSMessageModel(upstreamMessage, mappedModel, originalModel) + } + if s.toolCorrector != nil && openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) { + if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(upstreamMessage); changed { + upstreamMessage = corrected + } + } + replayCollector.AddEvent(eventType, upstreamMessage) + + if !clientDisconnected { + if err := writeClientMessage(upstreamMessage); err != nil { + if isOpenAIWSClientDisconnectError(err) { + clientDisconnected = true + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err) + logOpenAIWSModeInfo( + "ingress_ws_http_bridge_client_disconnected_drain account_id=%d turn=%d close_status=%s close_reason=%s", + account.ID, + turn, + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + } else { + return nil, wrapOpenAIWSIngressTurnError( + "write_client", + fmt.Errorf("write client websocket event: %w", err), + wroteDownstream, + ) + } + } else { + wroteDownstream = true + } + } + + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage) + s.persistOpenAIWSRateLimitSignal(ctx, account, resp.Header, upstreamMessage, errCodeRaw, errTypeRaw, errMsgRaw) + errMessage := strings.TrimSpace(errMsgRaw) + if errMessage == "" { + errMessage = "upstream error event" + } + return resultWithUsage(), errors.New(errMessage) + } + if isOpenAIWSTerminalEvent(eventType) { + terminalEventCount++ + firstTokenMsValue := -1 + if firstTokenMs != nil { + firstTokenMsValue = *firstTokenMs + } + logOpenAIWSModeInfo( + "ingress_ws_http_bridge_turn_completed account_id=%d turn=%d response_id=%s payload_bytes=%d duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + payloadBytes, + time.Since(turnStart).Milliseconds(), + eventCount, + tokenEventCount, + terminalEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + firstTokenMsValue, + clientDisconnected, + ) + return resultWithUsage(), nil + } + } + if err := scanner.Err(); err != nil { + return resultWithUsage(), fmt.Errorf("read upstream http bridge stream: %w", err) + } + if sawDone && eventCount > 0 { + return resultWithUsage(), nil + } + return resultWithUsage(), errors.New("upstream http bridge stream ended before terminal event") +} diff --git a/backend/internal/service/openai_ws_http_bridge_test.go b/backend/internal/service/openai_ws_http_bridge_test.go new file mode 100644 index 00000000000..0a1d6b564bb --- /dev/null +++ b/backend/internal/service/openai_ws_http_bridge_test.go @@ -0,0 +1,461 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestPrepareOpenAIWSHTTPBridgeBodyStripsWSFields(t *testing.T) { + body, err := prepareOpenAIWSHTTPBridgeBody([]byte(`{"type":"response.create","generate":true,"model":"gpt-5","stream":false,"previous_response_id":"resp_prev","input":"hi"}`)) + require.NoError(t, err) + require.False(t, gjson.GetBytes(body, "type").Exists()) + require.False(t, gjson.GetBytes(body, "generate").Exists()) + require.False(t, gjson.GetBytes(body, "previous_response_id").Exists()) + require.Equal(t, "gpt-5", gjson.GetBytes(body, "model").String()) + require.True(t, gjson.GetBytes(body, "stream").Bool()) + require.Equal(t, "hi", gjson.GetBytes(body, "input").String()) +} + +func TestOpenAIWSHTTPBridgeDecisionKeepsSmallFramesOnWS(t *testing.T) { + svc := &OpenAIGatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + HTTPBridgeEnabled: true, + HTTPBridgeThresholdBytes: 100, + }, + }, + }, + } + + require.False(t, svc.shouldBridgeOpenAIWSHTTP(99, "")) + require.True(t, svc.shouldBridgeOpenAIWSHTTP(100, "")) + require.False(t, svc.shouldBridgeOpenAIWSHTTP(1000, "resp_existing")) + + svc.cfg.Gateway.OpenAIWS.HTTPBridgeEnabled = false + require.False(t, svc.shouldBridgeOpenAIWSHTTP(1000, "")) +} + +func TestOpenAIWSHTTPBridgeRelaysSSEFramesAsWebSocketMessages(t *testing.T) { + gin.SetMode(gin.TestMode) + + sseBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_bridge","model":"gpt-5"}}`, + "", + `data: {"type":"response.output_text.delta","response":{"id":"resp_bridge"},"delta":"ok"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_bridge","model":"gpt-5","usage":{"input_tokens":3,"output_tokens":2}}}`, + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "x-request-id": []string{"rid_bridge"}, + }, + Body: io.NopCloser(strings.NewReader(sseBody)), + }} + svc := &OpenAIGatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + OpenAIWS: config.GatewayOpenAIWSConfig{ + HTTPBridgeEnabled: true, + HTTPBridgeThresholdBytes: 1, + }, + }, + }, + httpUpstream: upstream, + toolCorrector: NewCodexToolCorrector(), + } + account := &Account{ + ID: 7, + Name: "api-key", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Status: StatusActive, + } + payload := []byte(`{"type":"response.create","generate":true,"model":"gpt-5","stream":true,"input":"hi"}`) + + type bridgeResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan bridgeResult, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover}) + if err != nil { + resultCh <- bridgeResult{err: err} + return + } + defer func() { _ = conn.CloseNow() }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + ginCtx.Request = req + + writeClient := func(message []byte) error { + writeCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + defer cancel() + return conn.Write(writeCtx, coderws.MessageText, message) + } + result, bridgeErr := svc.proxyOpenAIWSHTTPBridgeTurn( + r.Context(), + ginCtx, + account, + "sk-test", + payload, + len(payload), + "gpt-5", + "", + "", + "", + 1, + writeClient, + ) + resultCh <- bridgeResult{result: result, err: bridgeErr} + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + readEvent := func() []byte { + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + msgType, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return event + } + + created := readEvent() + delta := readEvent() + completed := readEvent() + + require.Equal(t, "response.created", gjson.GetBytes(created, "type").String()) + require.Equal(t, "response.output_text.delta", gjson.GetBytes(delta, "type").String()) + require.Equal(t, "response.completed", gjson.GetBytes(completed, "type").String()) + + select { + case bridge := <-resultCh: + require.NoError(t, bridge.err) + require.NotNil(t, bridge.result) + require.Equal(t, "resp_bridge", bridge.result.RequestID) + require.Equal(t, 3, bridge.result.Usage.InputTokens) + require.Equal(t, 2, bridge.result.Usage.OutputTokens) + require.True(t, bridge.result.OpenAIWSMode) + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for bridge result") + } + + require.NotNil(t, upstream.lastReq) + require.Equal(t, http.MethodPost, upstream.lastReq.Method) + require.False(t, gjson.GetBytes(upstream.lastBody, "type").Exists()) + require.False(t, gjson.GetBytes(upstream.lastBody, "generate").Exists()) + require.True(t, gjson.GetBytes(upstream.lastBody, "stream").Bool()) +} + +func TestOpenAIWSHTTPBridgeAcceptsFirstFrameAboveLegacy16MiB(t *testing.T) { + gin.SetMode(gin.TestMode) + + sseBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_large_bridge","model":"gpt-5"}}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_large_bridge","model":"gpt-5","usage":{"input_tokens":9,"output_tokens":1}}}`, + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "x-request-id": []string{"rid_large_bridge"}, + }, + Body: io.NopCloser(strings.NewReader(sseBody)), + }} + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + OpenAIWS: config.GatewayOpenAIWSConfig{ + Enabled: true, + APIKeyEnabled: true, + ResponsesWebsocketsV2: true, + ClientReadLimitBytes: 64 * 1024 * 1024, + HTTPBridgeEnabled: true, + HTTPBridgeThresholdBytes: 15 * 1024 * 1024, + }, + }, + } + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + toolCorrector: NewCodexToolCorrector(), + } + account := &Account{ + ID: 9, + Name: "api-key", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{"api_key": "sk-upstream"}, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + Concurrency: 1, + Status: StatusActive, + } + + payload := []byte(`{"type":"response.create","generate":true,"model":"gpt-5","stream":true,"input":"` + strings.Repeat("x", 17*1024*1024) + `"}`) + require.Greater(t, len(payload), 16*1024*1024) + require.Less(t, int64(len(payload)), ResolveOpenAIWSClientReadLimitBytes(cfg)) + + errCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover}) + if err != nil { + errCh <- err + return + } + defer func() { _ = conn.CloseNow() }() + conn.SetReadLimit(ResolveOpenAIWSClientReadLimitBytes(cfg)) + + readCtx, cancelRead := context.WithTimeout(r.Context(), 10*time.Second) + msgType, firstMessage, err := conn.Read(readCtx) + cancelRead() + if err != nil { + errCh <- err + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + errCh <- NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "unexpected client websocket message type", nil) + return + } + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "codex_cli_rs/0.135.0") + ginCtx.Request = req + + proxyCtx, cancelProxy := context.WithTimeout(r.Context(), 20*time.Second) + defer cancelProxy() + errCh <- svc.ProxyResponsesWebSocketFromClient(proxyCtx, ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 5*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 20*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, payload) + cancelWrite() + require.NoError(t, err) + + var eventTypes []string + for { + readCtx, cancelRead := context.WithTimeout(context.Background(), 10*time.Second) + msgType, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + + eventType := gjson.GetBytes(event, "type").String() + eventTypes = append(eventTypes, eventType) + if eventType == "response.completed" { + break + } + } + require.Contains(t, eventTypes, "response.created") + require.Contains(t, eventTypes, "response.completed") + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case proxyErr := <-errCh: + require.NoError(t, proxyErr) + case <-time.After(10 * time.Second): + t.Fatal("timed out waiting for websocket bridge proxy to finish") + } + + require.NotNil(t, upstream.lastReq) + require.Equal(t, http.MethodPost, upstream.lastReq.Method) + require.Greater(t, len(upstream.lastBody), 16*1024*1024) + require.False(t, gjson.GetBytes(upstream.lastBody, "type").Exists()) + require.False(t, gjson.GetBytes(upstream.lastBody, "generate").Exists()) + require.True(t, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + require.Equal(t, "gpt-5", gjson.GetBytes(upstream.lastBody, "model").String()) +} + +func TestOpenAIWSHTTPBridgeKeepsContinuationFramesOnHTTPWithoutPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + firstSSEBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_bridge_first","model":"gpt-5.1","output":[{"type":"function_call","id":"fc_bridge_1","call_id":"call_bridge_1","name":"shell","arguments":"{}"}],"usage":{"input_tokens":9,"output_tokens":1}}}`, + "", + }, "\n") + secondSSEBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_bridge_second","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`, + "", + }, "\n") + upstream := &httpUpstreamRecorder{responses: []*http.Response{ + { + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + }, + Body: io.NopCloser(strings.NewReader(firstSSEBody)), + }, + { + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + }, + Body: io.NopCloser(strings.NewReader(secondSSEBody)), + }, + }} + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.HTTPBridgeEnabled = true + cfg.Gateway.OpenAIWS.HTTPBridgeThresholdBytes = 1 + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{} + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + account := &Account{ + ID: 19, + Name: "api-key-bridge-handoff", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{"api_key": "sk-upstream"}, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + Concurrency: 1, + Status: StatusActive, + Schedulable: true, + } + + errCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover}) + if err != nil { + errCh <- err + return + } + defer func() { _ = conn.CloseNow() }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, err := conn.Read(readCtx) + cancelRead() + if err != nil { + errCh <- err + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + errCh <- NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "unexpected client websocket message type", nil) + return + } + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "codex_cli_rs/0.135.0") + ginCtx.Request = req + + errCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + writeMessage := func(payload string) { + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + defer cancelWrite() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + defer cancelRead() + msgType, event, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return event + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":true,"input":"first"}`) + firstTurnEvent := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(firstTurnEvent, "type").String()) + require.Equal(t, "resp_bridge_first", gjson.GetBytes(firstTurnEvent, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_bridge_first","input":[{"type":"function_call_output","call_id":"call_bridge_1","output":"ok"}]}`) + secondTurnEvent := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(secondTurnEvent, "type").String()) + require.Equal(t, "resp_bridge_second", gjson.GetBytes(secondTurnEvent, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case proxyErr := <-errCh: + require.NoError(t, proxyErr) + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for websocket bridge proxy to finish") + } + + require.Len(t, upstream.bodies, 2, "进入 HTTP bridge 后同一客户端 WS 连接内应保持 HTTP/SSE bridge") + require.False(t, gjson.GetBytes(upstream.bodies[0], "previous_response_id").Exists()) + require.False(t, gjson.GetBytes(upstream.bodies[1], "previous_response_id").Exists()) + secondInput := gjson.GetBytes(upstream.bodies[1], "input").Array() + require.Len(t, secondInput, 3) + require.Equal(t, "first", secondInput[0].String()) + require.Equal(t, "function_call", secondInput[1].Get("type").String()) + require.Equal(t, "call_bridge_1", secondInput[1].Get("call_id").String()) + require.Equal(t, "function_call_output", secondInput[2].Get("type").String()) + require.Equal(t, "call_bridge_1", secondInput[2].Get("call_id").String()) + require.Equal(t, 0, captureDialer.DialCount()) + require.Empty(t, captureConn.writes) +}