diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index dd2f45facb..22ed214b25 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -2477,6 +2477,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( wsPath = normalizeOpenAIWSLogValue(parsedURL.Path) } debugEnabled := isOpenAIWSModeDebugEnabled() + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) type openAIWSClientPayload struct { payloadRaw []byte @@ -2586,6 +2587,34 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } normalized = next } + apiKey := getAPIKeyFromContext(c) + imageGenerationAllowed := GroupAllowsImageGeneration(apiKeyGroup(apiKey)) + codexBridgeEnabled := isCodexCLI && imageGenerationAllowed && s.isCodexImageGenerationBridgeEnabled(ctx, account, apiKey) + if codexBridgeEnabled { + payloadMap := make(map[string]any) + if err := json.Unmarshal(normalized, &payloadMap); err != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", err) + } + bridgeModified := false + if ensureOpenAIResponsesImageGenerationTool(payloadMap) { + bridgeModified = true + logOpenAIWSModeInfo("ingress_ws_codex_image_tool_injected account_id=%d", account.ID) + } + if normalizeOpenAIResponsesImageGenerationTools(payloadMap) { + bridgeModified = true + } + if applyCodexImageGenerationBridgeInstructions(payloadMap) { + bridgeModified = true + logOpenAIWSModeInfo("ingress_ws_codex_image_bridge_instructions_added account_id=%d", account.ID) + } + if bridgeModified { + rebuilt, marshalErr := json.Marshal(payloadMap) + if marshalErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", marshalErr) + } + normalized = rebuilt + } + } upstreamModel := normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel)) if modelMissing || upstreamModel != originalModel { next, setErr := applyPayloadMutation(normalized, "model", upstreamModel) @@ -2595,7 +2624,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( normalized = next } imageIntent := IsImageGenerationIntent(openAIResponsesEndpoint, originalModel, normalized) - if imageIntent && !GroupAllowsImageGeneration(apiKeyGroup(getAPIKeyFromContext(c))) { + if imageIntent && !imageGenerationAllowed { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, ImageGenerationPermissionMessage(), nil) } imageBillingModel := "" @@ -2694,7 +2723,6 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } } - isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey) baseAcquireReq := openAIWSAcquireRequest{ Account: account, diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index b7f1bc4f78..5e4b70c2a8 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -298,6 +298,142 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_FollowupCreateCa require.Equal(t, "resp_omit_model_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String()) } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_InjectsCodexImageBridge(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_codex_image_bridge","model":"gpt-5.5","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + groupID := int64(3) + apiKey := &APIKey{ + ID: 1, + UserID: 1, + GroupID: &groupID, + Group: &Group{ + ID: groupID, + AllowImageGeneration: true, + }, + } + account := &Account{ + ID: 31, + Name: "openai-codex-image-ws", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test-token", + }, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "codex_image_generation_bridge": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + ginCtx.Request = req + ginCtx.Set("api_key", apiKey) + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "test-token", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"input":"draw a cat"}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + msgType, message, err := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, err) + require.Equal(t, coderws.MessageText, msgType) + require.Equal(t, "resp_codex_image_bridge", gjson.GetBytes(message, "response.id").String()) + + _ = clientConn.Close(coderws.StatusNormalClosure, "done") + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Len(t, captureConn.writes, 1) + upstreamPayload := requestToJSONString(captureConn.writes[0]) + require.True(t, gjson.Get(upstreamPayload, `tools.#(type=="image_generation")`).Exists()) + require.Equal(t, "png", gjson.Get(upstreamPayload, `tools.#(type=="image_generation").output_format`).String()) + require.Contains(t, gjson.Get(upstreamPayload, "instructions").String(), "image_generation") +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) { gin.SetMode(gin.TestMode)