Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/internal/handler/openai_gateway_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
accountReleaseFunc()
}
}()
service.ClearParsedRequestBodyCache(c)
return h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
}()
forwardDurationMs := time.Since(forwardStart).Milliseconds()
Expand Down
140 changes: 140 additions & 0 deletions backend/internal/service/openai_failover_cached_body_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package service

import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)

func TestOpenAIGatewayService_Forward_FailoverReparsesCachedBodyForNextAccount(t *testing.T) {
gin.SetMode(gin.TestMode)

tests := []struct {
name string
requestModel string
firstMapping map[string]any
secondMapping map[string]any
clearCache bool
wantFirst string
wantSecond string
}{
{
name: "both accounts have mapping",
firstMapping: map[string]any{"alias-model": "base-model-a"},
secondMapping: map[string]any{"alias-model": "base-model-b"},
clearCache: true,
wantFirst: "base-model-a",
wantSecond: "base-model-b",
},
{
name: "first account has mapping second account has none",
requestModel: "gpt-5.4-high",
firstMapping: map[string]any{"gpt-5.4-high": "gpt-5.4"},
clearCache: true,
wantFirst: "gpt-5.4",
wantSecond: "gpt-5.4",
},
{
name: "first account has no mapping second account has mapping",
secondMapping: map[string]any{"alias-model": "base-model-b"},
clearCache: true,
wantFirst: "alias-model",
wantSecond: "base-model-b",
},
{
name: "dirty cache is reparsed when mappings differ",
firstMapping: map[string]any{"alias-model": "base-model-a"},
secondMapping: map[string]any{"alias-model": "base-model-b"},
wantFirst: "base-model-a",
wantSecond: "base-model-b",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
requestModel := tt.requestModel
if requestModel == "" {
requestModel = "alias-model"
}
body := []byte(`{"model":"` + requestModel + `","stream":false,"instructions":"cache-test","input":"hello"}`)

rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")

upstream := &httpUpstreamRecorder{responses: []*http.Response{
{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-failover-a"}},
Body: io.NopCloser(strings.NewReader(`{"error":{"type":"rate_limit_error","message":"rate limited"}}`)),
},
{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-ok-b"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_123","status":"completed","model":"ok","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}

firstAccount := openAIFailoverCachedBodyTestAccount(1, "account-a", tt.firstMapping)
secondAccount := openAIFailoverCachedBodyTestAccount(2, "account-b", tt.secondMapping)

_, err := svc.Forward(context.Background(), c, firstAccount, body)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.True(t, errors.As(err, &failoverErr))
require.Len(t, upstream.bodies, 1)
require.Equal(t, tt.wantFirst, gjson.GetBytes(upstream.bodies[0], "model").String())

if tt.clearCache {
ClearParsedRequestBodyCache(c)
}
result, err := svc.Forward(context.Background(), c, secondAccount, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, upstream.bodies, 2)
require.Equal(t, tt.wantSecond, gjson.GetBytes(upstream.bodies[1], "model").String())
})
}
}

func TestGetOpenAIRequestBodyMap_IgnoresDirtyContextCacheModel(t *testing.T) {
gin.SetMode(gin.TestMode)

rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Set(OpenAIParsedRequestBodyKey, map[string]any{"model": "base-model-a", "stream": true})

got, err := getOpenAIRequestBodyMap(c, []byte(`{"model":"alias-model","stream":false}`))
require.NoError(t, err)
require.Equal(t, "alias-model", got["model"])
require.Equal(t, false, got["stream"])
}

func openAIFailoverCachedBodyTestAccount(id int64, name string, mapping map[string]any) *Account {
credentials := map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-account"}
if mapping != nil {
credentials["model_mapping"] = mapping
}
return &Account{
ID: id,
Name: name,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: credentials,
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
}
16 changes: 15 additions & 1 deletion backend/internal/service/openai_gateway_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6819,7 +6819,7 @@ func isEmptyBase64DataURI(raw string) bool {
func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) {
if c != nil {
if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok {
if reqBody, ok := cached.(map[string]any); ok && reqBody != nil {
if reqBody, ok := cached.(map[string]any); ok && reqBody != nil && openAIRequestBodyCacheMatchesRawModel(reqBody, body) {
return reqBody, nil
}
}
Expand All @@ -6835,6 +6835,20 @@ func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error
return reqBody, nil
}

func openAIRequestBodyCacheMatchesRawModel(reqBody map[string]any, body []byte) bool {
rawModel := gjson.GetBytes(body, "model")
if !rawModel.Exists() || rawModel.Type != gjson.String {
return true
}
cachedModel, ok := reqBody["model"].(string)
return ok && cachedModel == rawModel.String()
}

// ClearParsedRequestBodyCache drops the per-request parsed OpenAI body cache.
func ClearParsedRequestBodyCache(c *gin.Context) {
releaseOpenAIParsedRequestBody(c)
}

func releaseOpenAIParsedRequestBody(c *gin.Context) {
if c == nil {
return
Expand Down
Loading