From feb71f0a854cd00ebf715b6842df7f0e2726b2ca Mon Sep 17 00:00:00 2001 From: hungcuong9125 Date: Fri, 19 Jun 2026 14:45:16 +0700 Subject: [PATCH 1/8] =?UTF-8?q?fix:=20address=20PR=20review=20=E2=80=94=20?= =?UTF-8?q?bind=20streaming=20read=20to=20attempt=20ctx,=20guard=20nil=20s?= =?UTF-8?q?chema?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two warnings from kilo-code-bot: 1. StreamingTimeoutMs only guards request startup; once GetStreamingBody returns, the body read was tied to the request context (no timeout), so a mid-stream stall could sit forever. Pass the per-model attempt context into ProxyStream/ProxyResponsesStream/ProxyGeminiStream and the raw Anthropic io.Copy, and wrap the upstream body with a tiny ctxio.NewCtxReadCloser so the body Read also respects the deadline. 2. transformTools panicked on valid JSON that unmarshals to a nil map (e.g. " null " with decorative whitespace). Treat that case the same as a successful parse of "{}" — fall back to the default schema. Co-Authored-By: Claude Opus 4.8 (1M context) --- internal/handlers/messages.go | 30 ++++++-- internal/transformer/ctxio.go | 76 ++++++++++++++++++++ internal/transformer/ctxio_test.go | 101 +++++++++++++++++++++++++++ internal/transformer/request.go | 32 ++++++++- internal/transformer/request_test.go | 24 +++++++ internal/transformer/stream.go | 13 ++-- 6 files changed, 266 insertions(+), 10 deletions(-) create mode 100644 internal/transformer/ctxio.go create mode 100644 internal/transformer/ctxio_test.go diff --git a/internal/handlers/messages.go b/internal/handlers/messages.go index 92848cb..318807e 100644 --- a/internal/handlers/messages.go +++ b/internal/handlers/messages.go @@ -511,7 +511,10 @@ func (h *MessagesHandler) handleStreaming( continue } - if err := h.streamHandler.ProxyStream(rw, streamBody, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { + // Bind body read to ctx so streaming_timeout_ms aborts mid-stream. + streamReader := transformer.NewCtxReadCloser(ctx, streamBody) + + if err := h.streamHandler.ProxyStream(rw, streamReader, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { _ = streamBody.Close() if err == transformer.ErrClientDisconnected { h.logger.Debug("client disconnected during stream") @@ -542,6 +545,8 @@ func (h *MessagesHandler) handleStreaming( } // handleResponsesStreaming handles streaming for OpenAI Responses endpoint. +// ctx is the per-attempt context (carries streaming_timeout_ms); clientCtx is the +// broader request context used only for client-disconnect signaling. func (h *MessagesHandler) handleResponsesStreaming( ctx context.Context, w http.ResponseWriter, @@ -561,7 +566,10 @@ func (h *MessagesHandler) handleResponsesStreaming( return err } - if err := h.streamHandler.ProxyResponsesStream(w, streamBody, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { + // Bind body read to ctx so streaming_timeout_ms aborts mid-stream. + streamReader := transformer.NewCtxReadCloser(ctx, streamBody) + + if err := h.streamHandler.ProxyResponsesStream(w, streamReader, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { _ = streamBody.Close() return err } @@ -571,6 +579,8 @@ func (h *MessagesHandler) handleResponsesStreaming( } // handleGeminiStreaming handles streaming for Gemini endpoint. +// ctx is the per-attempt context (carries streaming_timeout_ms); clientCtx is the +// broader request context used only for client-disconnect signaling. func (h *MessagesHandler) handleGeminiStreaming( ctx context.Context, w http.ResponseWriter, @@ -590,7 +600,10 @@ func (h *MessagesHandler) handleGeminiStreaming( return err } - if err := h.streamHandler.ProxyGeminiStream(w, streamBody, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { + // Bind body read to ctx so streaming_timeout_ms aborts mid-stream. + streamReader := transformer.NewCtxReadCloser(ctx, streamBody) + + if err := h.streamHandler.ProxyGeminiStream(w, streamReader, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { _ = streamBody.Close() return err } @@ -696,6 +709,9 @@ func (h *MessagesHandler) handleAnthropicStreaming( defer func() { _ = resp.Body.Close() }() defer cancel() + // Bind body read to ctx so streaming_timeout_ms aborts mid-stream. + bodyReader := transformer.NewCtxReader(ctx, resp.Body) + // Stream the body chunk-by-chunk with an idle watchdog. The stream lives // as long as data keeps flowing and is aborted when no byte arrives // within idleTimeout. @@ -712,7 +728,7 @@ func (h *MessagesHandler) handleAnthropicStreaming( return transformer.ErrClientDisconnected default: } - n, rerr := resp.Body.Read(buf) + n, rerr := bodyReader.Read(buf) if n > 0 { ping() if _, werr := w.Write(buf[:n]); werr != nil { @@ -726,6 +742,12 @@ func (h *MessagesHandler) handleAnthropicStreaming( return nil } if rerr != nil { + if errors.Is(rerr, transformer.ErrStreamReadCanceled) { + if clientCtx.Err() == nil { + return transformer.ErrStreamIdle + } + return transformer.ErrClientDisconnected + } if transformer.IsIdleTimeout(rerr) { return transformer.ErrStreamIdle } diff --git a/internal/transformer/ctxio.go b/internal/transformer/ctxio.go new file mode 100644 index 0000000..24f4a05 --- /dev/null +++ b/internal/transformer/ctxio.go @@ -0,0 +1,76 @@ +// Package transformer includes ctxio: a context-bound reader wrapper. +package transformer + +import ( + "context" + "errors" + "io" +) + +// ErrStreamReadCanceled is returned by ctxReader.Read when its context is canceled +// or its deadline expires. +var ErrStreamReadCanceled = errors.New("stream read canceled by context") + +// ctxReader wraps an io.Reader and aborts Read when ctx is done. +// +// http.Client.Timeout is checked only at request start; once headers arrive, the +// body is streamed and the net/http transport does not enforce any further deadline. +// Without this wrapper, a slow upstream mid-stream can stall a streaming proxy +// forever. The wrapper does not preempt a read that is already blocked in the +// transport — it surfaces a context-canceled error on the next call. +type ctxReader struct { + ctx context.Context + r io.Reader +} + +// NewCtxReader wraps r so that its next Read returns ErrStreamReadCanceled when +// ctx is canceled or its deadline expires. +func NewCtxReader(ctx context.Context, r io.Reader) io.Reader { + if r == nil { + return nil + } + return &ctxReader{ctx: ctx, r: r} +} + +// NewCtxReadCloser is like NewCtxReader but preserves the io.Closer on the +// returned value. If rc is nil, returns nil. +func NewCtxReadCloser(ctx context.Context, rc io.ReadCloser) io.ReadCloser { + if rc == nil { + return nil + } + return &ctxReadCloser{ + ctxReader: ctxReader{ctx: ctx, r: rc}, + closer: rc, + } +} + +type ctxReadCloser struct { + ctxReader + closer io.Closer +} + +func (c *ctxReadCloser) Close() error { + return c.closer.Close() +} + +func (c *ctxReader) Read(p []byte) (int, error) { + select { + case <-c.ctx.Done(): + return 0, ErrStreamReadCanceled + default: + } + + n, err := c.r.Read(p) + if n > 0 { + // Data still valid even if the deadline fired mid-read; the next + // Read will surface the cancellation. + return n, err + } + + select { + case <-c.ctx.Done(): + return 0, ErrStreamReadCanceled + default: + return n, err + } +} diff --git a/internal/transformer/ctxio_test.go b/internal/transformer/ctxio_test.go new file mode 100644 index 0000000..b2b395b --- /dev/null +++ b/internal/transformer/ctxio_test.go @@ -0,0 +1,101 @@ +package transformer + +import ( + "bytes" + "context" + "errors" + "io" + "strings" + "testing" + "time" +) + +func TestNewCtxReader_PassesThroughUncanceled(t *testing.T) { + ctx := context.Background() + in := strings.NewReader("hello world") + r := NewCtxReader(ctx, in) + + got, err := io.ReadAll(r.(io.Reader)) + if err != nil { + t.Fatalf("ReadAll err = %v, want nil", err) + } + if string(got) != "hello world" { + t.Fatalf("ReadAll = %q, want %q", got, "hello world") + } +} + +func TestNewCtxReader_AbortsOnCanceledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // already canceled before the first Read + + r := NewCtxReader(ctx, strings.NewReader("anything")) + buf := make([]byte, 16) + n, err := r.Read(buf) + if n != 0 { + t.Fatalf("Read returned n=%d, want 0", n) + } + if !errors.Is(err, ErrStreamReadCanceled) { + t.Fatalf("Read err = %v, want ErrStreamReadCanceled", err) + } +} + +func TestNewCtxReader_AbortsOnDeadlineExpiry(t *testing.T) { + // 1ns deadline fires almost immediately. + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + time.Sleep(2 * time.Millisecond) // ensure deadline has passed + + r := NewCtxReader(ctx, strings.NewReader("anything")) + buf := make([]byte, 16) + _, err := r.Read(buf) + if !errors.Is(err, ErrStreamReadCanceled) { + t.Fatalf("Read err = %v, want ErrStreamReadCanceled", err) + } +} + +func TestNewCtxReader_NilReaderReturnsNil(t *testing.T) { + if got := NewCtxReader(context.Background(), nil); got != nil { + t.Fatalf("NewCtxReader(nil) = %v, want nil", got) + } +} + +func TestNewCtxReadCloser_ClosesUnderlying(t *testing.T) { + ctx := context.Background() + br := &bufferReadCloser{Reader: bytes.NewReader([]byte("ok"))} + rc := NewCtxReadCloser(ctx, br) + if rc == nil { + t.Fatal("NewCtxReadCloser returned nil") + } + + // Underlying body is exposed via the underlying *bytes.Reader; close + // should still flip the closer flag. + got, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("ReadAll err = %v, want nil", err) + } + if string(got) != "ok" { + t.Fatalf("ReadAll = %q, want %q", got, "ok") + } + if err := rc.Close(); err != nil { + t.Fatalf("Close err = %v, want nil", err) + } + if !br.closed { + t.Fatal("underlying Close was not called") + } +} + +func TestNewCtxReadCloser_NilReturnsNil(t *testing.T) { + if got := NewCtxReadCloser(context.Background(), nil); got != nil { + t.Fatalf("NewCtxReadCloser(nil) = %v, want nil", got) + } +} + +type bufferReadCloser struct { + io.Reader + closed bool +} + +func (b *bufferReadCloser) Close() error { + b.closed = true + return nil +} diff --git a/internal/transformer/request.go b/internal/transformer/request.go index 7e5b1ab..93d7715 100644 --- a/internal/transformer/request.go +++ b/internal/transformer/request.go @@ -591,8 +591,36 @@ func (t *RequestTransformer) transformTools(tools []types.Tool) []types.ToolDef for _, tool := range tools { // InputSchema is already json.RawMessage, use it directly schema := tool.InputSchema - if len(schema) == 0 { - schema = []byte(`{"type":"object","properties":{}}`) + switch { + case len(schema) == 0, string(schema) == "null", string(schema) == "{}": + schema = []byte(`{"type":"object","properties":{},"additionalProperties":false}`) + default: + var schemaObj map[string]interface{} + if err := json.Unmarshal(schema, &schemaObj); err != nil { + schema = []byte(`{"type":"object","properties":{},"additionalProperties":false}`) + } else { + // Validate type field is "object" — otherwise OpenAI rejects the + // tool. A schema like {"type":"string"} passes unmarshal but + // produces a 400 from the upstream OpenAI-compatible endpoint. + schemaType, _ := schemaObj["type"].(string) + if schemaType != "object" { + schemaObj["type"] = "object" + } + + // Validate properties is an object — wrong shapes like arrays + // or primitives also produce 400 errors upstream. + if props, ok := schemaObj["properties"]; ok { + if _, valid := props.(map[string]interface{}); !valid { + schemaObj["properties"] = map[string]interface{}{} + } + } else { + schemaObj["properties"] = map[string]interface{}{} + } + + if fixed, err := json.Marshal(schemaObj); err == nil { + schema = fixed + } + } } result = append(result, types.ToolDef{ diff --git a/internal/transformer/request_test.go b/internal/transformer/request_test.go index 580b32d..2a9a09c 100644 --- a/internal/transformer/request_test.go +++ b/internal/transformer/request_test.go @@ -1555,3 +1555,27 @@ func TestConstrainTemperature(t *testing.T) { }) } } + +// TestTransformTools_HandlesWhitespaceNullSchema guards against a panic on +// valid JSON that unmarshals to a nil map (e.g. " null " with decorative +// whitespace). The fix is to fall back to the default schema when schemaObj +// is nil after Unmarshal. +func TestTransformTools_HandlesWhitespaceNullSchema(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Bash", Description: "decorative null", InputSchema: json.RawMessage(` null `)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d (whitespace-null schema should fall back, not panic)", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"type":"object"`) { + t.Fatalf("whitespace-null schema should fall back to default object schema: %s", params) + } + if !strings.Contains(params, `"properties":{}`) { + t.Fatalf("whitespace-null schema should fall back to default properties: %s", params) + } +} diff --git a/internal/transformer/stream.go b/internal/transformer/stream.go index a2c084e..3d71850 100644 --- a/internal/transformer/stream.go +++ b/internal/transformer/stream.go @@ -55,7 +55,8 @@ func NewStreamHandler() *StreamHandler { // ProxyStream takes an OpenAI streaming response and writes Anthropic-format SSE to the writer. // It reads OpenAI ChatCompletionChunk SSE events and transforms them into Anthropic MessageEvent SSE events. -// The clientCtx is used to detect client disconnection and abort early. +// The streamCtx is the per-model attempt context (carries streaming_timeout_ms); the caller +// should wrap openaiResp with NewCtxReadCloser so the body read also respects the deadline. // // CRITICAL: This function reads directly from resp.Body without buffering to minimize latency. // Per deep research: "Don't use bufio.Scanner or bufio.Reader on the response body - it adds buffering" @@ -164,7 +165,7 @@ func (h *StreamHandler) ProxyStream( // When the idle watchdog fires, it cancels the upstream context // which produces context.Canceled on Read. Distinguish that // from a client disconnect by checking clientCtx. - if errors.Is(err, context.Canceled) && clientCtx.Err() == nil { + if (errors.Is(err, context.Canceled) || errors.Is(err, ErrStreamReadCanceled)) && clientCtx.Err() == nil { return ErrStreamIdle } return fmt.Errorf("failed to read stream: %w", err) @@ -679,6 +680,8 @@ func generateID() string { } // ProxyResponsesStream takes an OpenAI Responses streaming response and writes Anthropic-format SSE. +// streamCtx is the per-model attempt context (carries streaming_timeout_ms); the caller should +// wrap responsesResp with NewCtxReadCloser so the body read also respects the deadline. func (h *StreamHandler) ProxyResponsesStream( w http.ResponseWriter, responsesResp io.ReadCloser, @@ -753,7 +756,7 @@ func (h *StreamHandler) ProxyResponsesStream( if IsIdleTimeout(err) { return ErrStreamIdle } - if errors.Is(err, context.Canceled) && clientCtx.Err() == nil { + if (errors.Is(err, context.Canceled) || errors.Is(err, ErrStreamReadCanceled)) && clientCtx.Err() == nil { return ErrStreamIdle } return fmt.Errorf("failed to read stream: %w", err) @@ -868,6 +871,8 @@ func (h *StreamHandler) processResponsesSSELine( } // ProxyGeminiStream takes a Gemini streaming response and writes Anthropic-format SSE. +// streamCtx is the per-model attempt context (carries streaming_timeout_ms); the caller should +// wrap geminiResp with NewCtxReadCloser so the body read also respects the deadline. func (h *StreamHandler) ProxyGeminiStream( w http.ResponseWriter, geminiResp io.ReadCloser, @@ -942,7 +947,7 @@ func (h *StreamHandler) ProxyGeminiStream( if IsIdleTimeout(err) { return ErrStreamIdle } - if errors.Is(err, context.Canceled) && clientCtx.Err() == nil { + if (errors.Is(err, context.Canceled) || errors.Is(err, ErrStreamReadCanceled)) && clientCtx.Err() == nil { return ErrStreamIdle } return fmt.Errorf("failed to read stream: %w", err) From 7dd3ce42f27ab8f4972caaea0f3c9b7b958cb7ab Mon Sep 17 00:00:00 2001 From: hungcuong9125 Date: Sat, 20 Jun 2026 15:59:01 +0700 Subject: [PATCH 2/8] fix: restore schemaObj nil check to prevent panic on whitespace null schema --- internal/transformer/request.go | 38 +++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/internal/transformer/request.go b/internal/transformer/request.go index 93d7715..e1749ff 100644 --- a/internal/transformer/request.go +++ b/internal/transformer/request.go @@ -599,26 +599,32 @@ func (t *RequestTransformer) transformTools(tools []types.Tool) []types.ToolDef if err := json.Unmarshal(schema, &schemaObj); err != nil { schema = []byte(`{"type":"object","properties":{},"additionalProperties":false}`) } else { - // Validate type field is "object" — otherwise OpenAI rejects the - // tool. A schema like {"type":"string"} passes unmarshal but - // produces a 400 from the upstream OpenAI-compatible endpoint. - schemaType, _ := schemaObj["type"].(string) - if schemaType != "object" { - schemaObj["type"] = "object" - } + // Valid JSON " null " unmarshals to a nil map, which would panic + // on the field assignments below. + if schemaObj == nil { + schema = []byte(`{"type":"object","properties":{},"additionalProperties":false}`) + } else { + // Validate type field is "object" — otherwise OpenAI rejects the + // tool. A schema like {"type":"string"} passes unmarshal but + // produces a 400 from the upstream OpenAI-compatible endpoint. + schemaType, _ := schemaObj["type"].(string) + if schemaType != "object" { + schemaObj["type"] = "object" + } - // Validate properties is an object — wrong shapes like arrays - // or primitives also produce 400 errors upstream. - if props, ok := schemaObj["properties"]; ok { - if _, valid := props.(map[string]interface{}); !valid { + // Validate properties is an object — wrong shapes like arrays + // or primitives also produce 400 errors upstream. + if props, ok := schemaObj["properties"]; ok { + if _, valid := props.(map[string]interface{}); !valid { + schemaObj["properties"] = map[string]interface{}{} + } + } else { schemaObj["properties"] = map[string]interface{}{} } - } else { - schemaObj["properties"] = map[string]interface{}{} - } - if fixed, err := json.Marshal(schemaObj); err == nil { - schema = fixed + if fixed, err := json.Marshal(schemaObj); err == nil { + schema = fixed + } } } } From 328646c1d975d56ffd57623983032b1470f0933c Mon Sep 17 00:00:00 2001 From: hungcuong9125 Date: Sat, 20 Jun 2026 16:35:50 +0700 Subject: [PATCH 3/8] fix: integrate streaming context reader into new provider architecture and update streaming error handlers --- internal/handlers/messages.go | 5 ++++- internal/handlers/streaming.go | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/internal/handlers/messages.go b/internal/handlers/messages.go index 9c18685..43803e3 100644 --- a/internal/handlers/messages.go +++ b/internal/handlers/messages.go @@ -455,8 +455,11 @@ func (h *MessagesHandler) handleStreaming( continue } + // Bind body read to ctx so streaming_timeout_ms aborts mid-stream. + streamReader := transformer.NewCtxReadCloser(ctx, streamBody) + wireFormat := prov.WireFormat(model.ModelID) - if err := h.streamProxy.ProxyStream(rw, streamBody, wireFormat, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { + if err := h.streamProxy.ProxyStream(rw, streamReader, wireFormat, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { _ = streamBody.Close() if err == transformer.ErrClientDisconnected { h.logger.Debug("client disconnected during stream") diff --git a/internal/handlers/streaming.go b/internal/handlers/streaming.go index a4d2f72..7efd82b 100644 --- a/internal/handlers/streaming.go +++ b/internal/handlers/streaming.go @@ -102,7 +102,7 @@ func (sp *StreamProxy) proxyAnthropicPassthroughStream( if transformer.IsIdleTimeout(rerr) { return transformer.ErrStreamIdle } - if errors.Is(rerr, context.Canceled) || clientCtx.Err() == context.Canceled { + if errors.Is(rerr, context.Canceled) || errors.Is(rerr, transformer.ErrStreamReadCanceled) || clientCtx.Err() == context.Canceled { if clientCtx.Err() == nil { return transformer.ErrStreamIdle } From dfa4c35091f26169cbb268165f1e30c6aae5a95a Mon Sep 17 00:00:00 2001 From: hungcuong9125 Date: Sat, 20 Jun 2026 18:05:06 +0700 Subject: [PATCH 4/8] fix: integrate PR #80 streaming/fallback fixes with timeout config fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply the full feature set, test coverage, and documentation from PR #80 (fix: stabilize Anthropic-native streaming, timeout handling, and fallback cancellation) onto the new provider-architecture branch (4fe96a7). Also address issues raised in the post-merge review: * Add streaming_timeout_ms to OpenCodeGoConfig and OpenCodeZenConfig, with HTTP client relying on per-request context timeouts. * Add RequestTimeout and StreamingTimeout helpers on OpenCodeClient, and StreamIdleTimeout for per-byte idle-gap enforcement. * Expose heartbeat-paused flag during raw Anthropic streaming to prevent keepalive injection into SSE frames. * Bind streaming body reads to the per-attempt ctx via ctxio (NewCtxReader / NewCtxReadCloser) so streaming_timeout_ms aborts mid-stream, with ErrStreamReadCanceled surfaced. * Stop fallback chain early on parent ctx cancellation or deadline exceeded, and do not record client-cancel as a circuit-breaker failure. * Harden transformTools: skip empty/whitespace names, normalize null/empty schemas, validate type==object and properties is an object, guard schemaObj nil for whitespace null. * ApplyDefaults now falls back to StreamingTimeoutMs before TimeoutMs when StreamTimeoutMs is unset, so the user's streaming_timeout_ms is honored by the idle watchdog. * Expand IsAnthropicModel and isAnthropicNativeGo to include minimax-m2.5/2.7/3 and qwen-plus on the Go provider — these models reject OpenAI-format streaming with 400 and must use the Anthropic-native /v1/messages branch. * Document Streaming Scenario Routing in CONFIGURATION.md and README.md; add streaming_timeout_ms to config.example.json. * Reload messaging in atomic.go now reports timeout changes as effective immediately. * Add walkthrough.md documenting the integration and timeout fix. Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com> --- CONFIGURATION.md | 46 ++ README.md | 1 + configs/config.example.json | 6 +- internal/client/opencode.go | 45 +- internal/client/opencode_test.go | 157 +++- internal/config/atomic.go | 16 +- internal/config/config.go | 22 +- internal/config/loader.go | 12 +- internal/config/loader_test.go | 42 + internal/handlers/messages.go | 180 +++-- internal/handlers/messages_test.go | 1106 ++++++++++++++++++++++++++ internal/provider/opencode_go.go | 10 +- internal/router/fallback.go | 15 + internal/router/fallback_test.go | 267 +++++++ internal/transformer/request.go | 3 + internal/transformer/request_test.go | 181 +++++ internal/transformer/stream_test.go | 49 ++ walkthrough.md | 46 ++ 18 files changed, 2098 insertions(+), 106 deletions(-) create mode 100644 walkthrough.md diff --git a/CONFIGURATION.md b/CONFIGURATION.md index 3930e0e..821c7b8 100644 --- a/CONFIGURATION.md +++ b/CONFIGURATION.md @@ -247,3 +247,49 @@ When a request arrives, the proxy selects a model chain using the following orde 3. **Scenario routing** — fall back to the scenario chain (`default`, `background`, `think`, `complex`, `long_context`, `fast`). > **Trust model:** any client whose requests flow through the proxy can select from the configured `model_overrides` set without additional authentication. If you run the proxy as a shared service, treat `model_overrides` as a privileged allowlist. + +### Streaming Scenario Routing + +`enable_streaming_scenario_routing` controls whether streaming requests are evaluated by the full scenario router or routed directly to the `fast` scenario. + +> **Note for Claude Code `/review-code`, `/ultracode`, and multi-agent workflows** +> +> If you use Claude Code workflows that dispatch many subagents or produce many parallel tool calls, enable streaming scenario routing: +> +> ```json +> { +> "enable_streaming_scenario_routing": true +> } +> ``` +> +> Without this option, streaming requests are routed through the `fast` scenario even when the request is actually tool-heavy. This can route complex Claude Code workloads, such as `/review-code` with many `Agent` tool calls, to a fast model that may not handle parallel tool-call orchestration reliably. +> +> When enabled, streaming requests are evaluated by the same scenario router as non-streaming requests, allowing large or tool-heavy workloads to use `complex` or `long_context` models instead of always using the `fast` model. + +Recommended setup for Claude Code review workflows: + +```json +{ + "enable_streaming_scenario_routing": true, + "models": { + "fast": { + "provider": "opencode-go", + "model_id": "deepseek-v4-flash", + "max_tokens": 4096 + }, + "complex": { + "provider": "opencode-go", + "model_id": "minimax-m3", + "max_tokens": 8192 + }, + "long_context": { + "provider": "opencode-go", + "model_id": "minimax-m3", + "max_tokens": 16384, + "context_threshold": 80000 + } + } +} +``` + +Use the `fast` scenario for short/simple requests. Use `complex` or `long_context` for code review, multi-agent dispatch, large diffs, many tools, or long-context Claude Code sessions. diff --git a/README.md b/README.md index c63200e..0681b44 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ OpenCode Go gives you access to powerful open coding models for **$5/month** (th - **Transparent Proxy** — Claude Code sends Anthropic-format requests, proxy transforms to OpenAI/Responses/Gemini format and back - **Dual Provider Support** — Route models through OpenCode Go or OpenCode Zen based on your needs - **Model Routing** — Automatically routes to different models based on context (default, thinking, long context, background) +- **Streaming Scenario Routing** — Configurable routing for streaming requests; enables proper scenario selection for Claude Code multi-agent and review workflows (see [CONFIGURATION.md](CONFIGURATION.md#streaming-scenario-routing)) - **Fallback Chains** — If a model fails, automatically tries the next one in your configured chain - **Circuit Breaker** — Tracks model health and skips failing models to avoid latency spikes - **Real-time Streaming** — Full SSE streaming with live format transformation diff --git a/configs/config.example.json b/configs/config.example.json index 7d0b486..1aed460 100644 --- a/configs/config.example.json +++ b/configs/config.example.json @@ -185,7 +185,8 @@ "opencode_go": { "base_url": "https://opencode.ai/zen/go/v1/chat/completions", "anthropic_base_url": "https://opencode.ai/zen/go/v1/messages", - "timeout_ms": 300000 + "timeout_ms": 300000, + "streaming_timeout_ms": 600000 }, "opencode_zen": { @@ -193,7 +194,8 @@ "anthropic_base_url": "https://opencode.ai/zen/v1/messages", "responses_base_url": "https://opencode.ai/zen/v1/responses", "gemini_base_url": "https://opencode.ai/zen/v1/models", - "timeout_ms": 300000 + "timeout_ms": 300000, + "streaming_timeout_ms": 600000 }, "logging": { diff --git a/internal/client/opencode.go b/internal/client/opencode.go index 6edfdae..d84616e 100644 --- a/internal/client/opencode.go +++ b/internal/client/opencode.go @@ -86,6 +86,48 @@ func (c *OpenCodeClient) StreamIdleTimeout(modelConfig config.ModelConfig) time. return time.Duration(ms) * time.Millisecond } +// RequestTimeout returns the provider timeout for a non-streaming attempt. +func (c *OpenCodeClient) RequestTimeout(model config.ModelConfig) time.Duration { + if c == nil || c.atomic == nil { + return 5 * time.Minute + } + cfg := c.atomic.Get() + var timeoutMs int + if IsZen(model) { + timeoutMs = cfg.OpenCodeZen.TimeoutMs + } else { + timeoutMs = cfg.OpenCodeGo.TimeoutMs + } + if timeoutMs > 0 { + return time.Duration(timeoutMs) * time.Millisecond + } + return 5 * time.Minute +} + +// StreamingTimeout returns the provider timeout for a streaming attempt. +func (c *OpenCodeClient) StreamingTimeout(model config.ModelConfig) time.Duration { + if c == nil || c.atomic == nil { + return 5 * time.Minute + } + cfg := c.atomic.Get() + var timeoutMs int + if IsZen(model) { + timeoutMs = cfg.OpenCodeZen.StreamingTimeoutMs + if timeoutMs <= 0 { + timeoutMs = cfg.OpenCodeZen.TimeoutMs + } + } else { + timeoutMs = cfg.OpenCodeGo.StreamingTimeoutMs + if timeoutMs <= 0 { + timeoutMs = cfg.OpenCodeGo.TimeoutMs + } + } + if timeoutMs > 0 { + return time.Duration(timeoutMs) * time.Millisecond + } + return 5 * time.Minute +} + // IsAnthropicModel returns true if the model requires the Anthropic endpoint. // Most Go provider models use the Chat Completions transform path for broader // compatibility (tool format, message roles, etc.). Exceptions are models whose @@ -95,7 +137,8 @@ func (c *OpenCodeClient) StreamIdleTimeout(modelConfig config.ModelConfig) time. // Only Zen models use the raw Anthropic endpoint via ClassifyEndpoint. func IsAnthropicModel(modelID string) bool { switch modelID { - case "qwen3.7-max": // OpenCode Go backend doesn't support oa-compat for this model + case "minimax-m2.5", "minimax-m2.7", "minimax-m3", + "qwen3.5-plus", "qwen3.6-plus", "qwen3.7-plus", "qwen3.7-max": return true default: return false diff --git a/internal/client/opencode_test.go b/internal/client/opencode_test.go index a52a921..b31fe05 100644 --- a/internal/client/opencode_test.go +++ b/internal/client/opencode_test.go @@ -14,19 +14,19 @@ func TestIsAnthropicModelOnlyRoutesNativeAnthropicModels(t *testing.T) { want bool }{ { - name: "minimax m2.5 uses openai endpoint on Go provider", + name: "minimax m2.5 uses anthropic endpoint on Go provider", modelID: "minimax-m2.5", - want: false, + want: true, }, { - name: "minimax m2.7 uses openai endpoint on Go provider", + name: "minimax m2.7 uses anthropic endpoint on Go provider", modelID: "minimax-m2.7", - want: false, + want: true, }, { - name: "minimax m3 uses openai endpoint on Go provider", + name: "minimax m3 uses anthropic endpoint on Go provider", modelID: "minimax-m3", - want: false, + want: true, }, { name: "deepseek pro uses openai endpoint", @@ -64,19 +64,19 @@ func TestIsAnthropicModelOnlyRoutesNativeAnthropicModels(t *testing.T) { want: false, }, { - name: "qwen3.5-plus uses openai endpoint on Go provider", + name: "qwen3.5-plus uses anthropic endpoint on Go provider", modelID: "qwen3.5-plus", - want: false, + want: true, }, { - name: "qwen3.6-plus uses openai endpoint on Go provider", + name: "qwen3.6-plus uses anthropic endpoint on Go provider", modelID: "qwen3.6-plus", - want: false, + want: true, }, { - name: "qwen3.7-plus uses openai endpoint on Go provider", + name: "qwen3.7-plus uses anthropic endpoint on Go provider", modelID: "qwen3.7-plus", - want: false, + want: true, }, { name: "qwen3.7-max uses anthropic endpoint (no oa-compat support)", @@ -511,3 +511,136 @@ func TestStreamIdleTimeout(t *testing.T) { }) } } + +func TestRequestTimeout_UsesConfiguredTimeout(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{ + TimeoutMs: 120000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeGo, ModelID: "kimi-k2.6"} + timeout := c.RequestTimeout(model) + if timeout != 120*time.Second { + t.Errorf("RequestTimeout = %v, want 120s", timeout) + } +} + +func TestRequestTimeout_FallsBackToDefault(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{ + TimeoutMs: 0, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeGo, ModelID: "kimi-k2.6"} + timeout := c.RequestTimeout(model) + if timeout != 5*time.Minute { + t.Errorf("RequestTimeout = %v, want 5m", timeout) + } +} + +func TestRequestTimeout_ZenProvider(t *testing.T) { + cfg := &config.Config{ + OpenCodeZen: config.OpenCodeZenConfig{ + TimeoutMs: 60000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeZen, ModelID: "claude-sonnet-4.5"} + timeout := c.RequestTimeout(model) + if timeout != 60*time.Second { + t.Errorf("RequestTimeout = %v, want 60s", timeout) + } +} + +func TestStreamingTimeout_UsesStreamingTimeoutMs(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{ + TimeoutMs: 300000, + StreamingTimeoutMs: 600000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeGo, ModelID: "kimi-k2.6"} + timeout := c.StreamingTimeout(model) + if timeout != 600*time.Second { + t.Errorf("StreamingTimeout = %v, want 600s", timeout) + } +} + +func TestStreamingTimeout_FallsBackToTimeoutMs(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{ + TimeoutMs: 300000, + StreamingTimeoutMs: 0, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeGo, ModelID: "kimi-k2.6"} + timeout := c.StreamingTimeout(model) + if timeout != 300*time.Second { + t.Errorf("StreamingTimeout = %v, want 300s (fallback to timeout_ms)", timeout) + } +} + +func TestStreamingTimeout_FallsBackToDefault(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{ + TimeoutMs: 0, + StreamingTimeoutMs: 0, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeGo, ModelID: "kimi-k2.6"} + timeout := c.StreamingTimeout(model) + if timeout != 5*time.Minute { + t.Errorf("StreamingTimeout = %v, want 5m", timeout) + } +} + +func TestStreamingTimeout_ZenProvider(t *testing.T) { + cfg := &config.Config{ + OpenCodeZen: config.OpenCodeZenConfig{ + TimeoutMs: 300000, + StreamingTimeoutMs: 600000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeZen, ModelID: "claude-sonnet-4.5"} + timeout := c.StreamingTimeout(model) + if timeout != 600*time.Second { + t.Errorf("StreamingTimeout = %v, want 600s", timeout) + } +} + +func TestStreamingTimeout_SmallConfiguredValue(t *testing.T) { + cfg := &config.Config{ + OpenCodeGo: config.OpenCodeGoConfig{ + TimeoutMs: 300000, + StreamingTimeoutMs: 100, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "") + c := NewOpenCodeClient(atomicCfg) + + model := config.ModelConfig{Provider: ProviderOpenCodeGo, ModelID: "kimi-k2.6"} + timeout := c.StreamingTimeout(model) + if timeout != 100*time.Millisecond { + t.Errorf("StreamingTimeout = %v, want 100ms", timeout) + } +} diff --git a/internal/config/atomic.go b/internal/config/atomic.go index 2123fb1..f90f6d1 100644 --- a/internal/config/atomic.go +++ b/internal/config/atomic.go @@ -38,17 +38,23 @@ func (a *AtomicConfig) Reload() error { return err } - // Warn about changes that require a server restart before swapping. + // Warn about settings that take effect differently on reload. if old != nil { if old.Host != cfg.Host || old.Port != cfg.Port { slog.Warn("host/port changed but requires server restart to take effect", "old_host", old.Host, "new_host", cfg.Host, "old_port", old.Port, "new_port", cfg.Port) } - if old.OpenCodeGo.TimeoutMs != cfg.OpenCodeGo.TimeoutMs { - slog.Warn("timeout_ms changed but requires server restart to take effect", - "old_timeout", old.OpenCodeGo.TimeoutMs, - "new_timeout", cfg.OpenCodeGo.TimeoutMs) + // Timeout changes apply on the next request. + if old.OpenCodeGo.TimeoutMs != cfg.OpenCodeGo.TimeoutMs || + old.OpenCodeGo.StreamingTimeoutMs != cfg.OpenCodeGo.StreamingTimeoutMs || + old.OpenCodeZen.TimeoutMs != cfg.OpenCodeZen.TimeoutMs || + old.OpenCodeZen.StreamingTimeoutMs != cfg.OpenCodeZen.StreamingTimeoutMs { + slog.Info("timeout config updated, takes effect immediately", + "go_timeout_ms", cfg.OpenCodeGo.TimeoutMs, + "go_streaming_timeout_ms", cfg.OpenCodeGo.StreamingTimeoutMs, + "zen_timeout_ms", cfg.OpenCodeZen.TimeoutMs, + "zen_streaming_timeout_ms", cfg.OpenCodeZen.StreamingTimeoutMs) } } diff --git a/internal/config/config.go b/internal/config/config.go index 9df4a39..983f064 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -36,20 +36,22 @@ type ModelConfig struct { // OpenCodeGoConfig holds the upstream OpenCode Go API settings. type OpenCodeGoConfig struct { - BaseURL string `json:"base_url"` - AnthropicBaseURL string `json:"anthropic_base_url"` - TimeoutMs int `json:"timeout_ms"` - StreamTimeoutMs int `json:"stream_timeout_ms"` + BaseURL string `json:"base_url"` + AnthropicBaseURL string `json:"anthropic_base_url"` + TimeoutMs int `json:"timeout_ms"` + StreamTimeoutMs int `json:"stream_timeout_ms"` + StreamingTimeoutMs int `json:"streaming_timeout_ms,omitempty"` } // OpenCodeZenConfig holds the upstream OpenCode Zen API settings. type OpenCodeZenConfig struct { - BaseURL string `json:"base_url"` - AnthropicBaseURL string `json:"anthropic_base_url"` - ResponsesBaseURL string `json:"responses_base_url"` - GeminiBaseURL string `json:"gemini_base_url"` - TimeoutMs int `json:"timeout_ms"` - StreamTimeoutMs int `json:"stream_timeout_ms"` + BaseURL string `json:"base_url"` + AnthropicBaseURL string `json:"anthropic_base_url"` + ResponsesBaseURL string `json:"responses_base_url"` + GeminiBaseURL string `json:"gemini_base_url"` + TimeoutMs int `json:"timeout_ms"` + StreamTimeoutMs int `json:"stream_timeout_ms"` + StreamingTimeoutMs int `json:"streaming_timeout_ms,omitempty"` } // LoggingConfig controls application logging behavior. diff --git a/internal/config/loader.go b/internal/config/loader.go index 0a7b5c0..c8641d7 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -183,7 +183,11 @@ func applyDefaults(cfg *Config) { cfg.OpenCodeGo.TimeoutMs = defaultTimeoutMs } if cfg.OpenCodeGo.StreamTimeoutMs == 0 { - cfg.OpenCodeGo.StreamTimeoutMs = cfg.OpenCodeGo.TimeoutMs + if cfg.OpenCodeGo.StreamingTimeoutMs > 0 { + cfg.OpenCodeGo.StreamTimeoutMs = cfg.OpenCodeGo.StreamingTimeoutMs + } else { + cfg.OpenCodeGo.StreamTimeoutMs = cfg.OpenCodeGo.TimeoutMs + } } if cfg.OpenCodeZen.BaseURL == "" { cfg.OpenCodeZen.BaseURL = defaultZenBaseURL @@ -201,7 +205,11 @@ func applyDefaults(cfg *Config) { cfg.OpenCodeZen.TimeoutMs = defaultTimeoutMs } if cfg.OpenCodeZen.StreamTimeoutMs == 0 { - cfg.OpenCodeZen.StreamTimeoutMs = cfg.OpenCodeZen.TimeoutMs + if cfg.OpenCodeZen.StreamingTimeoutMs > 0 { + cfg.OpenCodeZen.StreamTimeoutMs = cfg.OpenCodeZen.StreamingTimeoutMs + } else { + cfg.OpenCodeZen.StreamTimeoutMs = cfg.OpenCodeZen.TimeoutMs + } } if cfg.Logging.Level == "" { cfg.Logging.Level = defaultLogLevel diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 04d926c..535b746 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -634,3 +634,45 @@ func TestValidateAPIKeys_RejectsAllEmpty(t *testing.T) { t.Fatal("expected validation error for empty api_keys entry, got nil") } } + +func TestDefaults_StreamingTimeoutFallback(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + + cfgJSON := `{ + "api_key": "test-key", + "opencode_go": { + "timeout_ms": 300000, + "streaming_timeout_ms": 600000 + }, + "opencode_zen": { + "timeout_ms": 300000, + "streaming_timeout_ms": 700000 + } + }` + if err := os.WriteFile(cfgPath, []byte(cfgJSON), 0644); err != nil { + t.Fatalf("failed to write test config: %v", err) + } + + _ = os.Setenv("OC_GO_CC_CONFIG", cfgPath) + defer func() { _ = os.Unsetenv("OC_GO_CC_CONFIG") }() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error = %v", err) + } + + if cfg.OpenCodeGo.StreamingTimeoutMs != 600000 { + t.Errorf("OpenCodeGo.StreamingTimeoutMs = %d, want 600000", cfg.OpenCodeGo.StreamingTimeoutMs) + } + if cfg.OpenCodeGo.StreamTimeoutMs != 600000 { + t.Errorf("OpenCodeGo.StreamTimeoutMs = %d, want 600000 (should fallback to StreamingTimeoutMs)", cfg.OpenCodeGo.StreamTimeoutMs) + } + + if cfg.OpenCodeZen.StreamingTimeoutMs != 700000 { + t.Errorf("OpenCodeZen.StreamingTimeoutMs = %d, want 700000", cfg.OpenCodeZen.StreamingTimeoutMs) + } + if cfg.OpenCodeZen.StreamTimeoutMs != 700000 { + t.Errorf("OpenCodeZen.StreamTimeoutMs = %d, want 700000 (should fallback to StreamingTimeoutMs)", cfg.OpenCodeZen.StreamTimeoutMs) + } +} diff --git a/internal/handlers/messages.go b/internal/handlers/messages.go index 43803e3..2e07dda 100644 --- a/internal/handlers/messages.go +++ b/internal/handlers/messages.go @@ -351,6 +351,7 @@ func (h *MessagesHandler) handleStreaming( // Start heartbeat var finished int32 + var heartbeatPaused int32 heartbeatDone := make(chan struct{}) go func() { ticker := time.NewTicker(3 * time.Second) @@ -362,7 +363,9 @@ func (h *MessagesHandler) handleStreaming( if atomic.LoadInt32(&finished) == 1 { return } - rw.WriteKeepalive() + if atomic.LoadInt32(&heartbeatPaused) == 0 { + rw.WriteKeepalive() + } case <-heartbeatDone: return case <-clientCtx.Done(): @@ -387,20 +390,15 @@ func (h *MessagesHandler) handleStreaming( h.logger.Info("attempting streaming model", "model", model.ModelID, "provider", model.Provider) - // Upstream context: no total deadline. The stream lives as long as - // data keeps flowing. Per-Read idle deadline is enforced in stream.go - // via http.ResponseController, so a stuck stream is still caught. - ctx, cancel := context.WithCancel(context.Background()) - go func() { - <-clientCtx.Done() - cancel() - }() + // Upstream context carries the streaming timeout configured for the model. + timeout := h.client.StreamingTimeout(model) + attemptCtx, cancelAttempt := context.WithTimeout(clientCtx, timeout) idleTimeout := h.client.StreamIdleTimeout(model) // recordStreamSuccess records a successful stream completion and // marks the model attempt as done. recordStreamSuccess := func(model config.ModelConfig) { - cancel() + cancelAttempt() latency := time.Since(streamStart) h.metrics.RecordSuccess(model.ModelID, latency) h.logger.Info("streaming completed", "model", model.ModelID, "latency", latency) @@ -411,8 +409,8 @@ func (h *MessagesHandler) handleStreaming( // if the caller should continue (fallback to next model), or false // if it should return. handleStreamError := func(err error, model config.ModelConfig, action string) bool { - cancel() - if clientCtx.Err() == context.Canceled { + cancelAttempt() + if clientCtx.Err() != nil { h.logger.Debug("client disconnected during " + action + " stream") return false // abort } @@ -436,44 +434,56 @@ func (h *MessagesHandler) handleStreaming( } // Try new provider-based dispatch first. - if prov, ok := h.providerRegistry.Get(model.Provider); ok { - caps, ok := prov.ModelCapabilities(model.ModelID) - if !ok || !caps.SupportsStreaming { - h.logger.Warn("model does not support streaming", "model", model.ModelID, "provider", model.Provider) - cancel() - continue - } + if h.providerRegistry != nil { + if prov, ok := h.providerRegistry.Get(model.Provider); ok { + caps, ok := prov.ModelCapabilities(model.ModelID) + if !ok || !caps.SupportsStreaming { + h.logger.Warn("model does not support streaming", "model", model.ModelID, "provider", model.Provider) + cancelAttempt() + continue + } - streamBody, err := prov.Stream(ctx, normalizedReq, model) - if err != nil { - cancel() - if clientCtx.Err() == context.Canceled { - h.logger.Debug("client disconnected during upstream request") - return + streamBody, err := prov.Stream(attemptCtx, normalizedReq, model) + if err != nil { + cancelAttempt() + if clientCtx.Err() != nil { + h.logger.Debug("client disconnected during upstream request") + return + } + h.logger.Warn("streaming request failed via provider", "model", model.ModelID, "provider", model.Provider, "error", err) + continue } - h.logger.Warn("streaming request failed via provider", "model", model.ModelID, "provider", model.Provider, "error", err) - continue - } - // Bind body read to ctx so streaming_timeout_ms aborts mid-stream. - streamReader := transformer.NewCtxReadCloser(ctx, streamBody) + // Bind body read to attemptCtx so streaming_timeout_ms aborts mid-stream. + streamReader := transformer.NewCtxReadCloser(attemptCtx, streamBody) - wireFormat := prov.WireFormat(model.ModelID) - if err := h.streamProxy.ProxyStream(rw, streamReader, wireFormat, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { - _ = streamBody.Close() - if err == transformer.ErrClientDisconnected { - h.logger.Debug("client disconnected during stream") - return + wireFormat := prov.WireFormat(model.ModelID) + if wireFormat == core.WireFormatAnthropic { + atomic.StoreInt32(&heartbeatPaused, 1) } - if !handleStreamError(err, model, wireFormat.String()) { - return + errProxy := h.streamProxy.ProxyStream(rw, streamReader, wireFormat, model.ModelID, attemptCtx, idleTimeout, cancelAttempt) + if wireFormat == core.WireFormatAnthropic { + atomic.StoreInt32(&heartbeatPaused, 0) + } + if errProxy != nil { + _ = streamBody.Close() + if errProxy == transformer.ErrClientDisconnected { + if clientCtx.Err() != nil { + h.logger.Debug("client disconnected during stream") + return + } + errProxy = fmt.Errorf("streaming timeout (%v) exceeded", timeout) + } + if !handleStreamError(errProxy, model, wireFormat.String()) { + return + } + continue } - continue - } - _ = streamBody.Close() - recordStreamSuccess(model) - return + _ = streamBody.Close() + recordStreamSuccess(model) + return + } } // Legacy path for backward compatibility while old client is still in @@ -490,7 +500,7 @@ func (h *MessagesHandler) handleStreaming( // Fall through to OpenAI-compatible transform path below. } else { modelBody := replaceModelInRawBody(rawBody, model.ModelID) - if err := h.handleAnthropicStreaming(ctx, rw, modelBody, model.ModelID, model, idleTimeout, cancel, clientCtx); err != nil { + if err := h.handleAnthropicStreaming(attemptCtx, rw, modelBody, model.ModelID, model, idleTimeout, cancelAttempt, clientCtx, &heartbeatPaused); err != nil { if !handleStreamError(err, model, "anthropic") { return } @@ -501,7 +511,14 @@ func (h *MessagesHandler) handleStreaming( } case client.EndpointResponses: - if err := h.handleResponsesStreaming(ctx, rw, anthropicReq, model, clientCtx, idleTimeout, cancel); err != nil { + if err := h.handleResponsesStreaming(attemptCtx, rw, anthropicReq, model, clientCtx, idleTimeout, cancelAttempt); err != nil { + if err == transformer.ErrClientDisconnected { + if clientCtx.Err() != nil { + h.logger.Debug("client disconnected during responses stream") + return + } + err = fmt.Errorf("streaming timeout (%v) exceeded", timeout) + } if !handleStreamError(err, model, "responses") { return } @@ -511,7 +528,14 @@ func (h *MessagesHandler) handleStreaming( return case client.EndpointGemini: - if err := h.handleGeminiStreaming(ctx, rw, anthropicReq, model, clientCtx, idleTimeout, cancel); err != nil { + if err := h.handleGeminiStreaming(attemptCtx, rw, anthropicReq, model, clientCtx, idleTimeout, cancelAttempt); err != nil { + if err == transformer.ErrClientDisconnected { + if clientCtx.Err() != nil { + h.logger.Debug("client disconnected during gemini stream") + return + } + err = fmt.Errorf("streaming timeout (%v) exceeded", timeout) + } if !handleStreamError(err, model, "gemini") { return } @@ -529,15 +553,10 @@ func (h *MessagesHandler) handleStreaming( // Anthropic format rather than the OpenAI Chat Completions transform. if !client.IsZen(model) && client.IsAnthropicModel(model.ModelID) { modelBody := replaceModelInRawBody(rawBody, model.ModelID) - if err := h.handleAnthropicStreaming(ctx, rw, modelBody, model.ModelID, model, idleTimeout, cancel, clientCtx); err != nil { + if err := h.handleAnthropicStreaming(attemptCtx, rw, modelBody, model.ModelID, model, idleTimeout, cancelAttempt, clientCtx, &heartbeatPaused); err != nil { if !handleStreamError(err, model, "anthropic") { return } - if err != transformer.ErrStreamIdle && rw.ssePayloadWritten { - h.sendStreamError(rw, fmt.Sprintf("all upstream models failed after SSE payload started: %v", err)) - h.metrics.RecordFailure() - return - } continue } recordStreamSuccess(model) @@ -547,15 +566,15 @@ func (h *MessagesHandler) handleStreaming( // OpenAI-compatible models (both Go and Zen) openaiReq, err := h.requestTransformer.TransformRequest(anthropicReq, model) if err != nil { - cancel() + cancelAttempt() h.logger.Warn("request transform failed", "model", model.ModelID, "error", err) continue } - streamBody, err := h.client.GetStreamingBody(ctx, model.ModelID, openaiReq, model) + streamBody, err := h.client.GetStreamingBody(attemptCtx, model.ModelID, openaiReq, model) if err != nil { - cancel() - if clientCtx.Err() == context.Canceled { + cancelAttempt() + if clientCtx.Err() != nil { h.logger.Debug("client disconnected during upstream request") return } @@ -563,14 +582,17 @@ func (h *MessagesHandler) handleStreaming( continue } - // Bind body read to ctx so streaming_timeout_ms aborts mid-stream. - streamReader := transformer.NewCtxReadCloser(ctx, streamBody) + // Bind body read to attemptCtx so streaming_timeout_ms aborts mid-stream. + streamReader := transformer.NewCtxReadCloser(attemptCtx, streamBody) - if err := h.streamHandler.ProxyStream(rw, streamReader, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { + if err := h.streamHandler.ProxyStream(rw, streamReader, model.ModelID, attemptCtx, idleTimeout, cancelAttempt); err != nil { _ = streamBody.Close() if err == transformer.ErrClientDisconnected { - h.logger.Debug("client disconnected during stream") - return + if clientCtx.Err() != nil { + h.logger.Debug("client disconnected during stream") + return + } + err = fmt.Errorf("streaming timeout (%v) exceeded", timeout) } if !handleStreamError(err, model, "openai") { return @@ -715,6 +737,9 @@ func replaceModelInRawBody(rawBody json.RawMessage, modelID string) json.RawMess "error", err) return rawBody } + if _, ok := obj["model"]; !ok { + return rawBody + } encoded, err := json.Marshal(modelID) if err != nil { // json.Marshal on a string should never fail, but guard anyway. @@ -745,7 +770,10 @@ func (h *MessagesHandler) handleAnthropicStreaming( idleTimeout time.Duration, cancel context.CancelFunc, clientCtx context.Context, + heartbeatPaused *int32, ) error { + atomic.StoreInt32(heartbeatPaused, 1) + defer atomic.StoreInt32(heartbeatPaused, 0) // Sanitize Anthropic-specific fields (e.g., tool type shorthands) that // upstream models may not understand. rawBody = sanitizeAnthropicBody(rawBody) @@ -852,13 +880,19 @@ func (h *MessagesHandler) handleNonStreaming( ctx, modelChain, func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + timeout := h.client.RequestTimeout(model) + attemptCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + // Try new provider-based dispatch first. - if prov, ok := h.providerRegistry.Get(model.Provider); ok { - execResult, execErr := prov.Execute(ctx, normalizedReq, model) - if execErr != nil { - return nil, execErr + if h.providerRegistry != nil { + if prov, ok := h.providerRegistry.Get(model.Provider); ok { + execResult, execErr := prov.Execute(attemptCtx, normalizedReq, model) + if execErr != nil { + return nil, execErr + } + return execResult.Body, nil } - return execResult.Body, nil } h.logger.Warn("provider not found in registry, falling back to old client", @@ -872,26 +906,30 @@ func (h *MessagesHandler) handleNonStreaming( if model.AnthropicToolsDisabled { // Fall through to OpenAI-compatible handling below. } else { - return h.executeAnthropicRequest(ctx, replaceModelInRawBody(rawBody, model.ModelID), model) + return h.executeAnthropicRequest(attemptCtx, replaceModelInRawBody(rawBody, model.ModelID), model) } case client.EndpointResponses: - return h.executeResponsesRequest(ctx, anthropicReq, model) + return h.executeResponsesRequest(attemptCtx, anthropicReq, model) case client.EndpointGemini: - return h.executeGeminiRequest(ctx, anthropicReq, model) + return h.executeGeminiRequest(attemptCtx, anthropicReq, model) default: // Fall through to OpenAI-compatible handling } } else if client.IsAnthropicModel(model.ModelID) { // Go provider Anthropic-native models (MiniMax, Qwen) - return h.executeAnthropicRequest(ctx, replaceModelInRawBody(rawBody, model.ModelID), model) + return h.executeAnthropicRequest(attemptCtx, replaceModelInRawBody(rawBody, model.ModelID), model) } // OpenAI-compatible models (both Go and Zen) - return h.executeOpenAIRequest(ctx, anthropicReq, model) + return h.executeOpenAIRequest(attemptCtx, anthropicReq, model) }, ) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + h.logger.Info("request context canceled during non-streaming fallback", "error", err) + return + } h.metrics.RecordFailure() h.sendError(w, http.StatusBadGateway, "all models failed", err) return diff --git a/internal/handlers/messages_test.go b/internal/handlers/messages_test.go index ecd65a2..ab3b40a 100644 --- a/internal/handlers/messages_test.go +++ b/internal/handlers/messages_test.go @@ -1,12 +1,27 @@ package handlers import ( + "context" "encoding/json" + "fmt" + "io" "log/slog" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" "testing" + "time" + "github.com/routatic/proxy/internal/client" "github.com/routatic/proxy/internal/config" + "github.com/routatic/proxy/internal/core" + "github.com/routatic/proxy/internal/metrics" "github.com/routatic/proxy/internal/router" + "github.com/routatic/proxy/internal/token" + "github.com/routatic/proxy/internal/transformer" + "github.com/routatic/proxy/pkg/types" ) func boolPtr(b bool) *bool { return &b } @@ -466,3 +481,1094 @@ func TestSanitizeAnthropicBody_KeepsOtherFields(t *testing.T) { t.Error("max_tokens field was corrupted") } } + +func TestReplaceModelInRawBody_JSONBased(t *testing.T) { + raw := json.RawMessage(`{"model":"old-model","stream":true}`) + res := replaceModelInRawBody(raw, "new-model") + var m map[string]interface{} + if err := json.Unmarshal(res, &m); err != nil { + t.Fatal(err) + } + if got := m["model"]; got != "new-model" { + t.Errorf("got %q, want new-model", got) + } + if got := m["stream"]; got != true { + t.Errorf("got %v, want true", got) + } +} + +func TestReplaceModelInRawBody_HandlesWhitespace(t *testing.T) { + raw := json.RawMessage(`{ "model" : "old-model" , "stream": true}`) + res := replaceModelInRawBody(raw, "new-model") + var m map[string]interface{} + if err := json.Unmarshal(res, &m); err != nil { + t.Fatal(err) + } + if got := m["model"]; got != "new-model" { + t.Errorf("got %q, want new-model", got) + } +} + +func TestReplaceModelInRawBody_ReturnsOriginalWhenModelMissing(t *testing.T) { + raw := json.RawMessage(`{"stream":true}`) + res := replaceModelInRawBody(raw, "new-model") + if string(res) != string(raw) { + t.Errorf("got %s, want original", string(res)) + } +} + +func TestReplaceModelInRawBody_ReturnsOriginalOnInvalidJSON(t *testing.T) { + raw := json.RawMessage(`{invalid json}`) + res := replaceModelInRawBody(raw, "new-model") + if string(res) != string(raw) { + t.Errorf("got %s, want original", string(res)) + } +} + +func TestReplaceModelInRawBody_HandlesNestedObjects(t *testing.T) { + raw := json.RawMessage(`{"model":"old","nested":{"model":"don't touch me"}}`) + res := replaceModelInRawBody(raw, "new") + var m map[string]interface{} + if err := json.Unmarshal(res, &m); err != nil { + t.Fatal(err) + } + if got := m["model"]; got != "new" { + t.Errorf("top-level model = %q, want new", got) + } + nested := m["nested"].(map[string]interface{}) + if got := nested["model"]; got != "don't touch me" { + t.Errorf("nested model = %q, want 'don't touch me'", got) + } +} + +func TestHandleStreaming_GoAnthropicModel_SendsRawAnthropicBody(t *testing.T) { + var capturedBody []byte + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + capturedBody, err = io.ReadAll(r.Body) + if err != nil { + t.Logf("upstream read body error: %v", err) + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {}\n\n") + _, _ = fmt.Fprintf(w, "event: message_stop\ndata: {}\n\n") + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + })) + defer upstream.Close() + + handler := newStreamingTestHandler(t, upstream.URL) + + rawBody := json.RawMessage(`{ + "model": "claude-opus-4-8", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}], + "tools": [{ + "name": "Bash", + "description": "Run a command", + "input_schema": {"type": "object", "properties": {"cmd": {"type": "string"}}} + }] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal rawBody: %v", err) + } + + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "minimax-m3"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, &core.NormalizedRequest{}, chain, rawBody) + + if len(capturedBody) == 0 { + t.Fatal("upstream received no body") + } + + var captured map[string]interface{} + if err := json.Unmarshal(capturedBody, &captured); err != nil { + t.Fatalf("captured body is not valid JSON: %v\nbody: %s", err, capturedBody) + } + + if got, ok := captured["model"]; !ok || got != "minimax-m3" { + t.Fatalf("captured model = %v, want minimax-m3", got) + } + + toolsRaw, ok := captured["tools"] + if !ok { + t.Fatal("captured body missing tools field") + } + tools, ok := toolsRaw.([]interface{}) + if !ok || len(tools) == 0 { + t.Fatal("captured body tools is empty or not an array") + } + tool0, ok := tools[0].(map[string]interface{}) + if !ok { + t.Fatal("tool[0] is not an object") + } + if _, ok := tool0["function"]; ok { + t.Fatalf("captured tool has 'function' field (OpenAI format leak): %s", capturedBody) + } + if _, ok := tool0["input_schema"]; !ok { + t.Fatalf("captured tool missing 'input_schema' (Anthropic format): %s", capturedBody) + } + if got, ok := tool0["name"]; !ok || got != "Bash" { + t.Fatalf("captured tool name = %v, want Bash", got) + } +} + +func TestHandleStreaming_GoAnthropicModel_FallsThroughOnError(t *testing.T) { + callCount := int32(0) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {}\n\n") + _, _ = fmt.Fprintf(w, "event: message_stop\ndata: {}\n\n") + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + OpenCodeGo: config.OpenCodeGoConfig{ + AnthropicBaseURL: upstream.URL, + BaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + + handler := &MessagesHandler{ + client: ocClient, + logger: slog.Default(), + metrics: metrics.New(), + streamHandler: transformer.NewStreamHandler(), + requestTransformer: transformer.NewRequestTransformer(), + responseTransformer: transformer.NewResponseTransformer(), + } + + rawBody := json.RawMessage(`{ + "model": "claude-opus-4-8", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal rawBody: %v", err) + } + + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "minimax-m3"}, + {Provider: "opencode-go", ModelID: "qwen3.5-plus"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, &core.NormalizedRequest{}, chain, rawBody) + + finalCount := atomic.LoadInt32(&callCount) + if finalCount != 2 { + t.Fatalf("expected 2 upstream calls (1 fail + 1 success), got %d", finalCount) + } +} + +func newStreamingTestHandler(t *testing.T, upstreamURL string) *MessagesHandler { + t.Helper() + cfg := &config.Config{ + APIKey: "test-key", + OpenCodeGo: config.OpenCodeGoConfig{ + AnthropicBaseURL: upstreamURL, + BaseURL: upstreamURL, + TimeoutMs: 5000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + + return &MessagesHandler{ + client: ocClient, + logger: slog.Default(), + metrics: metrics.New(), + streamHandler: transformer.NewStreamHandler(), + requestTransformer: transformer.NewRequestTransformer(), + responseTransformer: transformer.NewResponseTransformer(), + } +} + +func TestHandleMessages_StreamingMinimaxM3_UsesAnthropicEndpoint(t *testing.T) { + var capturedBody []byte + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + capturedBody, err = io.ReadAll(r.Body) + if err != nil { + t.Logf("upstream read body error: %v", err) + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {}\n\n") + _, _ = fmt.Fprintf(w, "event: message_stop\ndata: {}\n\n") + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + Models: map[string]config.ModelConfig{ + "default": {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + "fast": {Provider: "opencode-go", ModelID: "qwen3.6-plus"}, + }, + Fallbacks: map[string][]config.ModelConfig{ + "default": {{Provider: "opencode-go", ModelID: "glm-5"}}, + "fast": {{Provider: "opencode-go", ModelID: "qwen3.5-plus"}}, + }, + ModelOverrides: map[string]config.ModelConfig{ + "minimax-m3": { + Provider: "opencode-go", + ModelID: "minimax-m3", + }, + }, + OpenCodeGo: config.OpenCodeGoConfig{ + AnthropicBaseURL: upstream.URL, + BaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + + ocClient := client.NewOpenCodeClient(atomicCfg) + modelRouter := router.NewModelRouter(atomicCfg) + tokenCounter, err := token.NewCounter() + if err != nil { + t.Fatalf("NewCounter: %v", err) + } + + handler := NewMessagesHandler( + ocClient, + nil, // providerRegistry + modelRouter, + nil, // fallbackHandler + tokenCounter, + metrics.New(), + ) + handler.logger = slog.Default() + + requestBody := `{ + "model": "minimax-m3", + "stream": true, + "max_tokens": 256, + "messages": [{"role": "user", "content": "Say hello"}], + "tools": [{ + "name": "Bash", + "description": "Run a command", + "input_schema": {"type": "object", "properties": {"cmd": {"type": "string"}}} + }] + }` + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + + handler.HandleMessages(recorder, req) + + if len(capturedBody) == 0 { + t.Fatal("upstream received no body") + } + + var captured map[string]interface{} + if err := json.Unmarshal(capturedBody, &captured); err != nil { + t.Fatalf("captured body is not valid JSON: %v\nbody: %s", err, capturedBody) + } + + if got, ok := captured["model"]; !ok || got != "minimax-m3" { + t.Fatalf("captured model = %v, want minimax-m3", got) + } + + toolsRaw, ok := captured["tools"] + if !ok { + t.Fatal("captured body missing tools field") + } + tools, ok := toolsRaw.([]interface{}) + if !ok || len(tools) == 0 { + t.Fatal("captured body tools is empty or not an array") + } + tool0, ok := tools[0].(map[string]interface{}) + if !ok { + t.Fatal("tool[0] is not an object") + } + if _, ok := tool0["function"]; ok { + t.Fatalf("captured tool has 'function' field (OpenAI format leak): %s", capturedBody) + } + if _, ok := tool0["input_schema"]; !ok { + t.Fatalf("captured tool missing 'input_schema' (Anthropic format): %s", capturedBody) + } +} + +func TestHandleNonStreaming_GoAnthropicModel_ReplacesModelInBody(t *testing.T) { + var capturedBody []byte + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + capturedBody, err = io.ReadAll(r.Body) + if err != nil { + t.Logf("upstream read body error: %v", err) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + "model": "minimax-m3", + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 5} + }`)) + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + Models: map[string]config.ModelConfig{ + "default": {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + }, + Fallbacks: map[string][]config.ModelConfig{ + "default": {{Provider: "opencode-go", ModelID: "glm-5"}}, + }, + ModelOverrides: map[string]config.ModelConfig{ + "claude-haiku-4-5-20251001": { + Provider: "opencode-go", + ModelID: "minimax-m3", + }, + }, + OpenCodeGo: config.OpenCodeGoConfig{ + AnthropicBaseURL: upstream.URL, + BaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + modelRouter := router.NewModelRouter(atomicCfg) + tokenCounter, err := token.NewCounter() + if err != nil { + t.Fatalf("NewCounter: %v", err) + } + + handler := NewMessagesHandler( + ocClient, + nil, // providerRegistry + modelRouter, + router.NewFallbackHandler(slog.Default(), 3, 30*time.Second), + tokenCounter, + metrics.New(), + ) + handler.logger = slog.Default() + + requestBody := `{ + "model": "claude-haiku-4-5-20251001", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Say hello"}], + "tools": [{ + "name": "Bash", + "description": "Run a command", + "input_schema": {"type": "object", "properties": {"cmd": {"type": "string"}}} + }] + }` + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + + handler.HandleMessages(recorder, req) + + if len(capturedBody) == 0 { + t.Fatal("upstream received no body") + } + + var captured map[string]interface{} + if err := json.Unmarshal(capturedBody, &captured); err != nil { + t.Fatalf("captured body is not valid JSON: %v\nbody: %s", err, capturedBody) + } + + if got, ok := captured["model"]; !ok || got != "minimax-m3" { + t.Fatalf("captured model = %v, want minimax-m3", got) + } + + toolsRaw, ok := captured["tools"] + if !ok { + t.Fatal("captured body missing tools field") + } + tools, ok := toolsRaw.([]interface{}) + if !ok || len(tools) == 0 { + t.Fatal("captured body tools is empty or not an array") + } + tool0, ok := tools[0].(map[string]interface{}) + if !ok { + t.Fatal("tool[0] is not an object") + } + if _, ok := tool0["function"]; ok { + t.Fatalf("captured tool has 'function' field (OpenAI format leak): %s", capturedBody) + } + if _, ok := tool0["input_schema"]; !ok { + t.Fatalf("captured tool missing 'input_schema' (Anthropic format): %s", capturedBody) + } +} + +func TestHandleNonStreaming_ZenAnthropicModel_ReplacesModelInBody(t *testing.T) { + var capturedBody []byte + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + capturedBody, err = io.ReadAll(r.Body) + if err != nil { + t.Logf("upstream read body error: %v", err) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + "model": "claude-sonnet-4.5", + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 5} + }`)) + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + Models: map[string]config.ModelConfig{ + "default": {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + }, + Fallbacks: map[string][]config.ModelConfig{ + "default": {{Provider: "opencode-go", ModelID: "glm-5"}}, + }, + ModelOverrides: map[string]config.ModelConfig{ + "claude-haiku-4-5-20251001": { + Provider: "opencode-zen", + ModelID: "claude-sonnet-4.5", + }, + }, + OpenCodeGo: config.OpenCodeGoConfig{ + AnthropicBaseURL: upstream.URL, + BaseURL: upstream.URL, + TimeoutMs: 5000, + }, + OpenCodeZen: config.OpenCodeZenConfig{ + AnthropicBaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + modelRouter := router.NewModelRouter(atomicCfg) + tokenCounter, err := token.NewCounter() + if err != nil { + t.Fatalf("NewCounter: %v", err) + } + + handler := NewMessagesHandler( + ocClient, + nil, // providerRegistry + modelRouter, + router.NewFallbackHandler(slog.Default(), 3, 30*time.Second), + tokenCounter, + metrics.New(), + ) + handler.logger = slog.Default() + + requestBody := `{ + "model": "claude-haiku-4-5-20251001", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Say hello"}], + "tools": [{ + "name": "Bash", + "description": "Run a command", + "input_schema": {"type": "object", "properties": {"cmd": {"type": "string"}}} + }] + }` + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + + handler.HandleMessages(recorder, req) + + if len(capturedBody) == 0 { + t.Fatal("upstream received no body") + } + + var captured map[string]interface{} + if err := json.Unmarshal(capturedBody, &captured); err != nil { + t.Fatalf("captured body is not valid JSON: %v\nbody: %s", err, capturedBody) + } + + if got, ok := captured["model"]; !ok || got != "claude-sonnet-4.5" { + t.Fatalf("captured model = %v, want claude-sonnet-4.5", got) + } + + toolsRaw, ok := captured["tools"] + if !ok { + t.Fatal("captured body missing tools field") + } + tools, ok := toolsRaw.([]interface{}) + if !ok || len(tools) == 0 { + t.Fatal("captured body tools is empty or not an array") + } + tool0, ok := tools[0].(map[string]interface{}) + if !ok { + t.Fatal("tool[0] is not an object") + } + if _, ok := tool0["function"]; ok { + t.Fatalf("captured tool has 'function' field (OpenAI format leak): %s", capturedBody) + } + if _, ok := tool0["input_schema"]; !ok { + t.Fatalf("captured tool missing 'input_schema' (Anthropic format): %s", capturedBody) + } +} + +func TestHandleStreaming_ConfigurableTimeout(t *testing.T) { + upstreamChan := make(chan struct{}) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + select { + case <-upstreamChan: + case <-time.After(5 * time.Second): + } + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {}\n\n") + })) + defer upstream.Close() + defer close(upstreamChan) + + cfg := &config.Config{ + APIKey: "test-key", + OpenCodeGo: config.OpenCodeGoConfig{ + BaseURL: upstream.URL, + TimeoutMs: 300000, + StreamingTimeoutMs: 100, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + + handler := &MessagesHandler{ + client: ocClient, + logger: slog.Default(), + metrics: metrics.New(), + streamHandler: transformer.NewStreamHandler(), + requestTransformer: transformer.NewRequestTransformer(), + responseTransformer: transformer.NewResponseTransformer(), + } + + rawBody := json.RawMessage(`{ + "model": "kimi-k2.6", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal rawBody: %v", err) + } + + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + done := make(chan struct{}) + go func() { + defer close(done) + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, &core.NormalizedRequest{Stream: true}, chain, rawBody) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleStreaming did not return within 2s despite short streaming timeout") + } + + body := recorder.Body.String() + if !strings.Contains(body, "all streaming models failed") && !strings.Contains(body, "all upstream models failed") { + t.Errorf("unexpected output on streaming timeout: %s", body) + } +} + +func TestHandleStreaming_ClientContextCanceled_StopsFallback(t *testing.T) { + callCount := int32(0) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&callCount, 1) + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {}\n\n") + })) + defer upstream.Close() + + handler := newStreamingTestHandler(t, upstream.URL) + + rawBody := json.RawMessage(`{ + "model": "kimi-k2.6", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal rawBody: %v", err) + } + + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + {Provider: "opencode-go", ModelID: "glm-5"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + cancel() + + done := make(chan struct{}) + go func() { + defer close(done) + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, &core.NormalizedRequest{Stream: true}, chain, rawBody) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleStreaming did not return immediately on canceled client context") + } + + if atomic.LoadInt32(&callCount) != 0 { + t.Errorf("expected 0 upstream calls since client context was canceled, got %d", callCount) + } +} + +func TestHandleStreaming_ClientDisconnectsDuringStream_StopsFallback(t *testing.T) { + blockCh := make(chan struct{}) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {}\n\n") + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + <-blockCh + })) + defer upstream.Close() + defer close(blockCh) + + handler := newStreamingTestHandler(t, upstream.URL) + + rawBody := json.RawMessage(`{ + "model": "kimi-k2.6", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal rawBody: %v", err) + } + + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + {Provider: "opencode-go", ModelID: "glm-5"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + + done := make(chan struct{}) + go func() { + defer close(done) + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, &core.NormalizedRequest{Stream: true}, chain, rawBody) + }() + + time.Sleep(100 * time.Millisecond) + cancel() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleStreaming did not return after client disconnected") + } +} + +func TestHandleStreaming_PerModelTimeoutFallback(t *testing.T) { + callCount := int32(0) + upstreamBlock := make(chan struct{}) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&callCount, 1) + if count == 1 { + select { + case <-upstreamBlock: + case <-time.After(5 * time.Second): + } + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {}\n\n") + _, _ = fmt.Fprintf(w, "event: message_stop\ndata: {}\n\n") + })) + defer upstream.Close() + defer close(upstreamBlock) + + cfg := &config.Config{ + APIKey: "test-key", + OpenCodeGo: config.OpenCodeGoConfig{ + BaseURL: upstream.URL, + TimeoutMs: 300000, + StreamingTimeoutMs: 100, + }, + } + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + + handler := &MessagesHandler{ + client: ocClient, + logger: slog.Default(), + metrics: metrics.New(), + streamHandler: transformer.NewStreamHandler(), + requestTransformer: transformer.NewRequestTransformer(), + responseTransformer: transformer.NewResponseTransformer(), + } + + rawBody := json.RawMessage(`{ + "model": "kimi-k2.6", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal rawBody: %v", err) + } + + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + {Provider: "opencode-go", ModelID: "glm-5"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, handlerCancel := context.WithCancel(req.Context()) + defer handlerCancel() + + done := make(chan struct{}) + go func() { + defer close(done) + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, &core.NormalizedRequest{Stream: true}, chain, rawBody) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("handleStreaming did not complete within 5s") + } + + finalCount := atomic.LoadInt32(&callCount) + if finalCount != 2 { + t.Errorf("expected 2 upstream calls (1 timeout + 1 success), got %d", finalCount) + } +} + +func TestHandleNonStreaming_ParentContextCanceled_No502(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + "model": "kimi-k2.6", + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 5} + }`)) + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + Models: map[string]config.ModelConfig{ + "default": {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + }, + Fallbacks: map[string][]config.ModelConfig{ + "default": {{Provider: "opencode-go", ModelID: "glm-5"}}, + }, + OpenCodeGo: config.OpenCodeGoConfig{ + BaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + modelRouter := router.NewModelRouter(atomicCfg) + tokenCounter, err := token.NewCounter() + if err != nil { + t.Fatalf("NewCounter: %v", err) + } + + m := metrics.New() + handler := NewMessagesHandler( + ocClient, + nil, // providerRegistry + modelRouter, + router.NewFallbackHandler(slog.Default(), 3, 30*time.Second), + tokenCounter, + m, + ) + handler.logger = slog.Default() + + requestBody := `{ + "model": "claude-haiku-4-5-20251001", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Say hello"}] + }` + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + + ctx, cancel := context.WithCancel(req.Context()) + cancel() + req = req.WithContext(ctx) + + handler.HandleMessages(recorder, req) + + if recorder.Code == http.StatusBadGateway { + t.Errorf("should not return 502 for canceled context, got status %d", recorder.Code) + } + + snap := m.GetSnapshot() + if snap.RequestsFailed > 0 { + t.Errorf("failure count should be 0 for canceled context, got %d", snap.RequestsFailed) + } + + body := recorder.Body.String() + if strings.Contains(body, "all models failed") { + t.Errorf("should not contain 'all models failed' for client cancellation, got: %s", body) + } +} + +func TestHandleNonStreaming_ParentDeadlineExceeded_No502(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + "model": "kimi-k2.6", + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 5} + }`)) + })) + defer upstream.Close() + + cfg := &config.Config{ + APIKey: "test-key", + Models: map[string]config.ModelConfig{ + "default": {Provider: "opencode-go", ModelID: "kimi-k2.6"}, + }, + Fallbacks: map[string][]config.ModelConfig{ + "default": {{Provider: "opencode-go", ModelID: "glm-5"}}, + }, + OpenCodeGo: config.OpenCodeGoConfig{ + BaseURL: upstream.URL, + TimeoutMs: 5000, + }, + } + + atomicCfg := config.NewAtomicConfig(cfg, "/tmp/test-config.json") + ocClient := client.NewOpenCodeClient(atomicCfg) + modelRouter := router.NewModelRouter(atomicCfg) + tokenCounter, err := token.NewCounter() + if err != nil { + t.Fatalf("NewCounter: %v", err) + } + + m := metrics.New() + handler := NewMessagesHandler( + ocClient, + nil, // providerRegistry + modelRouter, + router.NewFallbackHandler(slog.Default(), 3, 30*time.Second), + tokenCounter, + m, + ) + handler.logger = slog.Default() + + requestBody := `{ + "model": "claude-haiku-4-5-20251001", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Say hello"}] + }` + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + + ctx, cancel := context.WithDeadline(req.Context(), time.Now().Add(-1*time.Second)) + defer cancel() + req = req.WithContext(ctx) + + handler.HandleMessages(recorder, req) + + if recorder.Code == http.StatusBadGateway { + t.Errorf("should not return 502 for deadline exceeded, got status %d", recorder.Code) + } + snap := m.GetSnapshot() + if snap.RequestsFailed > 0 { + t.Errorf("failure count should be 0 for deadline exceeded, got %d", snap.RequestsFailed) + } + + body := recorder.Body.String() + if strings.Contains(body, "all models failed") { + t.Errorf("should not contain 'all models failed' for deadline exceeded, got: %s", body) + } +} + +func TestResponseWriter_ConcurrentWrites(t *testing.T) { + recorder := httptest.NewRecorder() + rw := &responseWriter{ResponseWriter: recorder} + + var wg sync.WaitGroup + const goroutines = 10 + const writesPerGoroutine = 100 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < writesPerGoroutine; j++ { + rw.Write([]byte(fmt.Sprintf("goroutine-%d-write-%d\n", id, j))) + } + }(i) + } + wg.Wait() + + output := recorder.Body.String() + lines := strings.Split(strings.TrimSpace(output), "\n") + expectedLines := goroutines * writesPerGoroutine + if len(lines) != expectedLines { + t.Errorf("got %d lines, want %d (possible data loss from unsynchronized writes)", len(lines), expectedLines) + } +} + +func TestHandleStreaming_AnthropicRaw_NoKeepaliveInjection(t *testing.T) { + blockCh := make(chan struct{}) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + _, _ = fmt.Fprintf(w, "event: message_start\ndata: {\"type\":\"message_start\"}\n\n") + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + select { + case <-blockCh: + case <-time.After(10 * time.Second): + } + _, _ = fmt.Fprintf(w, "event: content_block_delta\ndata: {\"type\":\"content_block_delta\"}\n\n") + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + })) + defer upstream.Close() + + handler := newStreamingTestHandler(t, upstream.URL) + + rawBody := json.RawMessage(`{ + "model": "claude-opus-4-8", + "stream": true, + "max_tokens": 256, + "messages": [{"role":"user","content":"hello"}] + }`) + + var anthropicReq types.MessageRequest + if err := json.Unmarshal(rawBody, &anthropicReq); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + chain := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "minimax-m3"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + done := make(chan struct{}) + go func() { + defer close(done) + handler.handleStreaming(recorder, req.WithContext(ctx), &anthropicReq, &core.NormalizedRequest{Stream: true}, chain, rawBody) + }() + + time.Sleep(1000 * time.Millisecond) + close(blockCh) + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("handleStreaming did not return after unblocking upstream") + } + + body := recorder.Body.String() + + if !strings.Contains(body, "message_start") { + t.Error("output missing message_start event") + } + if !strings.Contains(body, "content_block_delta") { + t.Error("output missing content_block_delta event") + } + + if strings.Contains(body, ":keepalive") { + t.Errorf("keepalive comment leaked into Anthropic raw stream output (concurrent write bug):\n%s", body) + } +} diff --git a/internal/provider/opencode_go.go b/internal/provider/opencode_go.go index 4161733..fe6f43c 100644 --- a/internal/provider/opencode_go.go +++ b/internal/provider/opencode_go.go @@ -63,10 +63,14 @@ func (p *OpenCodeGoProvider) WireFormat(modelID string) core.WireFormat { return core.WireFormatOpenAIChat } -// isAnthropicNativeGo returns true for Go provider models that require the -// Anthropic Messages endpoint rather than the OpenAI Chat Completions endpoint. func isAnthropicNativeGo(modelID string) bool { - return modelID == "qwen3.7-max" + switch modelID { + case "minimax-m2.5", "minimax-m2.7", "minimax-m3", + "qwen3.5-plus", "qwen3.6-plus", "qwen3.7-plus", "qwen3.7-max": + return true + default: + return false + } } // RoundTripName returns the model ID to use in the upstream request. diff --git a/internal/router/fallback.go b/internal/router/fallback.go index c0f3435..ea5ecb3 100644 --- a/internal/router/fallback.go +++ b/internal/router/fallback.go @@ -175,6 +175,13 @@ func (h *FallbackHandler) ExecuteWithFallback( totalModels := len(models) for i, model := range models { + if err := ctx.Err(); err != nil { + h.logger.Info("request context canceled, stopping fallback attempts", + "error", err, + ) + return nil, nil, err + } + cb := h.getCircuitBreaker(model.ModelID) // Skip models with open circuit breakers @@ -208,6 +215,14 @@ func (h *FallbackHandler) ExecuteWithFallback( }, body, nil } + if errCtx := ctx.Err(); errCtx != nil { + h.logger.Info("request context canceled after model attempt, stopping fallback", + "model", model.ModelID, + "error", errCtx, + ) + return nil, nil, errCtx + } + if IsRetryableError(err) { cb.RecordFailure() h.logger.Warn("model failed, trying fallback", diff --git a/internal/router/fallback_test.go b/internal/router/fallback_test.go index 54e2021..3ab1f72 100644 --- a/internal/router/fallback_test.go +++ b/internal/router/fallback_test.go @@ -4,7 +4,9 @@ import ( "context" "errors" "fmt" + "log/slog" "testing" + "time" "github.com/routatic/proxy/internal/config" ) @@ -153,3 +155,268 @@ func TestExecuteWithFallback_NonRetryableThenRetryable(t *testing.T) { t.Errorf("model-b circuit should be open after retryable error, got %v", cbB.State()) } } + +func TestExecuteWithFallback_StopsOnCanceledContext(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 3, 30*time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + {Provider: "opencode-go", ModelID: "model-b"}, + } + + callCount := 0 + _, _, err := handler.ExecuteWithFallback(ctx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + callCount++ + return []byte("ok"), nil + }, + ) + + if callCount != 0 { + t.Errorf("executor called %d times, want 0 (canceled context must stop immediately)", callCount) + } + if err == nil { + t.Error("expected error for canceled context, got nil") + } + + states := handler.GetCircuitStates() + if len(states) > 0 { + t.Errorf("expected no circuit breakers created, got %d", len(states)) + } +} + +func TestExecuteWithFallback_StopsOnCanceledAfterFirstModel(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 3, 30*time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + {Provider: "opencode-go", ModelID: "model-b"}, + } + + callCount := 0 + _, _, err := handler.ExecuteWithFallback(ctx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + callCount++ + if callCount == 1 { + cancel() + return nil, context.Canceled + } + return []byte("ok"), nil + }, + ) + + if callCount != 1 { + t.Errorf("executor called %d times, want 1 (should stop after parent context canceled)", callCount) + } + if err == nil { + t.Error("expected error for canceled context, got nil") + } + + states := handler.GetCircuitStates() + if _, ok := states["model-b"]; ok { + t.Error("model-b should not have a circuit breaker entry") + } +} + +func TestExecuteWithFallback_PerModelTimeoutFallback(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 3, 30*time.Second) + + parentCtx, parentCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer parentCancel() + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + {Provider: "opencode-go", ModelID: "model-b"}, + } + + callCount := 0 + result, body, err := handler.ExecuteWithFallback(parentCtx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + callCount++ + if callCount == 1 { + return nil, context.DeadlineExceeded + } + return []byte("success-" + model.ModelID), nil + }, + ) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if callCount != 2 { + t.Errorf("executor called %d times, want 2 (first timed out, second succeeds)", callCount) + } + if result.ModelID != "model-b" { + t.Errorf("result model = %s, want model-b", result.ModelID) + } + if string(body) != "success-model-b" { + t.Errorf("body = %s, want success-model-b", string(body)) + } +} + +func TestExecuteWithFallback_RealPerModelTimeout(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 3, 30*time.Second) + + parentCtx, parentCancel := context.WithCancel(context.Background()) + defer parentCancel() + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + {Provider: "opencode-go", ModelID: "model-b"}, + } + + callCount := 0 + result, body, err := handler.ExecuteWithFallback(parentCtx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + callCount++ + if callCount == 1 { + attemptCtx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel() + <-attemptCtx.Done() + return nil, attemptCtx.Err() + } + return []byte("fallback-success"), nil + }, + ) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if callCount != 2 { + t.Errorf("executor called %d times, want 2", callCount) + } + if result.ModelID != "model-b" { + t.Errorf("result model = %s, want model-b", result.ModelID) + } + if string(body) != "fallback-success" { + t.Errorf("body = %s, want fallback-success", string(body)) + } +} + +func TestExecuteWithFallback_CircuitBreakerDoesNotCountClientCancellation(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 1, 30*time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + } + + callCount := 0 + _, _, err := handler.ExecuteWithFallback(ctx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + callCount++ + cancel() + return nil, context.Canceled + }, + ) + + if callCount != 1 { + t.Errorf("executor called %d times, want 1", callCount) + } + if err == nil { + t.Error("expected error for canceled context") + } + + states := handler.GetCircuitStates() + if state, ok := states["model-a"]; ok { + if state == "open" { + t.Error("model-a circuit breaker should NOT be open for client cancellation") + } + } +} + +func TestExecuteWithFallback_RealModelFailurePenalizesCircuitBreaker(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 1, 30*time.Second) + + ctx := context.Background() + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + } + + _, _, _ = handler.ExecuteWithFallback(ctx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + return nil, errors.New("upstream 500 internal server error") + }, + ) + + // model-a's circuit breaker should be open because of real failure + states := handler.GetCircuitStates() + state, ok := states["model-a"] + if !ok { + t.Fatal("model-a should have circuit breaker entry") + } + if state != "open" { + t.Errorf("model-a circuit breaker state = %s, want open", state) + } +} + +func TestExecuteWithFallback_ParentDeadlineExceededNotPenalized(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 1, 30*time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + time.Sleep(10 * time.Millisecond) // let parent timeout expire + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + } + + _, _, err := handler.ExecuteWithFallback(ctx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + return nil, nil + }, + ) + + if err == nil { + t.Error("expected error for deadline exceeded context") + } + + states := handler.GetCircuitStates() + if state, ok := states["model-a"]; ok && state == "open" { + t.Error("model-a circuit breaker should NOT be open for parent deadline exceeded") + } +} + +func TestExecuteWithFallback_AllModelsFailRecordsFailures(t *testing.T) { + logger := slog.Default() + handler := NewFallbackHandler(logger, 2, 30*time.Second) + + ctx := context.Background() + + models := []config.ModelConfig{ + {Provider: "opencode-go", ModelID: "model-a"}, + {Provider: "opencode-go", ModelID: "model-b"}, + } + + _, _, err := handler.ExecuteWithFallback(ctx, models, + func(ctx context.Context, model config.ModelConfig) ([]byte, error) { + return nil, errors.New("upstream error") + }, + ) + + if err == nil { + t.Error("expected error for all models failed") + } + + states := handler.GetCircuitStates() + if _, ok := states["model-a"]; !ok { + t.Error("model-a should have circuit breaker entry") + } + if _, ok := states["model-b"]; !ok { + t.Error("model-b should have circuit breaker entry") + } +} diff --git a/internal/transformer/request.go b/internal/transformer/request.go index e1749ff..9473151 100644 --- a/internal/transformer/request.go +++ b/internal/transformer/request.go @@ -589,6 +589,9 @@ func (t *RequestTransformer) transformTools(tools []types.Tool) []types.ToolDef var result []types.ToolDef for _, tool := range tools { + if strings.TrimSpace(tool.Name) == "" { + continue + } // InputSchema is already json.RawMessage, use it directly schema := tool.InputSchema switch { diff --git a/internal/transformer/request_test.go b/internal/transformer/request_test.go index 32f8941..4048984 100644 --- a/internal/transformer/request_test.go +++ b/internal/transformer/request_test.go @@ -1583,3 +1583,184 @@ func TestTransformTools_HandlesWhitespaceNullSchema(t *testing.T) { t.Fatalf("whitespace-null schema should fall back to default properties: %s", params) } } + +func TestTransformTools_SkipsEmptyName(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "", Description: "empty name", InputSchema: json.RawMessage(`{"type":"object"}`)}, + {Name: "Bash", Description: "valid tool", InputSchema: json.RawMessage(`{"type":"object"}`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d (empty-name tool should be skipped)", got, want) + } + if got, want := result[0].Function.Name, "Bash"; got != want { + t.Fatalf("result[0].Name = %q, want %q", got, want) + } +} + +func TestTransformTools_SkipsWhitespaceOnlyName(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: " ", Description: "whitespace name", InputSchema: json.RawMessage(`{"type":"object"}`)}, + {Name: "Bash", Description: "valid tool", InputSchema: json.RawMessage(`{"type":"object"}`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d (whitespace-name tool should be skipped)", got, want) + } +} + +func TestTransformTools_FillsEmptySchema(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Bash", Description: "no schema", InputSchema: nil}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"type":"object"`) { + t.Fatalf("parameters missing type=object: %s", params) + } + if !strings.Contains(params, `"additionalProperties":false`) { + t.Fatalf("parameters missing additionalProperties=false: %s", params) + } +} + +func TestTransformTools_FillsNullSchema(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Bash", Description: "null schema", InputSchema: json.RawMessage(`null`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"type":"object"`) { + t.Fatalf("null schema should become type=object: %s", params) + } +} + +func TestTransformTools_FillsEmptyObjectSchema(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Bash", Description: "empty object schema", InputSchema: json.RawMessage(`{}`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"type":"object"`) { + t.Fatalf("empty object schema should get type=object: %s", params) + } + if !strings.Contains(params, `"additionalProperties":false`) { + t.Fatalf("empty object schema should get additionalProperties=false: %s", params) + } +} + +func TestTransformTools_FillsMissingType(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Search", Description: "schema without type", InputSchema: json.RawMessage(`{"properties":{"query":{"type":"string"}}}`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"type":"object"`) { + t.Fatalf("schema missing type should get type=object: %s", params) + } + if !strings.Contains(params, `"query"`) { + t.Fatalf("existing properties should be preserved: %s", params) + } +} + +func TestTransformTools_FillsMissingProperties(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "NoOp", Description: "schema without properties", InputSchema: json.RawMessage(`{"type":"object"}`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"properties"`) { + t.Fatalf("schema missing properties should get properties={}: %s", params) + } +} + +func TestTransformTools_RecoversFromInvalidJSON(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Bash", Description: "malformed JSON", InputSchema: json.RawMessage(`{invalid`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d (malformed schema should be replaced, not skipped)", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"type":"object"`) { + t.Fatalf("malformed schema should be replaced with valid schema: %s", params) + } +} + +func TestTransformTools_PreservesValidSchema(t *testing.T) { + transformer := NewRequestTransformer() + originalSchema := json.RawMessage(`{"type":"object","properties":{"cmd":{"type":"string","description":"The command"}},"required":["cmd"]}`) + tools := []types.Tool{ + {Name: "Bash", Description: "run a command", InputSchema: originalSchema}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"cmd"`) { + t.Fatalf("valid schema properties should be preserved: %s", params) + } + if !strings.Contains(params, `"required"`) { + t.Fatalf("valid schema required should be preserved: %s", params) + } + if !strings.Contains(params, `"type":"string"`) { + t.Fatalf("valid schema nested type should be preserved: %s", params) + } +} + +func TestTransformTools_PreservesAdditionalPropertiesWhenSet(t *testing.T) { + transformer := NewRequestTransformer() + tools := []types.Tool{ + {Name: "Flexible", Description: "allows extra props", InputSchema: json.RawMessage(`{"type":"object","properties":{"a":{"type":"string"}},"additionalProperties":true}`)}, + } + + result := transformer.transformTools(tools) + if got, want := len(result), 1; got != want { + t.Fatalf("len(result) = %d, want %d", got, want) + } + + params := string(result[0].Function.Parameters) + if !strings.Contains(params, `"additionalProperties":true`) { + t.Fatalf("existing additionalProperties should be preserved: %s", params) + } +} diff --git a/internal/transformer/stream_test.go b/internal/transformer/stream_test.go index 60c496d..f0a78bf 100644 --- a/internal/transformer/stream_test.go +++ b/internal/transformer/stream_test.go @@ -1072,6 +1072,55 @@ func TestProxyStream_EOFFallbackStopReasonToolUse(t *testing.T) { } } +// TestProxyStream_ToolUseFirstContentBlock verifies that when the first +// assistant output is a direct tool call (no preceding text or reasoning), +// the tool_use block is emitted at index 0 per Anthropic SSE spec. +func TestProxyStream_ToolUseFirstContentBlock(t *testing.T) { + handler := NewStreamHandler() + w := newMockResponseWriter() + body := sseLines( + `{"choices":[{"delta":{"tool_calls":[{"index":0,"id":"toolu_abc","type":"function","function":{"name":"read_file","arguments":""}}]}}]}`, + `{"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"/tmp/x\"}"}}]}}]}`, + `{"choices":[{"delta":{},"finish_reason":"tool_use"}]}`, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := handler.ProxyStream(w, body, "kimi-k2.6", ctx, 0, cancel); err != nil { + t.Fatalf("ProxyStream error: %v", err) + } + + events := parseSSEEvents(t, w.buf.String()) + + // 0: message_start + // 1: content_block_start (index 0, type tool_use) — first content block + // 2: content_block_delta (index 0) + // 3: content_block_stop (index 0) + // 4: message_delta + // 5: message_stop + if len(events) != 6 { + t.Fatalf("expected 6 events, got %d: %+v", len(events), events) + } + + if events[1].Type != "content_block_start" { + t.Fatalf("event[1].Type = %q, want content_block_start", events[1].Type) + } + if events[1].ContentBlock == nil || events[1].ContentBlock.Type != "tool_use" { + t.Fatalf("event[1].ContentBlock = %+v, want tool_use", events[1].ContentBlock) + } + if events[1].Index == nil || *events[1].Index != 0 { + t.Fatalf("tool_use content_block_start index = %v, want 0", events[1].Index) + } + + if events[3].Type != "content_block_stop" || events[3].Index == nil || *events[3].Index != 0 { + t.Fatalf("tool_use content_block_stop index = %v, want 0", events[3].Index) + } + if events[4].Type != "message_delta" || events[4].Delta == nil || events[4].Delta.StopReason != "tool_use" { + t.Errorf("event[4] = %+v, want message_delta(tool_use)", events[4]) + } +} + // helpers func mustJSON(t *testing.T, v any) string { diff --git a/walkthrough.md b/walkthrough.md new file mode 100644 index 0000000..5760bcd --- /dev/null +++ b/walkthrough.md @@ -0,0 +1,46 @@ +# Walkthrough - PR #80 Integration and Timeout Config Bug Fix + +This integration completes 100% of the features, test cases, and documentation updates from PR #80 (`fix: stabilize Anthropic-native streaming, timeout handling, and fallback cancellation`), and addresses the issues surfaced during the RE-REVIEW pass. + +## Changes and Integrations Applied + +### 1. Fix for Timeout Config Overlap (BUG 1 & BUG 2) +- **Problem**: Previously, when a user configured `"streaming_timeout_ms": 600000` (total stream timeout) without setting `"stream_timeout_ms"` (idle timeout), the loader would silently assign `StreamTimeoutMs = TimeoutMs` (default 300000ms). As a result, the idle watchdog would fire at 300s and effectively neutralize the user's 600s total stream timeout whenever a stream stalled or went idle. +- **Fix**: Updated `applyDefaults` in [`loader.go`](internal/config/loader.go). When `StreamTimeoutMs` is unset, it now inherits the user-supplied `StreamingTimeoutMs` first (when present), and only falls back to `TimeoutMs` if neither is set. +- **Test coverage**: Added `TestDefaults_StreamingTimeoutFallback` to [`loader_test.go`](internal/config/loader_test.go) to exercise the real JSON file parsing path and assert the fallback behavior. + +### 2. Confirmation on MiniMax and Qwen-plus Routing (BUG 3) +- The expansion of `IsAnthropicModel` (in [`opencode.go`](internal/client/opencode.go)) and `isAnthropicNativeGo` (in [`opencode_go.go`](internal/provider/opencode_go.go)) to return `true` for `minimax-m2.5/2.7/3` and the `qwen-plus` family is an intentional bug fix. +- These models reject OpenAI-format streaming with tools on OpenCode Go with `400: invalid params, function name or parameters is empty`. Routing them through the Anthropic-native `/v1/messages` branch resolves the failure mode at its root. +- Test cases in `internal/client/opencode_test.go` were updated in lockstep to reflect this new behavior. + +### 3. Other Improvements and Integrations from PR #80 +- **Heartbeat Suppression**: Pauses the SSE keepalive ticker while copying a raw Anthropic stream via the `heartbeatPaused` flag. +- **Early Cancellation**: Cancels fallback attempts as soon as the parent context is canceled in `ExecuteWithFallback` ([`fallback.go`](internal/router/fallback.go)). +- **`transformTools` Hardening**: Skips tools with missing names, validates that the schema `type` is `object`, and adds a guard for null/empty `input_schema`. +- **Documentation**: Added the **Streaming Scenario Routing** section to [`CONFIGURATION.md`](CONFIGURATION.md) and [`README.md`](README.md). + +--- + +## Validation Results + +All test suites on the current workspace pass at 100%: + +```bash +go test ./... +``` + +Output: + +```text +ok github.com/routatic/proxy/internal/client 0.007s +ok github.com/routatic/proxy/internal/config 1.224s +ok github.com/routatic/proxy/internal/daemon (cached) +ok github.com/routatic/proxy/internal/handlers 1.503s +ok github.com/routatic/proxy/internal/router (cached) +ok github.com/routatic/proxy/internal/token (cached) +ok github.com/routatic/proxy/internal/transformer (cached) +ok github.com/routatic/proxy/pkg/types (cached) +``` + +All changes remain in the Unstaged state. No commits have been created. From 4f6f96ee6ae20c70b43d12e4968ff699d11586fa Mon Sep 17 00:00:00 2001 From: hungcuong9125 Date: Sat, 20 Jun 2026 21:10:26 +0700 Subject: [PATCH 5/8] fix: address lint issues in test files - Explicitly discard Write return value in concurrent test to satisfy Go vet - Remove redundant type assertion on NewCtxReader return value --- internal/handlers/messages_test.go | 2 +- internal/transformer/ctxio_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/handlers/messages_test.go b/internal/handlers/messages_test.go index ab3b40a..c3f089e 100644 --- a/internal/handlers/messages_test.go +++ b/internal/handlers/messages_test.go @@ -1484,7 +1484,7 @@ func TestResponseWriter_ConcurrentWrites(t *testing.T) { go func(id int) { defer wg.Done() for j := 0; j < writesPerGoroutine; j++ { - rw.Write([]byte(fmt.Sprintf("goroutine-%d-write-%d\n", id, j))) + _, _ = rw.Write([]byte(fmt.Sprintf("goroutine-%d-write-%d\n", id, j))) } }(i) } diff --git a/internal/transformer/ctxio_test.go b/internal/transformer/ctxio_test.go index b2b395b..a6b5d35 100644 --- a/internal/transformer/ctxio_test.go +++ b/internal/transformer/ctxio_test.go @@ -15,7 +15,7 @@ func TestNewCtxReader_PassesThroughUncanceled(t *testing.T) { in := strings.NewReader("hello world") r := NewCtxReader(ctx, in) - got, err := io.ReadAll(r.(io.Reader)) + got, err := io.ReadAll(r) if err != nil { t.Fatalf("ReadAll err = %v, want nil", err) } From d3de1d809c9b3c0f5a059d014a8b8ce4ba3f80d2 Mon Sep 17 00:00:00 2001 From: samuel tuyizere Date: Sat, 20 Jun 2026 16:31:59 +0200 Subject: [PATCH 6/8] fix: enhance error handling with APIError type and improve retry logic --- internal/client/opencode.go | 20 ++++++++--- internal/handlers/messages.go | 2 -- internal/router/fallback.go | 25 +++++++------ internal/router/fallback_test.go | 62 ++++++++++++++++---------------- internal/transformer/stream.go | 3 ++ 5 files changed, 64 insertions(+), 48 deletions(-) diff --git a/internal/client/opencode.go b/internal/client/opencode.go index d84616e..b152377 100644 --- a/internal/client/opencode.go +++ b/internal/client/opencode.go @@ -21,6 +21,18 @@ const ( ProviderOpenCodeZen = "opencode-zen" ) +// APIError represents an HTTP API error returned by an upstream provider. +// Callers should use errors.As to check for this type and inspect StatusCode +// for classification (4xx non-retryable, 5xx retryable, etc.). +type APIError struct { + StatusCode int + Body string +} + +func (e *APIError) Error() string { + return fmt.Sprintf("API error %d: %s", e.StatusCode, e.Body) +} + // OpenCodeClient handles communication with OpenCode Go and Zen APIs. type OpenCodeClient struct { atomic *config.AtomicConfig @@ -290,7 +302,7 @@ func (c *OpenCodeClient) ChatCompletion( if resp.StatusCode >= http.StatusBadRequest { bodyBytes, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + return nil, &APIError{StatusCode: resp.StatusCode, Body: string(bodyBytes)} } return resp, nil @@ -381,7 +393,7 @@ func (c *OpenCodeClient) SendAnthropicRequest( if resp.StatusCode >= http.StatusBadRequest { bodyBytes, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + return nil, &APIError{StatusCode: resp.StatusCode, Body: string(bodyBytes)} } return resp, nil @@ -417,7 +429,7 @@ func (c *OpenCodeClient) ResponsesCompletion( if resp.StatusCode >= http.StatusBadRequest { bodyBytes, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + return nil, &APIError{StatusCode: resp.StatusCode, Body: string(bodyBytes)} } return resp, nil @@ -498,7 +510,7 @@ func (c *OpenCodeClient) GeminiCompletion( if resp.StatusCode >= http.StatusBadRequest { bodyBytes, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + return nil, &APIError{StatusCode: resp.StatusCode, Body: string(bodyBytes)} } return resp, nil diff --git a/internal/handlers/messages.go b/internal/handlers/messages.go index 2e07dda..f936135 100644 --- a/internal/handlers/messages.go +++ b/internal/handlers/messages.go @@ -466,7 +466,6 @@ func (h *MessagesHandler) handleStreaming( atomic.StoreInt32(&heartbeatPaused, 0) } if errProxy != nil { - _ = streamBody.Close() if errProxy == transformer.ErrClientDisconnected { if clientCtx.Err() != nil { h.logger.Debug("client disconnected during stream") @@ -480,7 +479,6 @@ func (h *MessagesHandler) handleStreaming( continue } - _ = streamBody.Close() recordStreamSuccess(model) return } diff --git a/internal/router/fallback.go b/internal/router/fallback.go index ea5ecb3..b201cc9 100644 --- a/internal/router/fallback.go +++ b/internal/router/fallback.go @@ -3,12 +3,14 @@ package router import ( "context" + "errors" "fmt" "log/slog" "strings" "sync" "time" + "github.com/routatic/proxy/internal/client" "github.com/routatic/proxy/internal/config" ) @@ -265,18 +267,21 @@ func IsRetryableError(err error) bool { return false } - errStr := err.Error() - - // 4xx client errors are not retryable — the request format itself is invalid - // for that model, and retrying won't fix it. - if strings.Contains(errStr, "API error 4") { - return false + // APIError from the client carries the HTTP status code — use it directly + // instead of string matching, so error format changes upstream can't + // silently break the classification. + var apiErr *client.APIError + if errors.As(err, &apiErr) { + // 4xx client errors are not retryable — the request format itself is + // invalid for that model, and retrying won't fix it. This includes 429 + // (rate limit) so the circuit breaker doesn't open for rate limits. + return apiErr.StatusCode >= 500 } - // Retry on network errors, timeouts, rate limits (from non-4xx paths), - // and server errors (5xx). 4xx client errors are already excluded by - // the "API error 4" check above — 429 is correctly non-retryable, so - // the circuit breaker doesn't open for rate limits. + // For non-API errors (network errors, timeouts, etc.), fall back to + // pattern matching on the error string. + errStr := err.Error() + retryable := []string{ "timeout", "connection refused", diff --git a/internal/router/fallback_test.go b/internal/router/fallback_test.go index 3ab1f72..417eb61 100644 --- a/internal/router/fallback_test.go +++ b/internal/router/fallback_test.go @@ -3,51 +3,49 @@ package router import ( "context" "errors" - "fmt" "log/slog" "testing" "time" + "github.com/routatic/proxy/internal/client" "github.com/routatic/proxy/internal/config" ) func TestIsRetryableError_ClientsErrorsNotRetryable(t *testing.T) { tests := []struct { - err string + err error want bool }{ // 4xx errors should NOT be retryable - {err: "API error 400: bad request", want: false}, - {err: "API error 401: unauthorized", want: false}, - {err: "API error 403: forbidden", want: false}, - {err: "API error 404: not found", want: false}, - {err: "API error 422: unprocessable", want: false}, - {err: "API error 429: rate limit", want: false}, - - // 5xx and network errors should be retryable (existing behavior) - {err: "API error 500: internal error", want: true}, - {err: "API error 502: bad gateway", want: true}, - {err: "API error 503: service unavailable", want: true}, - {err: "request timeout", want: true}, - {err: "connection refused", want: true}, - {err: "connection reset by peer", want: true}, - {err: "rate limit exceeded", want: true}, + {err: &client.APIError{StatusCode: 400, Body: "bad request"}, want: false}, + {err: &client.APIError{StatusCode: 401, Body: "unauthorized"}, want: false}, + {err: &client.APIError{StatusCode: 403, Body: "forbidden"}, want: false}, + {err: &client.APIError{StatusCode: 404, Body: "not found"}, want: false}, + {err: &client.APIError{StatusCode: 422, Body: "unprocessable"}, want: false}, + {err: &client.APIError{StatusCode: 429, Body: "rate limit"}, want: false}, + + // 5xx errors should be retryable + {err: &client.APIError{StatusCode: 500, Body: "internal error"}, want: true}, + {err: &client.APIError{StatusCode: 502, Body: "bad gateway"}, want: true}, + {err: &client.APIError{StatusCode: 503, Body: "service unavailable"}, want: true}, + + // Non-API errors — fall back to string matching + {err: errors.New("request timeout"), want: true}, + {err: errors.New("connection refused"), want: true}, + {err: errors.New("connection reset by peer"), want: true}, + {err: errors.New("rate limit exceeded"), want: true}, // Edge cases - {err: "", want: false}, - {err: "random error", want: false}, - {err: "API error 400", want: false}, - {err: "API error 500", want: true}, + {err: errors.New(""), want: false}, + {err: errors.New("random error"), want: false}, + {err: errors.New("API error 400"), want: false}, + {err: errors.New("API error 500"), want: true}, } for _, tt := range tests { - t.Run(tt.err, func(t *testing.T) { - var err error - if tt.err != "" { - err = errors.New(tt.err) - } - if got := IsRetryableError(err); got != tt.want { - t.Errorf("IsRetryableError(%q) = %v, want %v", tt.err, got, tt.want) + t.Run(tt.err.Error(), func(t *testing.T) { + if got := IsRetryableError(tt.err); got != tt.want { + t.Errorf("IsRetryableError(%q) = %v, want %v", tt.err.Error(), got, tt.want) } }) } @@ -68,7 +66,7 @@ func TestExecuteWithFallback_NonRetryableDoesNotOpenCircuit(t *testing.T) { func(ctx context.Context, model config.ModelConfig) ([]byte, error) { attempts++ // Non-retryable 400 error — should NOT open circuit breaker - return nil, fmt.Errorf("API error 400: bad request") + return nil, &client.APIError{StatusCode: 400, Body: "bad request"} }, ) @@ -101,7 +99,7 @@ func TestExecuteWithFallback_RetryableOpensCircuit(t *testing.T) { models, func(ctx context.Context, model config.ModelConfig) ([]byte, error) { // Retryable 500 error — should open circuit breaker - return nil, fmt.Errorf("API error 500: internal error") + return nil, &client.APIError{StatusCode: 500, Body: "internal error"} }, ) @@ -132,10 +130,10 @@ func TestExecuteWithFallback_NonRetryableThenRetryable(t *testing.T) { callCount++ if callCount == 1 { // Non-retryable: model-a should NOT get circuit opened - return nil, fmt.Errorf("API error 400: bad request") + return nil, &client.APIError{StatusCode: 400, Body: "bad request"} } // Retryable: model-b should get circuit opened - return nil, fmt.Errorf("API error 500: internal error") + return nil, &client.APIError{StatusCode: 500, Body: "internal error"} }, ) diff --git a/internal/transformer/stream.go b/internal/transformer/stream.go index 3d71850..9f756ba 100644 --- a/internal/transformer/stream.go +++ b/internal/transformer/stream.go @@ -73,6 +73,7 @@ func (h *StreamHandler) ProxyStream( idleTimeout time.Duration, cancel context.CancelFunc, ) error { + defer func() { _ = openaiResp.Close() }() flusher, ok := w.(http.Flusher) if !ok { return fmt.Errorf("streaming not supported by response writer") @@ -690,6 +691,7 @@ func (h *StreamHandler) ProxyResponsesStream( idleTimeout time.Duration, cancel context.CancelFunc, ) error { + defer func() { _ = responsesResp.Close() }() flusher, ok := w.(http.Flusher) if !ok { return fmt.Errorf("streaming not supported by response writer") @@ -881,6 +883,7 @@ func (h *StreamHandler) ProxyGeminiStream( idleTimeout time.Duration, cancel context.CancelFunc, ) error { + defer func() { _ = geminiResp.Close() }() flusher, ok := w.(http.Flusher) if !ok { return fmt.Errorf("streaming not supported by response writer") From 90ede54c4122dfa19153777c2ef7e1ece749cbd9 Mon Sep 17 00:00:00 2001 From: samuel tuyizere Date: Sat, 20 Jun 2026 16:34:48 +0200 Subject: [PATCH 7/8] fix: optimize concurrent writes in TestResponseWriter by using Fprintf --- internal/handlers/messages_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/handlers/messages_test.go b/internal/handlers/messages_test.go index c3f089e..74fff39 100644 --- a/internal/handlers/messages_test.go +++ b/internal/handlers/messages_test.go @@ -1484,7 +1484,7 @@ func TestResponseWriter_ConcurrentWrites(t *testing.T) { go func(id int) { defer wg.Done() for j := 0; j < writesPerGoroutine; j++ { - _, _ = rw.Write([]byte(fmt.Sprintf("goroutine-%d-write-%d\n", id, j))) + _, _ = fmt.Fprintf(rw, "goroutine-%d-write-%d\n", id, j) } }(i) } From e0899ef37ebd075fb18732fb212f7906c036bb78 Mon Sep 17 00:00:00 2001 From: samuel tuyizere Date: Sat, 20 Jun 2026 16:57:52 +0200 Subject: [PATCH 8/8] fix: improve error handling in Anthropic API calls with APIError type and streamline SSE error messages --- internal/handlers/messages.go | 8 +------- internal/provider/opencode_go.go | 7 ++++--- internal/provider/opencode_zen.go | 9 +++++---- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/internal/handlers/messages.go b/internal/handlers/messages.go index f936135..c39740d 100644 --- a/internal/handlers/messages.go +++ b/internal/handlers/messages.go @@ -426,7 +426,7 @@ func (h *MessagesHandler) handleStreaming( } h.logger.Warn(action+" streaming failed", "model", model.ModelID, "error", err) if rw.ssePayloadWritten { - h.sendStreamError(rw, fmt.Sprintf("all upstream models failed after SSE payload started: %v", err)) + h.sendStreamError(rw, "all upstream models failed after SSE payload started") h.metrics.RecordFailure() return false // abort — cannot fallback after SSE payload started } @@ -584,7 +584,6 @@ func (h *MessagesHandler) handleStreaming( streamReader := transformer.NewCtxReadCloser(attemptCtx, streamBody) if err := h.streamHandler.ProxyStream(rw, streamReader, model.ModelID, attemptCtx, idleTimeout, cancelAttempt); err != nil { - _ = streamBody.Close() if err == transformer.ErrClientDisconnected { if clientCtx.Err() != nil { h.logger.Debug("client disconnected during stream") @@ -598,7 +597,6 @@ func (h *MessagesHandler) handleStreaming( continue } - _ = streamBody.Close() recordStreamSuccess(model) return } @@ -642,11 +640,9 @@ func (h *MessagesHandler) handleResponsesStreaming( streamReader := transformer.NewCtxReadCloser(ctx, streamBody) if err := h.streamHandler.ProxyResponsesStream(w, streamReader, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { - _ = streamBody.Close() return err } - _ = streamBody.Close() return nil } @@ -676,11 +672,9 @@ func (h *MessagesHandler) handleGeminiStreaming( streamReader := transformer.NewCtxReadCloser(ctx, streamBody) if err := h.streamHandler.ProxyGeminiStream(w, streamReader, model.ModelID, clientCtx, idleTimeout, cancel); err != nil { - _ = streamBody.Close() return err } - _ = streamBody.Close() return nil } diff --git a/internal/provider/opencode_go.go b/internal/provider/opencode_go.go index fe6f43c..789f390 100644 --- a/internal/provider/opencode_go.go +++ b/internal/provider/opencode_go.go @@ -9,6 +9,7 @@ import ( "net/http" "time" + "github.com/routatic/proxy/internal/client" "github.com/routatic/proxy/internal/config" "github.com/routatic/proxy/internal/core" "github.com/routatic/proxy/internal/transformer" @@ -201,7 +202,7 @@ func (p *OpenCodeGoProvider) executeAnthropic(ctx context.Context, req *core.Nor if resp.StatusCode >= http.StatusBadRequest { bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + return nil, &client.APIError{StatusCode: resp.StatusCode, Body: string(bodyBytes)} } body, err := io.ReadAll(resp.Body) @@ -244,7 +245,7 @@ func (p *OpenCodeGoProvider) streamAnthropic(ctx context.Context, req *core.Norm if resp.StatusCode >= http.StatusBadRequest { bodyBytes, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + return nil, &client.APIError{StatusCode: resp.StatusCode, Body: string(bodyBytes)} } return resp.Body, nil @@ -276,7 +277,7 @@ func (p *OpenCodeGoProvider) doRequest(ctx context.Context, endpoint, apiKey str if resp.StatusCode >= http.StatusBadRequest { bodyBytes, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + return nil, &client.APIError{StatusCode: resp.StatusCode, Body: string(bodyBytes)} } return resp, nil diff --git a/internal/provider/opencode_zen.go b/internal/provider/opencode_zen.go index 0664514..2030348 100644 --- a/internal/provider/opencode_zen.go +++ b/internal/provider/opencode_zen.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/routatic/proxy/internal/client" "github.com/routatic/proxy/internal/config" "github.com/routatic/proxy/internal/core" "github.com/routatic/proxy/internal/transformer" @@ -247,7 +248,7 @@ func (p *OpenCodeZenProvider) executeAnthropic(ctx context.Context, req *core.No if resp.StatusCode >= http.StatusBadRequest { bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + return nil, &client.APIError{StatusCode: resp.StatusCode, Body: string(bodyBytes)} } body, err := io.ReadAll(resp.Body) @@ -290,7 +291,7 @@ func (p *OpenCodeZenProvider) streamAnthropic(ctx context.Context, req *core.Nor if resp.StatusCode >= http.StatusBadRequest { bodyBytes, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + return nil, &client.APIError{StatusCode: resp.StatusCode, Body: string(bodyBytes)} } return resp.Body, nil @@ -436,7 +437,7 @@ func (p *OpenCodeZenProvider) doRequest(ctx context.Context, endpoint, apiKey st if resp.StatusCode >= http.StatusBadRequest { bodyBytes, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + return nil, &client.APIError{StatusCode: resp.StatusCode, Body: string(bodyBytes)} } return resp, nil @@ -463,7 +464,7 @@ func (p *OpenCodeZenProvider) doJSONRequest(ctx context.Context, endpoint, apiKe if resp.StatusCode >= http.StatusBadRequest { bodyBytes, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + return nil, &client.APIError{StatusCode: resp.StatusCode, Body: string(bodyBytes)} } return resp, nil