Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions controller/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ func GetAllChannels(c *gin.Context) {
return
}

// buildFetchModelsHeaders constructs the HTTP headers used when fetching the
// list of available models from a channel's upstream provider. It applies
// channel-level header_override entries on top of the default auth headers,
// expanding the {api_key} placeholder. An override entry whose value is the
// empty string is treated as an explicit suppression marker and removes the
// header instead of setting it to an empty value.
func buildFetchModelsHeaders(channel *model.Channel, key string) (http.Header, error) {
var headers http.Header
switch channel.Type {
Expand All @@ -206,6 +212,11 @@ func buildFetchModelsHeaders(channel *model.Channel, key string) (http.Header, e
if strings.Contains(str, "{api_key}") {
str = strings.ReplaceAll(str, "{api_key}", key)
}
// An empty value explicitly suppresses the header.
if str == "" {
headers.Del(k)
continue
}
headers.Set(k, str)
}

Expand Down
47 changes: 44 additions & 3 deletions relay/channel/api_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,20 @@ func shouldSkipPassthroughHeader(name string) bool {
return false
}

// applyHeaderOverridePlaceholders resolves the placeholders supported in a
// header_override template value. It returns the resolved value, a boolean
// indicating whether the entry should be retained, and any error encountered.
//
// Supported placeholders:
// - {api_key}: replaced with the channel's API key.
// - {client_header:<name>}: replaced verbatim with the value of the named
// incoming request header. The placeholder must be the entire template;
// {api_key} is not interpolated inside client-supplied content. Missing or
// empty client headers cause the entry to be dropped (returns retain=false).
//
// An empty resolved value (after {api_key} expansion) is preserved as an
// explicit suppression marker so downstream consumers can delete the header
// from the upstream request.
func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey string) (string, bool, error) {
trimmed := strings.TrimSpace(template)
if strings.HasPrefix(trimmed, clientHeaderPlaceholderPrefix) {
Expand Down Expand Up @@ -154,9 +168,10 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str
if strings.Contains(template, "{api_key}") {
template = strings.ReplaceAll(template, "{api_key}", apiKey)
}
if strings.TrimSpace(template) == "" {
return "", false, nil
}
// An empty template is treated as an explicit suppression marker:
// the entry is included with an empty value so that downstream consumers
// can delete the header from the upstream request, rather than letting a
// value previously written by the channel adaptor leak through.
return template, true, nil
}

Expand Down Expand Up @@ -270,15 +285,31 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
return headerOverride, nil
}

// ResolveHeaderOverride returns the fully resolved header_override map for the
// current relay request. It is a public wrapper around processHeaderOverride
// for use by channel adaptors (notably AWS Bedrock) that need to apply the
// override on a header collection they own rather than on a *http.Request.
func ResolveHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
return processHeaderOverride(info, c)
}

// applyHeaderOverrideToRequest writes the resolved header_override map onto an
// outgoing *http.Request. An empty value is treated as an explicit suppression
// marker and removes the header from the request rather than forwarding an
// empty string (which would let a previously-set header value leak through on
// case-insensitive lookups). When the Host header is set, the request's Host
// field is updated as well so net/http honours it.
func applyHeaderOverrideToRequest(req *http.Request, headerOverride map[string]string) {
if req == nil {
return
}
for key, value := range headerOverride {
// An empty value is an explicit suppression marker: remove the header
// from the upstream request rather than forwarding the empty string.
if value == "" {
req.Header.Del(key)
continue
}
req.Header.Set(key, value)
// set Host in req
if strings.EqualFold(key, "Host") {
Expand Down Expand Up @@ -351,6 +382,11 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
return resp, nil
}

// DoWssRequest opens a WebSocket connection to the upstream provider for the
// current relay request. Channel-level header_override entries are applied
// after SetupRequestHeader so the operator-configured values win over any
// adaptor defaults. An override value of the empty string suppresses the
// header rather than forwarding an empty value upstream.
func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
Expand All @@ -368,6 +404,11 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
return nil, err
}
for key, value := range headerOverride {
// An empty value explicitly suppresses the header upstream.
if value == "" {
targetHeader.Del(key)
continue
}
targetHeader.Set(key, value)
}
targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
Expand Down
71 changes: 67 additions & 4 deletions relay/channel/api_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T)
require.Equal(t, "trace-123", headers["x-upstream-trace"])
}

func TestProcessHeaderOverride_RuntimeOverrideIsFinalHeaderMap(t *testing.T) {
func TestProcessHeaderOverride_ChannelOverrideWinsOverRuntime(t *testing.T) {
t.Parallel()

gin.SetMode(gin.TestMode)
Expand All @@ -105,10 +105,13 @@ func TestProcessHeaderOverride_RuntimeOverrideIsFinalHeaderMap(t *testing.T) {

headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)
require.Equal(t, "runtime-value", headers["x-static"])
// Channel-level header_override entries take precedence over the runtime
// override map (set by upstream features such as channel affinity rules).
// This makes the admin UI the authoritative source for header policy.
require.Equal(t, "legacy-value", headers["x-static"])
require.Equal(t, "legacy-only", headers["x-legacy"])
// Runtime-only entries are still included via the union merge.
require.Equal(t, "runtime-only", headers["x-runtime"])
_, exists := headers["x-legacy"]
require.False(t, exists)
}

func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
Expand Down Expand Up @@ -191,3 +194,63 @@ func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing.
require.Equal(t, "sess-123", upstreamReq.Header.Get("Session_id"))
require.Empty(t, upstreamReq.Header.Get("X-Codex-Beta-Features"))
}


// TestProcessHeaderOverride_EmptyValueIsExplicitSuppression verifies that
// configuring a header_override entry with an empty string value causes the
// resulting override map to include the key with an empty value, so that
// downstream consumers can interpret it as "delete this header upstream".
//
// Regression test for: header_override "anthropic-beta": "" silently became a
// no-op, allowing client-supplied beta flags to leak through to upstreams
// that reject them (notably AWS Bedrock).
func TestProcessHeaderOverride_EmptyValueIsExplicitSuppression(t *testing.T) {
t.Parallel()

gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)

info := &relaycommon.RelayInfo{
IsChannelTest: false,
ChannelMeta: &relaycommon.ChannelMeta{
HeadersOverride: map[string]any{
"anthropic-beta": "",
},
},
}

headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)

// The key MUST be present (so consumers know to act on it),
// and the value MUST be empty (so consumers know it's a suppression).
value, exists := headers["anthropic-beta"]
require.True(t, exists, "empty header_override entry must be retained as a suppression marker")
require.Equal(t, "", value)
}

// TestApplyHeaderOverrideToRequest_EmptyValueDeletesHeader verifies that
// applyHeaderOverrideToRequest removes the header from the outgoing request
// when the override value is the empty string, instead of forwarding an
// empty value or, worse, leaving a previously-set value in place.
func TestApplyHeaderOverrideToRequest_EmptyValueDeletesHeader(t *testing.T) {
t.Parallel()

req := httptest.NewRequest(http.MethodPost, "https://example.com/v1/messages", nil)
// Simulate a value already written by the channel adaptor (e.g. by
// CommonClaudeHeadersOperation copying anthropic-beta from the client).
req.Header.Set("anthropic-beta", "prompt-caching-2024-07-31")
require.Equal(t, "prompt-caching-2024-07-31", req.Header.Get("anthropic-beta"))

overrides := map[string]string{
"anthropic-beta": "",
"x-trace-id": "abc-123",
}
applyHeaderOverrideToRequest(req, overrides)

require.Empty(t, req.Header.Get("anthropic-beta"), "empty override value must remove the header")
require.NotContains(t, req.Header, "Anthropic-Beta")
require.Equal(t, "abc-123", req.Header.Get("x-trace-id"))
}
13 changes: 13 additions & 0 deletions relay/channel/aws/relay-aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
return client, nil
}

// doAwsClientRequest dispatches the relay request to AWS Bedrock through the
// SigV4-aware AWS SDK client. Channel-level header_override entries are
// applied after SetupRequestHeader so they win over any adaptor defaults; an
// override value of the empty string suppresses the header upstream, including
// any value previously written by SetupRequestHeader (this is how operators
// strip Anthropic beta flags that Bedrock rejects).
func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, requestBody io.Reader) (any, error) {
awsCli, err := newAwsClient(c, info)
if err != nil {
Expand All @@ -112,6 +118,13 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
return nil, err
}
for key, value := range headerOverride {
// An empty value explicitly suppresses the header upstream, including
// any value previously written by SetupRequestHeader (e.g. anthropic-beta
// copied from the client request, which Bedrock often rejects).
if value == "" {
requestHeader.Del(key)
continue
}
requestHeader.Set(key, value)
}

Expand Down
52 changes: 42 additions & 10 deletions relay/common/override.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,13 @@ func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} {
return info.ChannelMeta.HeadersOverride
}

// sanitizeHeaderOverrideMap normalizes a raw header_override map by
// canonicalizing keys, trimming whitespace from values, and dropping entries
// with empty keys. Empty values are intentionally preserved as explicit
// suppression markers; downstream consumers will Del() the matching header
// from the outgoing request. Passthrough rule keys ("*", "re:...",
// "regex:...") historically also use empty values and are preserved by the
// same fall-through behavior.
func sanitizeHeaderOverrideMap(source map[string]interface{}) map[string]interface{} {
if len(source) == 0 {
return map[string]interface{}{}
Expand All @@ -395,12 +402,11 @@ func sanitizeHeaderOverrideMap(source map[string]interface{}) map[string]interfa
continue
}
normalizedValue := strings.TrimSpace(fmt.Sprintf("%v", value))
if normalizedValue == "" {
if isHeaderPassthroughRuleKeyForOverride(normalizedKey) {
target[normalizedKey] = ""
}
continue
}
// An empty value is preserved as an explicit suppression marker:
// downstream consumers will Del() the header instead of forwarding it.
// Passthrough rule keys ("*", "re:...", "regex:...") historically use
// an empty value as well, and that contract is preserved by the same
// fall-through behavior.
target[normalizedKey] = normalizedValue
}
return target
Expand All @@ -417,14 +423,40 @@ func isHeaderPassthroughRuleKeyForOverride(key string) bool {
return strings.HasPrefix(key, "re:") || strings.HasPrefix(key, "regex:")
}

// GetEffectiveHeaderOverride returns the merged header_override map that
// should be applied to the upstream request. When UseRuntimeHeadersOverride
// is set, runtime overrides (from upstream features such as channel affinity
// rules) are merged first, then the channel-level header_override is layered
// on top. Channel-level entries win for keys defined in both layers because
// the admin UI is the authoritative source of header policy; this includes
// empty-string entries, which act as explicit suppression markers.
func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} {
if info == nil {
return map[string]interface{}{}
}
if info.UseRuntimeHeadersOverride {
return sanitizeHeaderOverrideMap(info.RuntimeHeadersOverride)
}
return sanitizeHeaderOverrideMap(getHeaderOverrideMap(info))
channelOverride := sanitizeHeaderOverrideMap(getHeaderOverrideMap(info))
if !info.UseRuntimeHeadersOverride {
return channelOverride
}
// Merge channel-level override on top of the runtime override map. Runtime
// overrides come from upstream features such as channel affinity rules,
// which inject pass-through headers (e.g. claude-cli "anthropic-beta")
// into the request. The channel-level header_override is set explicitly by
// the operator in the admin UI and represents the more intentional source
// of truth, so its entries (including empty-string suppression markers)
// must win when both define the same key.
runtimeOverride := sanitizeHeaderOverrideMap(info.RuntimeHeadersOverride)
if len(channelOverride) == 0 {
return runtimeOverride
}
merged := make(map[string]interface{}, len(runtimeOverride)+len(channelOverride))
for k, v := range runtimeOverride {
merged[k] = v
}
for k, v := range channelOverride {
merged[k] = v
}
return merged
}

func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
Expand Down
24 changes: 18 additions & 6 deletions relay/common/override_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1974,26 +1974,38 @@ func TestApplyParamOverrideWithRelayInfoSetHeaderMapRewritesAnthropicBeta(t *tes
}
}

func TestGetEffectiveHeaderOverrideUsesRuntimeOverrideAsFinalResult(t *testing.T) {
func TestGetEffectiveHeaderOverrideMergesChannelOnTopOfRuntime(t *testing.T) {
info := &RelayInfo{
UseRuntimeHeadersOverride: true,
RuntimeHeadersOverride: map[string]interface{}{
"x-runtime": "runtime-only",
"x-runtime": "runtime-only",
"anthropic-beta": "claude-code-20250219",
},
ChannelMeta: &ChannelMeta{
HeadersOverride: map[string]interface{}{
"X-Static": "static-value",
"X-Deleted": "should-not-exist",
"X-Static": "static-value",
"anthropic-beta": "",
},
},
}

effective := GetEffectiveHeaderOverride(info)
// Runtime-only entries are still included in the merged map.
if effective["x-runtime"] != "runtime-only" {
t.Fatalf("expected x-runtime from runtime override, got: %v", effective["x-runtime"])
}
if _, exists := effective["x-static"]; exists {
t.Fatalf("expected runtime override to be final and not merge channel headers")
// Channel-level overrides win for keys defined in both layers, including
// the empty-string suppression marker that downstream consumers will use
// to delete the header from the upstream request.
if effective["x-static"] != "static-value" {
t.Fatalf("expected x-static from channel override, got: %v", effective["x-static"])
}
v, exists := effective["anthropic-beta"]
if !exists {
t.Fatalf("expected anthropic-beta to be retained as suppression marker")
}
if v != "" {
t.Fatalf("expected channel-level empty value to win over runtime, got: %v", v)
}
}

Expand Down