From 32e23b1ebc39ac5b7b7fbc9c93a9842f8cc809a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 19 Feb 2026 15:58:10 +0000 Subject: [PATCH 1/5] feat: add upstream request/response logging for passthrough routes --- apidump_integration_test.go | 127 +++++++++++++++++++++++++++--- intercept/apidump/apidump.go | 116 ++++++++++++++++++++------- intercept/apidump/apidump_test.go | 102 ++++++++++++++++++++++++ intercept/apidump/headers_test.go | 8 +- internal/testutil/mockprovider.go | 1 + passthrough.go | 5 +- provider/anthropic.go | 4 + provider/copilot.go | 4 + provider/openai.go | 4 + provider/provider.go | 4 + 10 files changed, 329 insertions(+), 46 deletions(-) diff --git a/apidump_integration_test.go b/apidump_integration_test.go index 29db3f2..e064ad2 100644 --- a/apidump_integration_test.go +++ b/apidump_integration_test.go @@ -49,32 +49,28 @@ func TestAPIDump(t *testing.T) { cases := []struct { name string fixture []byte - providerName string providersFunc func(addr, dumpDir string) []aibridge.Provider createRequestFunc createRequestFunc }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntSimple, - providerName: config.ProviderAnthropic, + name: "anthropic", + fixture: fixtures.AntSimple, providersFunc: func(addr, dumpDir string) []aibridge.Provider { return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)} }, createRequestFunc: createAnthropicMessagesReq, }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatSimple, - providerName: config.ProviderOpenAI, + name: "openai_chat_completions", + fixture: fixtures.OaiChatSimple, providersFunc: func(addr, dumpDir string) []aibridge.Provider { return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} }, createRequestFunc: createOpenAIChatCompletionsReq, }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiResponsesBlockingSimple, - providerName: config.ProviderOpenAI, + name: "openai_responses", + fixture: fixtures.OaiResponsesBlockingSimple, providersFunc: func(addr, dumpDir string) []aibridge.Provider { return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} }, @@ -176,3 +172,114 @@ func TestAPIDump(t *testing.T) { }) } } + +func TestAPIDumpPassthrough(t *testing.T) { + t.Parallel() + + const responseBody = `{"object":"list","data":[{"id":"gpt-4","object":"model"}]}` + + cases := []struct { + name string + providerFunc func(addr string, dumpDir string) aibridge.Provider + requestPath string + }{ + { + name: "anthropic", + providerFunc: func(addr string, dumpDir string) aibridge.Provider { + return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil) + }, + requestPath: "/anthropic/v1/models", + }, + { + name: "openai", + providerFunc: func(addr string, dumpDir string) aibridge.Provider { + return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) + }, + requestPath: "/openai/v1/models", + }, + { + name: "copilot", + providerFunc: func(addr string, dumpDir string) aibridge.Provider { + return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir}) + }, + requestPath: "/copilot/models", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(responseBody)) + })) + t.Cleanup(upstream.Close) + + dumpDir := t.TempDir() + + recorderClient := &testutil.MockRecorder{} + prov := tc.providerFunc(upstream.URL, dumpDir) + provs := []aibridge.Provider{prov} + b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + require.NoError(t, err) + + bridgeSrv := httptest.NewUnstartedServer(b) + t.Cleanup(bridgeSrv.Close) + bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibcontext.AsActor(ctx, userID, nil) + } + bridgeSrv.Start() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, bridgeSrv.URL+tc.requestPath, nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Find dump files in the passthrough directory. + passthroughDir := filepath.Join(dumpDir, tc.name, "passthrough") + var reqDumpFile, respDumpFile string + err = filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + if strings.HasSuffix(path, apidump.SuffixRequest) { + reqDumpFile = path + } else if strings.HasSuffix(path, apidump.SuffixResponse) { + respDumpFile = path + } + return nil + }) + require.NoError(t, err, "walking failed: %v", err) + require.NotEmpty(t, reqDumpFile, "request dump file should exist") + require.NotEmpty(t, respDumpFile, "response dump file should exist") + + // Verify request dump. + reqDumpData, err := os.ReadFile(reqDumpFile) + require.NoError(t, err) + dumpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(reqDumpData))) + require.NoError(t, err) + require.Equal(t, http.MethodGet, dumpReq.Method) + + // Verify response dump. + respDumpData, err := os.ReadFile(respDumpFile) + require.NoError(t, err) + dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil) + require.NoError(t, err) + require.Equal(t, http.StatusOK, dumpResp.StatusCode) + dumpRespBody, err := io.ReadAll(dumpResp.Body) + require.NoError(t, err) + require.JSONEq(t, responseBody, string(dumpRespBody)) + }) + } +} diff --git a/intercept/apidump/apidump.go b/intercept/apidump/apidump.go index 2e8355f..a78858e 100644 --- a/intercept/apidump/apidump.go +++ b/intercept/apidump/apidump.go @@ -41,17 +41,13 @@ func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, lo } d := &dumper{ - baseDir: baseDir, - provider: provider, - model: model, - interceptionID: interceptionID, - clk: clk, - logger: logger, + dumpPath: interceptDumpPath(baseDir, provider, model, interceptionID, clk), + logger: logger, } return func(req *http.Request, next MiddlewareNext) (*http.Response, error) { if err := d.dumpRequest(req); err != nil { - logger.Named("apidump").Warn(context.Background(), "failed to dump request", slog.Error(err)) + logger.Named("apidump").Warn(req.Context(), "failed to dump request", slog.Error(err)) } // TODO: https://github.com/coder/aibridge/issues/129 @@ -61,7 +57,7 @@ func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, lo } if err := d.dumpResponse(resp); err != nil { - logger.Named("apidump").Warn(context.Background(), "failed to dump response", slog.Error(err)) + logger.Named("apidump").Warn(req.Context(), "failed to dump response", slog.Error(err)) } return resp, nil @@ -69,16 +65,12 @@ func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, lo } type dumper struct { - baseDir string - provider string - model string - interceptionID uuid.UUID - clk quartz.Clock - logger slog.Logger + dumpPath string + logger slog.Logger } func (d *dumper) dumpRequest(req *http.Request) error { - dumpPath := d.path(SuffixRequest) + dumpPath := d.dumpPath + SuffixRequest if err := os.MkdirAll(filepath.Dir(dumpPath), 0o755); err != nil { return fmt.Errorf("create dump dir: %w", err) } @@ -98,25 +90,40 @@ func (d *dumper) dumpRequest(req *http.Request) error { // Build raw HTTP request format var buf bytes.Buffer - fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto) + _, err := fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto) + if err != nil { + return fmt.Errorf("write request uri: %w", err) + } d.writeRedactedHeaders(&buf, req.Header, sensitiveRequestHeaders, map[string]string{ "Content-Length": fmt.Sprintf("%d", len(prettyBody)), }) - fmt.Fprintf(&buf, "\r\n") + _, err = fmt.Fprintf(&buf, "\r\n") + if err != nil { + return fmt.Errorf("write request body: %w", err) + } buf.Write(prettyBody) return os.WriteFile(dumpPath, buf.Bytes(), 0o644) } func (d *dumper) dumpResponse(resp *http.Response) error { - dumpPath := d.path(SuffixResponse) + dumpPath := d.dumpPath + SuffixResponse // Build raw HTTP response headers var headerBuf bytes.Buffer - fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status) - d.writeRedactedHeaders(&headerBuf, resp.Header, sensitiveResponseHeaders, nil) - fmt.Fprintf(&headerBuf, "\r\n") + _, err := fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status) + if err != nil { + return fmt.Errorf("write response status: %w", err) + } + err = d.writeRedactedHeaders(&headerBuf, resp.Header, sensitiveResponseHeaders, nil) + if err != nil { + return err + } + _, err = fmt.Fprintf(&headerBuf, "\r\n") + if err != nil { + return fmt.Errorf("write response body: %w", err) + } // Wrap the response body to capture it as it streams if resp.Body != nil { @@ -141,7 +148,7 @@ func (d *dumper) dumpResponse(resp *http.Response) error { // for deterministic output. // `sensitive` and `overrides` must both supply keys in canoncialized form. // See [textproto.MIMEHeader]. -func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) { +func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) error { // Collect all header keys including overrides. headerKeys := make([]string, 0, len(headers)+len(overrides)) seen := make(map[string]struct{}, len(headers)+len(overrides)) @@ -163,7 +170,10 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv // If no values exist but we have an override, use that. if len(values) == 0 { if override, ok := overrides[key]; ok { - fmt.Fprintf(w, "%s: %s\r\n", key, override) + _, err := fmt.Fprintf(w, "%s: %s\r\n", key, override) + if err != nil { + return fmt.Errorf("write response header override: %w", err) + } } continue } @@ -175,16 +185,64 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv if isSensitive { value = redactHeaderValue(value) } - fmt.Fprintf(w, "%s: %s\r\n", key, value) + _, err := fmt.Fprintf(w, "%s: %s\r\n", key, value) + if err != nil { + return fmt.Errorf("write response headers: %w", err) + } } } + return nil +} + +// interceptDumpPath returns the base file path (without suffix) for an interception dump. +func interceptDumpPath(baseDir string, provider string, model string, interceptionID uuid.UUID, clk quartz.Clock) string { + safeModel := strings.ReplaceAll(model, "/", "-") + return filepath.Join(baseDir, provider, safeModel, fmt.Sprintf("%d-%s", clk.Now().UTC().UnixMilli(), interceptionID)) +} + +// passthroughDumpPath returns the base file path (without suffix) for a passthrough dump. +// A random UUID is generated for the filename. "passthrough" is used as the directory name +// in place of the model. +func passthroughDumpPath(baseDir string, provider string, clk quartz.Clock) string { + return filepath.Join(baseDir, provider, "passthrough", fmt.Sprintf("%d-%s", clk.Now().UTC().UnixMilli(), uuid.New())) +} + +// NewRoundTripperMiddleware returns http.RoundTripper that dumps requests and responses to files. +// If baseDir is empty, returns the original transport unchanged. +// Used for logging passed through requests. +func NewRoundTripperMiddleware(transport http.RoundTripper, baseDir string, provider string, logger slog.Logger, clk quartz.Clock) http.RoundTripper { + if baseDir == "" { + return transport + } + return &dumpRoundTripper{ + inner: transport, + dumper: dumper{ + dumpPath: passthroughDumpPath(baseDir, provider, clk), + logger: logger, + }, + } +} + +type dumpRoundTripper struct { + inner http.RoundTripper + dumper dumper } -// path returns the path to a request/response dump file for a given interception. -// suffix should be SuffixRequest or SuffixResponse. -func (d *dumper) path(suffix string) string { - safeModel := strings.ReplaceAll(d.model, "/", "-") - return filepath.Join(d.baseDir, d.provider, safeModel, fmt.Sprintf("%d-%s%s", d.clk.Now().UTC().UnixMilli(), d.interceptionID, suffix)) +func (rt *dumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if err := rt.dumper.dumpRequest(req); err != nil { + rt.dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough request", slog.Error(err)) + } + + resp, err := rt.inner.RoundTrip(req) + if err != nil { + return resp, err + } + + if err := rt.dumper.dumpResponse(resp); err != nil { + rt.dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough response", slog.Error(err)) + } + + return resp, nil } // prettyPrintJSON returns indented JSON if body is valid JSON, otherwise returns body as-is. diff --git a/intercept/apidump/apidump_test.go b/intercept/apidump/apidump_test.go index a7b4c5b..89e0a27 100644 --- a/intercept/apidump/apidump_test.go +++ b/intercept/apidump/apidump_test.go @@ -311,3 +311,105 @@ func TestMiddleware_AllSensitiveRequestHeaders(t *testing.T) { require.Contains(t, content, "Proxy-Authorization:") require.Contains(t, content, "X-Amz-Security-Token:") } + +func TestRoundTripperMiddleware(t *testing.T) { + t.Parallel() + + t.Run("empty_base_dir_returns_original_transport", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + inner := http.DefaultTransport + rt := NewRoundTripperMiddleware(inner, "", "openai", logger, quartz.NewMock(t)) + require.Equal(t, inner, rt) + }) + + t.Run("returns_error_from_inner_round_trip", func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + clk := quartz.NewMock(t) + + innerErr := io.ErrUnexpectedEOF + inner := &mockRoundTripper{ + roundTrip: func(_ *http.Request) (*http.Response, error) { + return nil, innerErr + }, + } + + rt := NewRoundTripperMiddleware(inner, tmpDir, "openai", logger, clk) + + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/models", nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.ErrorIs(t, err, innerErr) + require.Nil(t, resp) + }) + + t.Run("dumps_request_and_response", func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + clk := quartz.NewMock(t) + + inner := &mockRoundTripper{ + roundTrip: func(req *http.Request) (*http.Response, error) { + // Verify body is still readable after dump + body, err := io.ReadAll(req.Body) + require.NoError(t, err) + require.Equal(t, `{"request": true}`, string(body)) + + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Proto: "HTTP/1.1", + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewReader([]byte(`{"response": true}`))), + }, nil + }, + } + + rt := NewRoundTripperMiddleware(inner, tmpDir, "openai", logger, clk) + + req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/models", bytes.NewReader([]byte(`{"request": true}`))) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer sk-secret-key-12345") + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + + // Must read and close response body to trigger the streaming dump + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + // Verify files are in passthrough directory + passthroughDir := filepath.Join(tmpDir, "openai", "passthrough") + reqDumpPath := findDumpFile(t, passthroughDir, SuffixRequest) + reqContent, err := os.ReadFile(reqDumpPath) + require.NoError(t, err) + + require.Contains(t, string(reqContent), "POST") + require.Contains(t, string(reqContent), `"request": true`) + // Sensitive header should be redacted + require.NotContains(t, string(reqContent), "sk-secret-key-12345") + require.Contains(t, string(reqContent), "Authorization:") + + respDumpPath := findDumpFile(t, passthroughDir, SuffixResponse) + respContent, err := os.ReadFile(respDumpPath) + require.NoError(t, err) + + require.Contains(t, string(respContent), "200 OK") + require.Contains(t, string(respContent), `"response": true`) + }) +} + +type mockRoundTripper struct { + roundTrip func(*http.Request) (*http.Response, error) +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return m.roundTrip(req) +} diff --git a/intercept/apidump/headers_test.go b/intercept/apidump/headers_test.go index c9286ab..181eae2 100644 --- a/intercept/apidump/headers_test.go +++ b/intercept/apidump/headers_test.go @@ -101,12 +101,8 @@ func TestWriteRedactedHeaders(t *testing.T) { t.Parallel() d := &dumper{ - baseDir: "/tmp", - provider: "test", - model: "test", - interceptionID: uuid.New(), - clk: quartz.NewMock(t), - logger: slog.Make(), + dumpPath: interceptDumpPath("/tmp", "test", "test", uuid.New(), quartz.NewMock(t)), + logger: slog.Make(), } tests := []struct { diff --git a/internal/testutil/mockprovider.go b/internal/testutil/mockprovider.go index 6e268c6..a21ac6a 100644 --- a/internal/testutil/mockprovider.go +++ b/internal/testutil/mockprovider.go @@ -25,6 +25,7 @@ func (m *MockProvider) PassthroughRoutes() []string { return m. func (m *MockProvider) AuthHeader() string { return "Authorization" } func (m *MockProvider) InjectAuthHeader(h *http.Header) {} func (m *MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil } +func (m *MockProvider) APIDumpDir() string { return "" } func (m *MockProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) { if m.InterceptorFunc != nil { return m.InterceptorFunc(w, r, tracer) diff --git a/passthrough.go b/passthrough.go index 0dcef9c..66da45a 100644 --- a/passthrough.go +++ b/passthrough.go @@ -8,9 +8,11 @@ import ( "time" "cdr.dev/slog/v3" + "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/tracing" + "github.com/coder/quartz" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" @@ -102,7 +104,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met } // Transport tuned for streaming (no response header timeout). - proxy.Transport = &http.Transport{ + t := &http.Transport{ Proxy: http.ProxyFromEnvironment, ForceAttemptHTTP2: true, MaxIdleConns: 100, @@ -110,6 +112,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } + proxy.Transport = apidump.NewRoundTripperMiddleware(t, provider.APIDumpDir(), provider.Name(), logger, quartz.NewReal()) proxy.ServeHTTP(w, r) } diff --git a/provider/anthropic.go b/provider/anthropic.go index be12583..5195cec 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -133,3 +133,7 @@ func (p *Anthropic) InjectAuthHeader(headers *http.Header) { func (p *Anthropic) CircuitBreakerConfig() *config.CircuitBreaker { return p.cfg.CircuitBreaker } + +func (p *Anthropic) APIDumpDir() string { + return p.cfg.APIDumpDir +} diff --git a/provider/copilot.go b/provider/copilot.go index 34fab49..9b128ca 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -109,6 +109,10 @@ func (p *Copilot) CircuitBreakerConfig() *config.CircuitBreaker { return p.circuitBreaker } +func (p *Copilot) APIDumpDir() string { + return p.cfg.APIDumpDir +} + func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) { _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") defer tracing.EndSpanErr(span, &outErr) diff --git a/provider/openai.go b/provider/openai.go index 9f0bf70..730fc68 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -152,3 +152,7 @@ func (p *OpenAI) InjectAuthHeader(headers *http.Header) { func (p *OpenAI) CircuitBreakerConfig() *config.CircuitBreaker { return p.circuitBreaker } + +func (p *OpenAI) APIDumpDir() string { + return p.cfg.APIDumpDir +} diff --git a/provider/provider.go b/provider/provider.go index 0dbe352..f2a70f1 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -75,4 +75,8 @@ type Provider interface { // CircuitBreakerConfig returns the circuit breaker configuration for the provider. CircuitBreakerConfig() *config.CircuitBreaker + + // APIDumpDir returns the directory path for dumping API requests and responses. + // Empty string is returned when API dumping is not enabled. + APIDumpDir() string } From 991d4a90c70f1f5b6224e86ba46b349a8e7d55de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 20 Feb 2026 11:23:58 +0000 Subject: [PATCH 2/5] drive by fix: prettyPrintJSON non json edge case + new line --- intercept/apidump/apidump.go | 17 ++++++++++------- intercept/apidump/apidump_test.go | 17 +++++++++++++++-- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/intercept/apidump/apidump.go b/intercept/apidump/apidump.go index a78858e..533fec2 100644 --- a/intercept/apidump/apidump.go +++ b/intercept/apidump/apidump.go @@ -252,12 +252,15 @@ func prettyPrintJSON(body []byte) []byte { if len(body) == 0 { return body } - result := pretty.Pretty(body) - // pretty.Pretty returns a truncated/modified result for invalid JSON, - // so check if the result is valid JSON; if not, return the original. - if !json.Valid(result) { - return body + + result := body + if json.Valid(body) { + result = pretty.Pretty(body) + } + + // Add trailing newline if missing. + if !bytes.HasSuffix(result, []byte("\n")) { + result = append(result, []byte("\n")...) } - // Trim trailing newline added by pretty.Pretty. - return bytes.TrimSuffix(result, []byte("\n")) + return result } diff --git a/intercept/apidump/apidump_test.go b/intercept/apidump/apidump_test.go index 89e0a27..4e2acbd 100644 --- a/intercept/apidump/apidump_test.go +++ b/intercept/apidump/apidump_test.go @@ -233,12 +233,25 @@ func TestPrettyPrintJSON(t *testing.T) { { name: "valid JSON", input: []byte(`{"key":"value"}`), - expected: "{\n \"key\": \"value\"\n}", + expected: "{\n \"key\": \"value\"\n}\n", }, { name: "invalid JSON returns as-is", input: []byte("not json"), - expected: "not json", + expected: "not json\n", + }, + // see: https://github.com/tidwall/pretty/blob/9090695766b652478676cc3e55bc3187056b1ff0/pretty.go#L117 + // for input starting with "t" it would change it to "true", eg. "t_rest_of_the_string_is_discarded" -> "true" + // similar for inputs startrting with "f" and "n" + { + name: "invalid JSON edge case t", + input: []byte("test"), + expected: "test\n", + }, + { + name: "invalid JSON edge case f", + input: []byte("f"), + expected: "f\n", }, } From 1a03d348e9489fff85ef641343b9821e61fb989f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 20 Feb 2026 12:08:07 +0000 Subject: [PATCH 3/5] added missing error handling --- intercept/apidump/apidump.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/intercept/apidump/apidump.go b/intercept/apidump/apidump.go index 533fec2..fa925a5 100644 --- a/intercept/apidump/apidump.go +++ b/intercept/apidump/apidump.go @@ -94,9 +94,12 @@ func (d *dumper) dumpRequest(req *http.Request) error { if err != nil { return fmt.Errorf("write request uri: %w", err) } - d.writeRedactedHeaders(&buf, req.Header, sensitiveRequestHeaders, map[string]string{ + err = d.writeRedactedHeaders(&buf, req.Header, sensitiveRequestHeaders, map[string]string{ "Content-Length": fmt.Sprintf("%d", len(prettyBody)), }) + if err != nil { + return fmt.Errorf("write request headers: %w", err) + } _, err = fmt.Fprintf(&buf, "\r\n") if err != nil { @@ -118,7 +121,7 @@ func (d *dumper) dumpResponse(resp *http.Response) error { } err = d.writeRedactedHeaders(&headerBuf, resp.Header, sensitiveResponseHeaders, nil) if err != nil { - return err + return fmt.Errorf("write response headers: %w", err) } _, err = fmt.Fprintf(&headerBuf, "\r\n") if err != nil { From 7b4701993c482581bf50442d76e5cad27538144b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 20 Feb 2026 12:47:35 +0000 Subject: [PATCH 4/5] Made NewRoundTripperMiddleware reusable --- intercept/apidump/apidump.go | 30 +++++++----- intercept/apidump/apidump_test.go | 77 +++++++++++++++++++++++-------- 2 files changed, 78 insertions(+), 29 deletions(-) diff --git a/intercept/apidump/apidump.go b/intercept/apidump/apidump.go index fa925a5..3ba7f9c 100644 --- a/intercept/apidump/apidump.go +++ b/intercept/apidump/apidump.go @@ -218,22 +218,30 @@ func NewRoundTripperMiddleware(transport http.RoundTripper, baseDir string, prov return transport } return &dumpRoundTripper{ - inner: transport, - dumper: dumper{ - dumpPath: passthroughDumpPath(baseDir, provider, clk), - logger: logger, - }, + inner: transport, + baseDir: baseDir, + provider: provider, + clk: clk, + logger: logger, } } type dumpRoundTripper struct { - inner http.RoundTripper - dumper dumper + inner http.RoundTripper + baseDir string + provider string + clk quartz.Clock + logger slog.Logger } func (rt *dumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if err := rt.dumper.dumpRequest(req); err != nil { - rt.dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough request", slog.Error(err)) + dumper := dumper{ + dumpPath: passthroughDumpPath(rt.baseDir, rt.provider, rt.clk), + logger: rt.logger, + } + + if err := dumper.dumpRequest(req); err != nil { + dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough request", slog.Error(err)) } resp, err := rt.inner.RoundTrip(req) @@ -241,8 +249,8 @@ func (rt *dumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) return resp, err } - if err := rt.dumper.dumpResponse(resp); err != nil { - rt.dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough response", slog.Error(err)) + if err := dumper.dumpResponse(resp); err != nil { + dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough response", slog.Error(err)) } return resp, nil diff --git a/intercept/apidump/apidump_test.go b/intercept/apidump/apidump_test.go index 4e2acbd..03de0ac 100644 --- a/intercept/apidump/apidump_test.go +++ b/intercept/apidump/apidump_test.go @@ -2,10 +2,12 @@ package apidump import ( "bytes" + "fmt" "io" "net/http" "os" "path/filepath" + "sort" "testing" "cdr.dev/slog/v3" @@ -367,55 +369,94 @@ func TestRoundTripperMiddleware(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) clk := quartz.NewMock(t) + req1Body := `first request` + req2Body := `{"request": 2}` + req2BodyPretty := "{\n \"request\": 2\n}\n" + + callCount := 0 inner := &mockRoundTripper{ roundTrip: func(req *http.Request) (*http.Response, error) { // Verify body is still readable after dump body, err := io.ReadAll(req.Body) require.NoError(t, err) - require.Equal(t, `{"request": true}`, string(body)) + callCount++ + if callCount == 1 { + require.Equal(t, req1Body, string(body)) + } else { + require.Equal(t, req2Body, string(body)) + } return &http.Response{ StatusCode: http.StatusOK, Status: "200 OK", Proto: "HTTP/1.1", Header: http.Header{"Content-Type": []string{"application/json"}}, - Body: io.NopCloser(bytes.NewReader([]byte(`{"response": true}`))), + Body: io.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(`{"call": %d}"`, callCount)))), }, nil }, } rt := NewRoundTripperMiddleware(inner, tmpDir, "openai", logger, clk) - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/models", bytes.NewReader([]byte(`{"request": true}`))) + req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/models", bytes.NewReader([]byte(req1Body))) require.NoError(t, err) req.Header.Set("Authorization", "Bearer sk-secret-key-12345") - resp, err := rt.RoundTrip(req) require.NoError(t, err) - - // Must read and close response body to trigger the streaming dump _, err = io.ReadAll(resp.Body) require.NoError(t, err) require.NoError(t, resp.Body.Close()) - // Verify files are in passthrough directory + // Second request should create new req/resp files + req2, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/models", bytes.NewReader([]byte(req2Body))) + require.NoError(t, err) + resp2, err := rt.RoundTrip(req2) + require.NoError(t, err) + _, err = io.ReadAll(resp2.Body) + require.NoError(t, err) + require.NoError(t, resp2.Body.Close()) + + // Validate request files contents passthroughDir := filepath.Join(tmpDir, "openai", "passthrough") - reqDumpPath := findDumpFile(t, passthroughDir, SuffixRequest) - reqContent, err := os.ReadFile(reqDumpPath) + reqPattern := filepath.Join(passthroughDir, "*"+SuffixRequest) + reqMatches, err := filepath.Glob(reqPattern) require.NoError(t, err) + require.Len(t, reqMatches, 2, "expected exactly two %s files in %s", SuffixRequest, passthroughDir) - require.Contains(t, string(reqContent), "POST") - require.Contains(t, string(reqContent), `"request": true`) - // Sensitive header should be redacted - require.NotContains(t, string(reqContent), "sk-secret-key-12345") - require.Contains(t, string(reqContent), "Authorization:") + reqContents := make([]string, 0, len(reqMatches)) + for _, reqDumpPath := range reqMatches { + reqContent, readErr := os.ReadFile(reqDumpPath) + require.NoError(t, readErr) + reqContents = append(reqContents, string(reqContent)) + } - respDumpPath := findDumpFile(t, passthroughDir, SuffixResponse) - respContent, err := os.ReadFile(respDumpPath) + sort.Strings(reqContents) + require.Contains(t, reqContents[0], req1Body+"\n") + require.Contains(t, reqContents[1], req2BodyPretty) + // Sensitive header should be redacted + require.NotContains(t, reqContents[0], "sk-secret-key-12345") + require.NotContains(t, reqContents[1], "sk-secret-key-12345") + require.Contains(t, reqContents[0], "Authorization:") + require.NotContains(t, reqContents[1], "Authorization:") + + // Validate response files contents + respPattern := filepath.Join(passthroughDir, "*"+SuffixResponse) + respMatches, err := filepath.Glob(respPattern) require.NoError(t, err) + require.Len(t, respMatches, 2, "expected exactly two %s files in %s", SuffixResponse, passthroughDir) + + respContents := make([]string, 0, len(respMatches)) + for _, respDumpPath := range respMatches { + respContent, readErr := os.ReadFile(respDumpPath) + require.NoError(t, readErr) + respContents = append(respContents, string(respContent)) + } - require.Contains(t, string(respContent), "200 OK") - require.Contains(t, string(respContent), `"response": true`) + sort.Strings(respContents) + require.Contains(t, respContents[0], "200 OK") + require.Contains(t, respContents[0], `{"call": 1}"`) + require.Contains(t, respContents[1], "200 OK") + require.Contains(t, respContents[1], `{"call": 2}"`) }) } From 153f439b5e1e9495d2ea9f8f43b84304d938345e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Tue, 3 Mar 2026 13:25:53 +0000 Subject: [PATCH 5/5] review 1: pass though dump file name includes url path + renaming --- apidump_integration_test.go | 24 +++++-- intercept/apidump/apidump.go | 33 ++++----- intercept/apidump/apidump_test.go | 104 +++++++++++++--------------- intercept/apidump/streaming_test.go | 4 +- intercept/chatcompletions/base.go | 2 +- intercept/messages/base.go | 2 +- intercept/responses/base.go | 2 +- passthrough.go | 2 +- 8 files changed, 86 insertions(+), 87 deletions(-) diff --git a/apidump_integration_test.go b/apidump_integration_test.go index e064ad2..fd17ead 100644 --- a/apidump_integration_test.go +++ b/apidump_integration_test.go @@ -179,30 +179,34 @@ func TestAPIDumpPassthrough(t *testing.T) { const responseBody = `{"object":"list","data":[{"id":"gpt-4","object":"model"}]}` cases := []struct { - name string - providerFunc func(addr string, dumpDir string) aibridge.Provider - requestPath string + name string + providerFunc func(addr string, dumpDir string) aibridge.Provider + requestPath string + expectDumpName string }{ { name: "anthropic", providerFunc: func(addr string, dumpDir string) aibridge.Provider { return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil) }, - requestPath: "/anthropic/v1/models", + requestPath: "/anthropic/v1/models", + expectDumpName: "-v1-models-", }, { name: "openai", providerFunc: func(addr string, dumpDir string) aibridge.Provider { return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) }, - requestPath: "/openai/v1/models", + requestPath: "/openai/v1/models", + expectDumpName: "-models-", }, { name: "copilot", providerFunc: func(addr string, dumpDir string) aibridge.Provider { return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir}) }, - requestPath: "/copilot/models", + requestPath: "/copilot/models", + expectDumpName: "-models-", }, } @@ -261,8 +265,16 @@ func TestAPIDumpPassthrough(t *testing.T) { return nil }) require.NoError(t, err, "walking failed: %v", err) + require.NotEmpty(t, reqDumpFile, "request dump file should exist") + require.FileExists(t, reqDumpFile) + require.Contains(t, reqDumpFile, "/passthrough/") + require.Contains(t, reqDumpFile, tc.expectDumpName) + require.NotEmpty(t, respDumpFile, "response dump file should exist") + require.FileExists(t, respDumpFile) + require.Contains(t, respDumpFile, "/passthrough/") + require.Contains(t, respDumpFile, tc.expectDumpName) // Verify request dump. reqDumpData, err := os.ReadFile(reqDumpFile) diff --git a/intercept/apidump/apidump.go b/intercept/apidump/apidump.go index 3ba7f9c..e8e6d89 100644 --- a/intercept/apidump/apidump.go +++ b/intercept/apidump/apidump.go @@ -32,10 +32,9 @@ type MiddlewareNext = func(*http.Request) (*http.Response, error) // Middleware is an HTTP middleware function compatible with SDK WithMiddleware options. type Middleware = func(*http.Request, MiddlewareNext) (*http.Response, error) -// NewMiddleware returns a middleware function that dumps requests and responses to files. -// Files are written to the path returned by DumpPath. +// NewBridgeMiddleware returns a middleware function that dumps requests and responses to files. // If baseDir is empty, returns nil (no middleware). -func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, logger slog.Logger, clk quartz.Clock) Middleware { +func NewBridgeMiddleware(baseDir string, provider string, model string, interceptionID uuid.UUID, logger slog.Logger, clk quartz.Clock) Middleware { if baseDir == "" { return nil } @@ -103,9 +102,10 @@ func (d *dumper) dumpRequest(req *http.Request) error { _, err = fmt.Fprintf(&buf, "\r\n") if err != nil { - return fmt.Errorf("write request body: %w", err) + return fmt.Errorf("write request header terminator: %w", err) } buf.Write(prettyBody) + buf.WriteByte('\n') return os.WriteFile(dumpPath, buf.Bytes(), 0o644) } @@ -125,7 +125,7 @@ func (d *dumper) dumpResponse(resp *http.Response) error { } _, err = fmt.Fprintf(&headerBuf, "\r\n") if err != nil { - return fmt.Errorf("write response body: %w", err) + return fmt.Errorf("write response header terminator: %w", err) } // Wrap the response body to capture it as it streams @@ -197,23 +197,22 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv return nil } -// interceptDumpPath returns the base file path (without suffix) for an interception dump. +// interceptDumpPath returns the base file path (without req/resp suffix) for an interception dump. func interceptDumpPath(baseDir string, provider string, model string, interceptionID uuid.UUID, clk quartz.Clock) string { safeModel := strings.ReplaceAll(model, "/", "-") return filepath.Join(baseDir, provider, safeModel, fmt.Sprintf("%d-%s", clk.Now().UTC().UnixMilli(), interceptionID)) } -// passthroughDumpPath returns the base file path (without suffix) for a passthrough dump. -// A random UUID is generated for the filename. "passthrough" is used as the directory name -// in place of the model. -func passthroughDumpPath(baseDir string, provider string, clk quartz.Clock) string { - return filepath.Join(baseDir, provider, "passthrough", fmt.Sprintf("%d-%s", clk.Now().UTC().UnixMilli(), uuid.New())) +// passthroughDumpPath returns the base file path (without req/resp suffix) for a passthrough dump. +func passthroughDumpPath(baseDir string, provider string, urlPath string, clk quartz.Clock) string { + safeURLPath := strings.ReplaceAll(strings.TrimPrefix(urlPath, "/"), "/", "-") + return filepath.Join(baseDir, provider, "passthrough", fmt.Sprintf("%d-%s-%s", clk.Now().UTC().UnixMilli(), safeURLPath, uuid.NewString()[:4])) } -// NewRoundTripperMiddleware returns http.RoundTripper that dumps requests and responses to files. +// NewPassthroughMiddleware returns http.RoundTripper that dumps requests and responses to files. // If baseDir is empty, returns the original transport unchanged. -// Used for logging passed through requests. -func NewRoundTripperMiddleware(transport http.RoundTripper, baseDir string, provider string, logger slog.Logger, clk quartz.Clock) http.RoundTripper { +// Used for logging in pass through routes. +func NewPassthroughMiddleware(transport http.RoundTripper, baseDir string, provider string, logger slog.Logger, clk quartz.Clock) http.RoundTripper { if baseDir == "" { return transport } @@ -236,7 +235,7 @@ type dumpRoundTripper struct { func (rt *dumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { dumper := dumper{ - dumpPath: passthroughDumpPath(rt.baseDir, rt.provider, rt.clk), + dumpPath: passthroughDumpPath(rt.baseDir, rt.provider, req.URL.Path, rt.clk), logger: rt.logger, } @@ -269,9 +268,5 @@ func prettyPrintJSON(body []byte) []byte { result = pretty.Pretty(body) } - // Add trailing newline if missing. - if !bytes.HasSuffix(result, []byte("\n")) { - result = append(result, []byte("\n")...) - } return result } diff --git a/intercept/apidump/apidump_test.go b/intercept/apidump/apidump_test.go index 03de0ac..1fa3d7f 100644 --- a/intercept/apidump/apidump_test.go +++ b/intercept/apidump/apidump_test.go @@ -7,7 +7,7 @@ import ( "net/http" "os" "path/filepath" - "sort" + "strings" "testing" "cdr.dev/slog/v3" @@ -27,7 +27,7 @@ func findDumpFile(t *testing.T, dir, suffix string) string { return matches[0] } -func TestMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { +func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { t.Parallel() tmpDir := t.TempDir() @@ -35,7 +35,7 @@ func TestMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { clk := quartz.NewMock(t) interceptionID := uuid.New() - middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`))) @@ -84,7 +84,7 @@ func TestMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { require.Contains(t, content, "User-Agent: test-client") } -func TestMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) { +func TestBridgedMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) { t.Parallel() tmpDir := t.TempDir() @@ -92,7 +92,7 @@ func TestMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) { clk := quartz.NewMock(t) interceptionID := uuid.New() - middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) @@ -145,15 +145,15 @@ func TestMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) { require.Contains(t, content, "X-Request-Id: req-123") } -func TestMiddleware_EmptyBaseDir_ReturnsNil(t *testing.T) { +func TestBridgedMiddleware_EmptyBaseDir_ReturnsNil(t *testing.T) { t.Parallel() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - middleware := NewMiddleware("", "openai", "gpt-4", uuid.New(), logger, quartz.NewMock(t)) + middleware := NewBridgeMiddleware("", "openai", "gpt-4", uuid.New(), logger, quartz.NewMock(t)) require.Nil(t, middleware) } -func TestMiddleware_PreservesRequestBody(t *testing.T) { +func TestBridgedMiddleware_PreservesRequestBody(t *testing.T) { t.Parallel() tmpDir := t.TempDir() @@ -161,7 +161,7 @@ func TestMiddleware_PreservesRequestBody(t *testing.T) { clk := quartz.NewMock(t) interceptionID := uuid.New() - middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) originalBody := `{"messages": [{"role": "user", "content": "hello"}]}` @@ -186,7 +186,7 @@ func TestMiddleware_PreservesRequestBody(t *testing.T) { require.Equal(t, originalBody, string(capturedBody)) } -func TestMiddleware_ModelWithSlash(t *testing.T) { +func TestBridgedMiddleware_ModelWithSlash(t *testing.T) { t.Parallel() tmpDir := t.TempDir() @@ -195,7 +195,7 @@ func TestMiddleware_ModelWithSlash(t *testing.T) { interceptionID := uuid.New() // Model with slash should have it replaced with dash - middleware := NewMiddleware(tmpDir, "google", "gemini/1.5-pro", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "google", "gemini/1.5-pro", interceptionID, logger, clk) require.NotNil(t, middleware) req, err := http.NewRequest(http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`))) @@ -240,7 +240,7 @@ func TestPrettyPrintJSON(t *testing.T) { { name: "invalid JSON returns as-is", input: []byte("not json"), - expected: "not json\n", + expected: "not json", }, // see: https://github.com/tidwall/pretty/blob/9090695766b652478676cc3e55bc3187056b1ff0/pretty.go#L117 // for input starting with "t" it would change it to "true", eg. "t_rest_of_the_string_is_discarded" -> "true" @@ -248,12 +248,12 @@ func TestPrettyPrintJSON(t *testing.T) { { name: "invalid JSON edge case t", input: []byte("test"), - expected: "test\n", + expected: "test", }, { name: "invalid JSON edge case f", input: []byte("f"), - expected: "f\n", + expected: "f", }, } @@ -266,7 +266,7 @@ func TestPrettyPrintJSON(t *testing.T) { } } -func TestMiddleware_AllSensitiveRequestHeaders(t *testing.T) { +func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) { t.Parallel() tmpDir := t.TempDir() @@ -274,7 +274,7 @@ func TestMiddleware_AllSensitiveRequestHeaders(t *testing.T) { clk := quartz.NewMock(t) interceptionID := uuid.New() - middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) @@ -327,14 +327,14 @@ func TestMiddleware_AllSensitiveRequestHeaders(t *testing.T) { require.Contains(t, content, "X-Amz-Security-Token:") } -func TestRoundTripperMiddleware(t *testing.T) { +func TestPassthroughMiddleware(t *testing.T) { t.Parallel() t.Run("empty_base_dir_returns_original_transport", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) inner := http.DefaultTransport - rt := NewRoundTripperMiddleware(inner, "", "openai", logger, quartz.NewMock(t)) + rt := NewPassthroughMiddleware(inner, "", "openai", logger, quartz.NewMock(t)) require.Equal(t, inner, rt) }) @@ -352,7 +352,7 @@ func TestRoundTripperMiddleware(t *testing.T) { }, } - rt := NewRoundTripperMiddleware(inner, tmpDir, "openai", logger, clk) + rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk) req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/models", nil) require.NoError(t, err) @@ -396,9 +396,9 @@ func TestRoundTripperMiddleware(t *testing.T) { }, } - rt := NewRoundTripperMiddleware(inner, tmpDir, "openai", logger, clk) + rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk) - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/models", bytes.NewReader([]byte(req1Body))) + req, err := http.NewRequest(http.MethodPost, "/v1/models", bytes.NewReader([]byte(req1Body))) require.NoError(t, err) req.Header.Set("Authorization", "Bearer sk-secret-key-12345") resp, err := rt.RoundTrip(req) @@ -408,7 +408,7 @@ func TestRoundTripperMiddleware(t *testing.T) { require.NoError(t, resp.Body.Close()) // Second request should create new req/resp files - req2, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/models", bytes.NewReader([]byte(req2Body))) + req2, err := http.NewRequest(http.MethodPost, "/v1/conversations", bytes.NewReader([]byte(req2Body))) require.NoError(t, err) resp2, err := rt.RoundTrip(req2) require.NoError(t, err) @@ -418,45 +418,25 @@ func TestRoundTripperMiddleware(t *testing.T) { // Validate request files contents passthroughDir := filepath.Join(tmpDir, "openai", "passthrough") - reqPattern := filepath.Join(passthroughDir, "*"+SuffixRequest) - reqMatches, err := filepath.Glob(reqPattern) - require.NoError(t, err) - require.Len(t, reqMatches, 2, "expected exactly two %s files in %s", SuffixRequest, passthroughDir) - - reqContents := make([]string, 0, len(reqMatches)) - for _, reqDumpPath := range reqMatches { - reqContent, readErr := os.ReadFile(reqDumpPath) - require.NoError(t, readErr) - reqContents = append(reqContents, string(reqContent)) - } + req1Dump := readDumpFileContent(t, filepath.Join(passthroughDir, "*-v1-models-*"+SuffixRequest)) + req2Dump := readDumpFileContent(t, filepath.Join(passthroughDir, "*-v1-conversations-*"+SuffixRequest)) - sort.Strings(reqContents) - require.Contains(t, reqContents[0], req1Body+"\n") - require.Contains(t, reqContents[1], req2BodyPretty) + require.Contains(t, req1Dump, req1Body+"\n") + require.Contains(t, req2Dump, req2BodyPretty) // Sensitive header should be redacted - require.NotContains(t, reqContents[0], "sk-secret-key-12345") - require.NotContains(t, reqContents[1], "sk-secret-key-12345") - require.Contains(t, reqContents[0], "Authorization:") - require.NotContains(t, reqContents[1], "Authorization:") + require.NotContains(t, req1Dump, "sk-secret-key-12345") + require.NotContains(t, req2Dump, "sk-secret-key-12345") + require.Contains(t, req1Dump, "Authorization:") + require.NotContains(t, req2Dump, "Authorization:") // Validate response files contents - respPattern := filepath.Join(passthroughDir, "*"+SuffixResponse) - respMatches, err := filepath.Glob(respPattern) - require.NoError(t, err) - require.Len(t, respMatches, 2, "expected exactly two %s files in %s", SuffixResponse, passthroughDir) - - respContents := make([]string, 0, len(respMatches)) - for _, respDumpPath := range respMatches { - respContent, readErr := os.ReadFile(respDumpPath) - require.NoError(t, readErr) - respContents = append(respContents, string(respContent)) - } + resp1Dump := readDumpFileContent(t, filepath.Join(passthroughDir, "*-v1-models-*"+SuffixResponse)) + resp2Dump := readDumpFileContent(t, filepath.Join(passthroughDir, "*-v1-conversations-*"+SuffixResponse)) - sort.Strings(respContents) - require.Contains(t, respContents[0], "200 OK") - require.Contains(t, respContents[0], `{"call": 1}"`) - require.Contains(t, respContents[1], "200 OK") - require.Contains(t, respContents[1], `{"call": 2}"`) + require.Contains(t, resp1Dump, "200 OK") + require.Contains(t, resp1Dump, `{"call": 1}"`) + require.Contains(t, resp2Dump, "200 OK") + require.Contains(t, resp2Dump, `{"call": 2}"`) }) } @@ -467,3 +447,15 @@ type mockRoundTripper struct { func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return m.roundTrip(req) } + +// readDumpFileContent reads the content of the dump file matching the pattern. +// Expects exactly one file to match the pattern. +func readDumpFileContent(t *testing.T, pattern string) string { + t.Helper() + matches, err := filepath.Glob(pattern) + require.NoError(t, err) + require.Len(t, matches, 1, "expected exactly one match got: %v %s", len(matches), strings.Join(matches, ", "), pattern) + reqContent, readErr := os.ReadFile(matches[0]) + require.NoError(t, readErr) + return string(reqContent) +} diff --git a/intercept/apidump/streaming_test.go b/intercept/apidump/streaming_test.go index 3d76555..653a726 100644 --- a/intercept/apidump/streaming_test.go +++ b/intercept/apidump/streaming_test.go @@ -24,7 +24,7 @@ func TestMiddleware_StreamingResponse(t *testing.T) { clk := quartz.NewMock(t) interceptionID := uuid.New() - middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) @@ -100,7 +100,7 @@ func TestMiddleware_PreservesResponseBody(t *testing.T) { clk := quartz.NewMock(t) interceptionID := uuid.New() - middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index aed8f88..7a755e0 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -46,7 +46,7 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService } // Add API dump middleware if configured - if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 387591d..6a5f512 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -182,7 +182,7 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) // Add API dump middleware if configured - if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, aibconfig.ProviderAnthropic, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, aibconfig.ProviderAnthropic, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } diff --git a/intercept/responses/base.go b/intercept/responses/base.go index b531a71..8b7c3de 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -59,7 +59,7 @@ func (i *responsesInterceptionBase) newResponsesService() responses.ResponseServ } // Add API dump middleware if configured - if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } diff --git a/passthrough.go b/passthrough.go index 66da45a..c6b59ed 100644 --- a/passthrough.go +++ b/passthrough.go @@ -112,7 +112,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } - proxy.Transport = apidump.NewRoundTripperMiddleware(t, provider.APIDumpDir(), provider.Name(), logger, quartz.NewReal()) + proxy.Transport = apidump.NewPassthroughMiddleware(t, provider.APIDumpDir(), provider.Name(), logger, quartz.NewReal()) proxy.ServeHTTP(w, r) }