Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions backend/internal/service/openai_ws_forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2477,6 +2477,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
}
debugEnabled := isOpenAIWSModeDebugEnabled()
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)

type openAIWSClientPayload struct {
payloadRaw []byte
Expand Down Expand Up @@ -2586,6 +2587,34 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
normalized = next
}
apiKey := getAPIKeyFromContext(c)
imageGenerationAllowed := GroupAllowsImageGeneration(apiKeyGroup(apiKey))
codexBridgeEnabled := isCodexCLI && imageGenerationAllowed && s.isCodexImageGenerationBridgeEnabled(ctx, account, apiKey)
if codexBridgeEnabled {
payloadMap := make(map[string]any)
if err := json.Unmarshal(normalized, &payloadMap); err != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", err)
}
bridgeModified := false
if ensureOpenAIResponsesImageGenerationTool(payloadMap) {
bridgeModified = true
logOpenAIWSModeInfo("ingress_ws_codex_image_tool_injected account_id=%d", account.ID)
}
if normalizeOpenAIResponsesImageGenerationTools(payloadMap) {
bridgeModified = true
}
if applyCodexImageGenerationBridgeInstructions(payloadMap) {
bridgeModified = true
logOpenAIWSModeInfo("ingress_ws_codex_image_bridge_instructions_added account_id=%d", account.ID)
}
if bridgeModified {
rebuilt, marshalErr := json.Marshal(payloadMap)
if marshalErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", marshalErr)
}
normalized = rebuilt
}
}
upstreamModel := normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel))
if modelMissing || upstreamModel != originalModel {
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
Expand All @@ -2595,7 +2624,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
normalized = next
}
imageIntent := IsImageGenerationIntent(openAIResponsesEndpoint, originalModel, normalized)
if imageIntent && !GroupAllowsImageGeneration(apiKeyGroup(getAPIKeyFromContext(c))) {
if imageIntent && !imageGenerationAllowed {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, ImageGenerationPermissionMessage(), nil)
}
imageBillingModel := ""
Expand Down Expand Up @@ -2694,7 +2723,6 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
}

isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey)
baseAcquireReq := openAIWSAcquireRequest{
Account: account,
Expand Down
136 changes: 136 additions & 0 deletions backend/internal/service/openai_ws_forwarder_ingress_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,142 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_FollowupCreateCa
require.Equal(t, "resp_omit_model_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String())
}

func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_InjectsCodexImageBridge(t *testing.T) {
gin.SetMode(gin.TestMode)

cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3

captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_codex_image_bridge","model":"gpt-5.5","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)

svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}

groupID := int64(3)
apiKey := &APIKey{
ID: 1,
UserID: 1,
GroupID: &groupID,
Group: &Group{
ID: groupID,
AllowImageGeneration: true,
},
}
account := &Account{
ID: 31,
Name: "openai-codex-image-ws",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test-token",
},
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
"codex_image_generation_bridge": true,
},
}

serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()

rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
ginCtx.Request = req
ginCtx.Set("api_key", apiKey)

readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}

serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "test-token", firstMessage, nil)
}))
defer wsServer.Close()

dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()

writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"input":"draw a cat"}`))
cancelWrite()
require.NoError(t, err)

readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
msgType, message, err := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, err)
require.Equal(t, coderws.MessageText, msgType)
require.Equal(t, "resp_codex_image_bridge", gjson.GetBytes(message, "response.id").String())

_ = clientConn.Close(coderws.StatusNormalClosure, "done")

select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}

require.Len(t, captureConn.writes, 1)
upstreamPayload := requestToJSONString(captureConn.writes[0])
require.True(t, gjson.Get(upstreamPayload, `tools.#(type=="image_generation")`).Exists())
require.Equal(t, "png", gjson.Get(upstreamPayload, `tools.#(type=="image_generation").output_format`).String())
require.Contains(t, gjson.Get(upstreamPayload, "instructions").String(), "image_generation")
}

func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) {
gin.SetMode(gin.TestMode)

Expand Down
Loading