Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 74 additions & 13 deletions backend/internal/service/openai_gateway_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4454,6 +4454,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}

needModelReplace := originalModel != mappedModel
streamOutputAccumulator := apicompat.NewBufferedResponseAccumulator()
streamImageOutputs := make([]json.RawMessage, 0, 1)
streamSeenImages := make(map[string]struct{})
resultWithUsage := func() *openaiStreamingResult {
return &openaiStreamingResult{
usage: usage,
Expand Down Expand Up @@ -4532,13 +4535,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if data, ok := extractOpenAISSEDataLine(line); ok {

// Replace model in response if needed.
// Fast path: most events do not contain model field values.
if needModelReplace && mappedModel != "" && strings.Contains(data, mappedModel) {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}

dataBytes := []byte(data)
if openAIStreamEventIsTerminal(data) {
sawTerminalEvent = true
Expand All @@ -4564,6 +4560,26 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
line = "data: " + data
eventType = strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
}
if imageOutput, ok := extractImageGenerationOutputFromSSEData(dataBytes, streamSeenImages); ok {
streamImageOutputs = append(streamImageOutputs, imageOutput)
}
if responsesStreamEventMayContributeToOutput(eventType) {
var streamEvent apicompat.ResponsesStreamEvent
if err := json.Unmarshal(dataBytes, &streamEvent); err == nil {
streamOutputAccumulator.ProcessEvent(&streamEvent)
}
}
if normalizedData, normalized := normalizeResponsesStreamingTerminalOutput(dataBytes, streamOutputAccumulator, streamImageOutputs); normalized {
dataBytes = normalizedData
data = string(normalizedData)
line = "data: " + data
eventType = strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
}
// Replace model in response if needed.
// Fast path: most events do not contain model field values.
if needModelReplace && mappedModel != "" && strings.Contains(line, mappedModel) {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
startsClientOutput := forceFlushFailedEvent || openAIStreamDataStartsClientOutput(data, eventType)

// 写入客户端(客户端断开后继续 drain 上游)
Expand Down Expand Up @@ -5099,6 +5115,45 @@ func extractCodexFinalResponse(body string) ([]byte, bool) {
return nil, false
}

func normalizeResponsesStreamingTerminalOutput(data []byte, acc *apicompat.BufferedResponseAccumulator, imageOutputs []json.RawMessage) ([]byte, bool) {
eventType := strings.TrimSpace(gjson.GetBytes(data, "type").String())
switch eventType {
case "response.completed", "response.done", "response.incomplete", "response.cancelled", "response.canceled":
default:
return data, false
}

output := gjson.GetBytes(data, "response.output")
hasAccumulatedOutput := (acc != nil && acc.HasContent()) || len(imageOutputs) > 0
if output.Exists() && output.IsArray() {
if len(output.Array()) > 0 || !hasAccumulatedOutput {
return data, false
}
}

outputJSON := []byte("[]")
if reconstructed, ok := buildResponsesOutputJSON(acc, imageOutputs); ok {
outputJSON = reconstructed
}
updated, err := sjson.SetRawBytes(data, "response.output", outputJSON)
if err != nil {
return data, false
}
return updated, true
}

func responsesStreamEventMayContributeToOutput(eventType string) bool {
switch eventType {
case "response.output_text.delta",
"response.output_item.added",
"response.function_call_arguments.delta",
"response.reasoning_summary_text.delta":
return true
default:
return false
}
}

// reconstructResponseOutputFromSSE scans raw SSE body text for delta events and
// returns a JSON-encoded output array reconstructed from accumulated deltas.
// Returns (nil, false) if no content was found in deltas.
Expand All @@ -5110,17 +5165,23 @@ func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) {
if imageOutput, ok := extractImageGenerationOutputFromSSEData(data, seenImages); ok {
imageOutputs = append(imageOutputs, imageOutput)
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal(data, &event); err == nil {
acc.ProcessEvent(&event)
eventType := strings.TrimSpace(gjson.GetBytes(data, "type").String())
if responsesStreamEventMayContributeToOutput(eventType) {
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal(data, &event); err == nil {
acc.ProcessEvent(&event)
}
}
})
if !acc.HasContent() && len(imageOutputs) == 0 {
return buildResponsesOutputJSON(acc, imageOutputs)
}

func buildResponsesOutputJSON(acc *apicompat.BufferedResponseAccumulator, imageOutputs []json.RawMessage) ([]byte, bool) {
if (acc == nil || !acc.HasContent()) && len(imageOutputs) == 0 {
return nil, false
}

var output []json.RawMessage
if acc.HasContent() {
if acc != nil && acc.HasContent() {
outputJSON, err := json.Marshal(acc.BuildOutput())
if err == nil {
_ = json.Unmarshal(outputJSON, &output)
Expand Down
79 changes: 79 additions & 0 deletions backend/internal/service/openai_gateway_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,85 @@ func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) {
require.Contains(t, rec.Body.String(), "response.completed")
}

func TestOpenAIStreamingNormalizesTerminalOutputFromDeltas(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}

rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)

resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.created","response":{"id":"resp_sdk_parse"}}`,
"",
`data: {"type":"response.output_text.delta","delta":"pon"}`,
"",
`data: {"type":"response.output_text.delta","delta":"g"}`,
"",
`data: {"type":"response.completed","response":{"id":"resp_sdk_parse","status":"completed","output":null,"usage":{"input_tokens":1,"output_tokens":1}}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-sdk-parse"}},
}

result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.NoError(t, err)
require.NotNil(t, result)

terminalType, terminalPayload, ok := extractOpenAISSETerminalEvent(rec.Body.String())
require.True(t, ok)
require.Equal(t, "response.completed", terminalType)
output := gjson.GetBytes(terminalPayload, "response.output")
require.True(t, output.IsArray())
require.Len(t, output.Array(), 1)
require.Equal(t, "pong", gjson.GetBytes(terminalPayload, "response.output.0.content.0.text").String())
}

func TestOpenAIStreamingNormalizesTerminalOutputToEmptyArray(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}

rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)

resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_empty","status":"completed","output":null,"usage":{"input_tokens":1,"output_tokens":0}}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-empty-output"}},
}

result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.NoError(t, err)
require.NotNil(t, result)

terminalType, terminalPayload, ok := extractOpenAISSETerminalEvent(rec.Body.String())
require.True(t, ok)
require.Equal(t, "response.completed", terminalType)
output := gjson.GetBytes(terminalPayload, "response.output")
require.True(t, output.IsArray())
require.Len(t, output.Array(), 0)
}

func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Expand Down
Loading