From e9a2db8e80b70091e64fcd88a3cda96db80d22c1 Mon Sep 17 00:00:00 2001 From: haichuan Date: Thu, 28 May 2026 18:02:18 +0800 Subject: [PATCH] fix: normalize responses streaming terminal output --- .../service/openai_gateway_service.go | 87 ++++++++++++++++--- .../service/openai_gateway_service_test.go | 79 +++++++++++++++++ 2 files changed, 153 insertions(+), 13 deletions(-) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f93cc221057..8b7e837b6ba 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4454,6 +4454,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } needModelReplace := originalModel != mappedModel + streamOutputAccumulator := apicompat.NewBufferedResponseAccumulator() + streamImageOutputs := make([]json.RawMessage, 0, 1) + streamSeenImages := make(map[string]struct{}) resultWithUsage := func() *openaiStreamingResult { return &openaiStreamingResult{ usage: usage, @@ -4532,13 +4535,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } // Extract data from SSE line (supports both "data: " and "data:" formats) if data, ok := extractOpenAISSEDataLine(line); ok { - - // Replace model in response if needed. - // Fast path: most events do not contain model field values. - if needModelReplace && mappedModel != "" && strings.Contains(data, mappedModel) { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) - } - dataBytes := []byte(data) if openAIStreamEventIsTerminal(data) { sawTerminalEvent = true @@ -4564,6 +4560,26 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp line = "data: " + data eventType = strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String()) } + if imageOutput, ok := extractImageGenerationOutputFromSSEData(dataBytes, streamSeenImages); ok { + streamImageOutputs = append(streamImageOutputs, imageOutput) + } + if responsesStreamEventMayContributeToOutput(eventType) { + var streamEvent apicompat.ResponsesStreamEvent + if err := json.Unmarshal(dataBytes, &streamEvent); err == nil { + streamOutputAccumulator.ProcessEvent(&streamEvent) + } + } + if normalizedData, normalized := normalizeResponsesStreamingTerminalOutput(dataBytes, streamOutputAccumulator, streamImageOutputs); normalized { + dataBytes = normalizedData + data = string(normalizedData) + line = "data: " + data + eventType = strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String()) + } + // Replace model in response if needed. + // Fast path: most events do not contain model field values. + if needModelReplace && mappedModel != "" && strings.Contains(line, mappedModel) { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + } startsClientOutput := forceFlushFailedEvent || openAIStreamDataStartsClientOutput(data, eventType) // 写入客户端(客户端断开后继续 drain 上游) @@ -5099,6 +5115,45 @@ func extractCodexFinalResponse(body string) ([]byte, bool) { return nil, false } +func normalizeResponsesStreamingTerminalOutput(data []byte, acc *apicompat.BufferedResponseAccumulator, imageOutputs []json.RawMessage) ([]byte, bool) { + eventType := strings.TrimSpace(gjson.GetBytes(data, "type").String()) + switch eventType { + case "response.completed", "response.done", "response.incomplete", "response.cancelled", "response.canceled": + default: + return data, false + } + + output := gjson.GetBytes(data, "response.output") + hasAccumulatedOutput := (acc != nil && acc.HasContent()) || len(imageOutputs) > 0 + if output.Exists() && output.IsArray() { + if len(output.Array()) > 0 || !hasAccumulatedOutput { + return data, false + } + } + + outputJSON := []byte("[]") + if reconstructed, ok := buildResponsesOutputJSON(acc, imageOutputs); ok { + outputJSON = reconstructed + } + updated, err := sjson.SetRawBytes(data, "response.output", outputJSON) + if err != nil { + return data, false + } + return updated, true +} + +func responsesStreamEventMayContributeToOutput(eventType string) bool { + switch eventType { + case "response.output_text.delta", + "response.output_item.added", + "response.function_call_arguments.delta", + "response.reasoning_summary_text.delta": + return true + default: + return false + } +} + // reconstructResponseOutputFromSSE scans raw SSE body text for delta events and // returns a JSON-encoded output array reconstructed from accumulated deltas. // Returns (nil, false) if no content was found in deltas. @@ -5110,17 +5165,23 @@ func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) { if imageOutput, ok := extractImageGenerationOutputFromSSEData(data, seenImages); ok { imageOutputs = append(imageOutputs, imageOutput) } - var event apicompat.ResponsesStreamEvent - if err := json.Unmarshal(data, &event); err == nil { - acc.ProcessEvent(&event) + eventType := strings.TrimSpace(gjson.GetBytes(data, "type").String()) + if responsesStreamEventMayContributeToOutput(eventType) { + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal(data, &event); err == nil { + acc.ProcessEvent(&event) + } } }) - if !acc.HasContent() && len(imageOutputs) == 0 { + return buildResponsesOutputJSON(acc, imageOutputs) +} + +func buildResponsesOutputJSON(acc *apicompat.BufferedResponseAccumulator, imageOutputs []json.RawMessage) ([]byte, bool) { + if (acc == nil || !acc.HasContent()) && len(imageOutputs) == 0 { return nil, false } - var output []json.RawMessage - if acc.HasContent() { + if acc != nil && acc.HasContent() { outputJSON, err := json.Marshal(acc.BuildOutput()) if err == nil { _ = json.Unmarshal(outputJSON, &output) diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 8bed920d3ff..c642fcd4d0b 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -1233,6 +1233,85 @@ func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) { require.Contains(t, rec.Body.String(), "response.completed") } +func TestOpenAIStreamingNormalizesTerminalOutputFromDeltas(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_sdk_parse"}}`, + "", + `data: {"type":"response.output_text.delta","delta":"pon"}`, + "", + `data: {"type":"response.output_text.delta","delta":"g"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_sdk_parse","status":"completed","output":null,"usage":{"input_tokens":1,"output_tokens":1}}}`, + "", + }, "\n"))), + Header: http.Header{"X-Request-Id": []string{"rid-sdk-parse"}}, + } + + result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + require.NoError(t, err) + require.NotNil(t, result) + + terminalType, terminalPayload, ok := extractOpenAISSETerminalEvent(rec.Body.String()) + require.True(t, ok) + require.Equal(t, "response.completed", terminalType) + output := gjson.GetBytes(terminalPayload, "response.output") + require.True(t, output.IsArray()) + require.Len(t, output.Array(), 1) + require.Equal(t, "pong", gjson.GetBytes(terminalPayload, "response.output.0.content.0.text").String()) +} + +func TestOpenAIStreamingNormalizesTerminalOutputToEmptyArray(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_empty","status":"completed","output":null,"usage":{"input_tokens":1,"output_tokens":0}}}`, + "", + }, "\n"))), + Header: http.Header{"X-Request-Id": []string{"rid-empty-output"}}, + } + + result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + require.NoError(t, err) + require.NotNil(t, result) + + terminalType, terminalPayload, ok := extractOpenAISSETerminalEvent(rec.Body.String()) + require.True(t, ok) + require.Equal(t, "response.completed", terminalType) + output := gjson.GetBytes(terminalPayload, "response.output") + require.True(t, output.IsArray()) + require.Len(t, output.Array(), 0) +} + func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{