diff --git a/apidump_integration_test.go b/apidump_integration_test.go index 29db3f2..fd17ead 100644 --- a/apidump_integration_test.go +++ b/apidump_integration_test.go @@ -49,32 +49,28 @@ func TestAPIDump(t *testing.T) { cases := []struct { name string fixture []byte - providerName string providersFunc func(addr, dumpDir string) []aibridge.Provider createRequestFunc createRequestFunc }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntSimple, - providerName: config.ProviderAnthropic, + name: "anthropic", + fixture: fixtures.AntSimple, providersFunc: func(addr, dumpDir string) []aibridge.Provider { return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)} }, createRequestFunc: createAnthropicMessagesReq, }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatSimple, - providerName: config.ProviderOpenAI, + name: "openai_chat_completions", + fixture: fixtures.OaiChatSimple, providersFunc: func(addr, dumpDir string) []aibridge.Provider { return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} }, createRequestFunc: createOpenAIChatCompletionsReq, }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiResponsesBlockingSimple, - providerName: config.ProviderOpenAI, + name: "openai_responses", + fixture: fixtures.OaiResponsesBlockingSimple, providersFunc: func(addr, dumpDir string) []aibridge.Provider { return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} }, @@ -176,3 +172,126 @@ func TestAPIDump(t *testing.T) { }) } } + +func TestAPIDumpPassthrough(t *testing.T) { + t.Parallel() + + const responseBody = `{"object":"list","data":[{"id":"gpt-4","object":"model"}]}` + + cases := []struct { + name string + providerFunc func(addr string, dumpDir string) aibridge.Provider + requestPath string + expectDumpName string + }{ + { + name: "anthropic", + providerFunc: func(addr string, dumpDir string) aibridge.Provider { + return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil) + }, + requestPath: "/anthropic/v1/models", + expectDumpName: "-v1-models-", + }, + { + name: "openai", + providerFunc: func(addr string, dumpDir string) aibridge.Provider { + return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) + }, + requestPath: "/openai/v1/models", + expectDumpName: "-models-", + }, + { + name: "copilot", + providerFunc: func(addr string, dumpDir string) aibridge.Provider { + return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir}) + }, + requestPath: "/copilot/models", + expectDumpName: "-models-", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(responseBody)) + })) + t.Cleanup(upstream.Close) + + dumpDir := t.TempDir() + + recorderClient := &testutil.MockRecorder{} + prov := tc.providerFunc(upstream.URL, dumpDir) + provs := []aibridge.Provider{prov} + b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + require.NoError(t, err) + + bridgeSrv := httptest.NewUnstartedServer(b) + t.Cleanup(bridgeSrv.Close) + bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibcontext.AsActor(ctx, userID, nil) + } + bridgeSrv.Start() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, bridgeSrv.URL+tc.requestPath, nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Find dump files in the passthrough directory. + passthroughDir := filepath.Join(dumpDir, tc.name, "passthrough") + var reqDumpFile, respDumpFile string + err = filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + if strings.HasSuffix(path, apidump.SuffixRequest) { + reqDumpFile = path + } else if strings.HasSuffix(path, apidump.SuffixResponse) { + respDumpFile = path + } + return nil + }) + require.NoError(t, err, "walking failed: %v", err) + + require.NotEmpty(t, reqDumpFile, "request dump file should exist") + require.FileExists(t, reqDumpFile) + require.Contains(t, reqDumpFile, "/passthrough/") + require.Contains(t, reqDumpFile, tc.expectDumpName) + + require.NotEmpty(t, respDumpFile, "response dump file should exist") + require.FileExists(t, respDumpFile) + require.Contains(t, respDumpFile, "/passthrough/") + require.Contains(t, respDumpFile, tc.expectDumpName) + + // Verify request dump. + reqDumpData, err := os.ReadFile(reqDumpFile) + require.NoError(t, err) + dumpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(reqDumpData))) + require.NoError(t, err) + require.Equal(t, http.MethodGet, dumpReq.Method) + + // Verify response dump. + respDumpData, err := os.ReadFile(respDumpFile) + require.NoError(t, err) + dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil) + require.NoError(t, err) + require.Equal(t, http.StatusOK, dumpResp.StatusCode) + dumpRespBody, err := io.ReadAll(dumpResp.Body) + require.NoError(t, err) + require.JSONEq(t, responseBody, string(dumpRespBody)) + }) + } +} diff --git a/intercept/apidump/apidump.go b/intercept/apidump/apidump.go index 2e8355f..e8e6d89 100644 --- a/intercept/apidump/apidump.go +++ b/intercept/apidump/apidump.go @@ -32,26 +32,21 @@ type MiddlewareNext = func(*http.Request) (*http.Response, error) // Middleware is an HTTP middleware function compatible with SDK WithMiddleware options. type Middleware = func(*http.Request, MiddlewareNext) (*http.Response, error) -// NewMiddleware returns a middleware function that dumps requests and responses to files. -// Files are written to the path returned by DumpPath. +// NewBridgeMiddleware returns a middleware function that dumps requests and responses to files. // If baseDir is empty, returns nil (no middleware). -func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, logger slog.Logger, clk quartz.Clock) Middleware { +func NewBridgeMiddleware(baseDir string, provider string, model string, interceptionID uuid.UUID, logger slog.Logger, clk quartz.Clock) Middleware { if baseDir == "" { return nil } d := &dumper{ - baseDir: baseDir, - provider: provider, - model: model, - interceptionID: interceptionID, - clk: clk, - logger: logger, + dumpPath: interceptDumpPath(baseDir, provider, model, interceptionID, clk), + logger: logger, } return func(req *http.Request, next MiddlewareNext) (*http.Response, error) { if err := d.dumpRequest(req); err != nil { - logger.Named("apidump").Warn(context.Background(), "failed to dump request", slog.Error(err)) + logger.Named("apidump").Warn(req.Context(), "failed to dump request", slog.Error(err)) } // TODO: https://github.com/coder/aibridge/issues/129 @@ -61,7 +56,7 @@ func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, lo } if err := d.dumpResponse(resp); err != nil { - logger.Named("apidump").Warn(context.Background(), "failed to dump response", slog.Error(err)) + logger.Named("apidump").Warn(req.Context(), "failed to dump response", slog.Error(err)) } return resp, nil @@ -69,16 +64,12 @@ func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, lo } type dumper struct { - baseDir string - provider string - model string - interceptionID uuid.UUID - clk quartz.Clock - logger slog.Logger + dumpPath string + logger slog.Logger } func (d *dumper) dumpRequest(req *http.Request) error { - dumpPath := d.path(SuffixRequest) + dumpPath := d.dumpPath + SuffixRequest if err := os.MkdirAll(filepath.Dir(dumpPath), 0o755); err != nil { return fmt.Errorf("create dump dir: %w", err) } @@ -98,25 +89,44 @@ func (d *dumper) dumpRequest(req *http.Request) error { // Build raw HTTP request format var buf bytes.Buffer - fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto) - d.writeRedactedHeaders(&buf, req.Header, sensitiveRequestHeaders, map[string]string{ + _, err := fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto) + if err != nil { + return fmt.Errorf("write request uri: %w", err) + } + err = d.writeRedactedHeaders(&buf, req.Header, sensitiveRequestHeaders, map[string]string{ "Content-Length": fmt.Sprintf("%d", len(prettyBody)), }) + if err != nil { + return fmt.Errorf("write request headers: %w", err) + } - fmt.Fprintf(&buf, "\r\n") + _, err = fmt.Fprintf(&buf, "\r\n") + if err != nil { + return fmt.Errorf("write request header terminator: %w", err) + } buf.Write(prettyBody) + buf.WriteByte('\n') return os.WriteFile(dumpPath, buf.Bytes(), 0o644) } func (d *dumper) dumpResponse(resp *http.Response) error { - dumpPath := d.path(SuffixResponse) + dumpPath := d.dumpPath + SuffixResponse // Build raw HTTP response headers var headerBuf bytes.Buffer - fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status) - d.writeRedactedHeaders(&headerBuf, resp.Header, sensitiveResponseHeaders, nil) - fmt.Fprintf(&headerBuf, "\r\n") + _, err := fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status) + if err != nil { + return fmt.Errorf("write response status: %w", err) + } + err = d.writeRedactedHeaders(&headerBuf, resp.Header, sensitiveResponseHeaders, nil) + if err != nil { + return fmt.Errorf("write response headers: %w", err) + } + _, err = fmt.Fprintf(&headerBuf, "\r\n") + if err != nil { + return fmt.Errorf("write response header terminator: %w", err) + } // Wrap the response body to capture it as it streams if resp.Body != nil { @@ -141,7 +151,7 @@ func (d *dumper) dumpResponse(resp *http.Response) error { // for deterministic output. // `sensitive` and `overrides` must both supply keys in canoncialized form. // See [textproto.MIMEHeader]. -func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) { +func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) error { // Collect all header keys including overrides. headerKeys := make([]string, 0, len(headers)+len(overrides)) seen := make(map[string]struct{}, len(headers)+len(overrides)) @@ -163,7 +173,10 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv // If no values exist but we have an override, use that. if len(values) == 0 { if override, ok := overrides[key]; ok { - fmt.Fprintf(w, "%s: %s\r\n", key, override) + _, err := fmt.Fprintf(w, "%s: %s\r\n", key, override) + if err != nil { + return fmt.Errorf("write response header override: %w", err) + } } continue } @@ -175,16 +188,71 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv if isSensitive { value = redactHeaderValue(value) } - fmt.Fprintf(w, "%s: %s\r\n", key, value) + _, err := fmt.Fprintf(w, "%s: %s\r\n", key, value) + if err != nil { + return fmt.Errorf("write response headers: %w", err) + } } } + return nil +} + +// interceptDumpPath returns the base file path (without req/resp suffix) for an interception dump. +func interceptDumpPath(baseDir string, provider string, model string, interceptionID uuid.UUID, clk quartz.Clock) string { + safeModel := strings.ReplaceAll(model, "/", "-") + return filepath.Join(baseDir, provider, safeModel, fmt.Sprintf("%d-%s", clk.Now().UTC().UnixMilli(), interceptionID)) +} + +// passthroughDumpPath returns the base file path (without req/resp suffix) for a passthrough dump. +func passthroughDumpPath(baseDir string, provider string, urlPath string, clk quartz.Clock) string { + safeURLPath := strings.ReplaceAll(strings.TrimPrefix(urlPath, "/"), "/", "-") + return filepath.Join(baseDir, provider, "passthrough", fmt.Sprintf("%d-%s-%s", clk.Now().UTC().UnixMilli(), safeURLPath, uuid.NewString()[:4])) +} + +// NewPassthroughMiddleware returns http.RoundTripper that dumps requests and responses to files. +// If baseDir is empty, returns the original transport unchanged. +// Used for logging in pass through routes. +func NewPassthroughMiddleware(transport http.RoundTripper, baseDir string, provider string, logger slog.Logger, clk quartz.Clock) http.RoundTripper { + if baseDir == "" { + return transport + } + return &dumpRoundTripper{ + inner: transport, + baseDir: baseDir, + provider: provider, + clk: clk, + logger: logger, + } } -// path returns the path to a request/response dump file for a given interception. -// suffix should be SuffixRequest or SuffixResponse. -func (d *dumper) path(suffix string) string { - safeModel := strings.ReplaceAll(d.model, "/", "-") - return filepath.Join(d.baseDir, d.provider, safeModel, fmt.Sprintf("%d-%s%s", d.clk.Now().UTC().UnixMilli(), d.interceptionID, suffix)) +type dumpRoundTripper struct { + inner http.RoundTripper + baseDir string + provider string + clk quartz.Clock + logger slog.Logger +} + +func (rt *dumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + dumper := dumper{ + dumpPath: passthroughDumpPath(rt.baseDir, rt.provider, req.URL.Path, rt.clk), + logger: rt.logger, + } + + if err := dumper.dumpRequest(req); err != nil { + dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough request", slog.Error(err)) + } + + resp, err := rt.inner.RoundTrip(req) + if err != nil { + return resp, err + } + + if err := dumper.dumpResponse(resp); err != nil { + dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough response", slog.Error(err)) + } + + return resp, nil } // prettyPrintJSON returns indented JSON if body is valid JSON, otherwise returns body as-is. @@ -194,12 +262,11 @@ func prettyPrintJSON(body []byte) []byte { if len(body) == 0 { return body } - result := pretty.Pretty(body) - // pretty.Pretty returns a truncated/modified result for invalid JSON, - // so check if the result is valid JSON; if not, return the original. - if !json.Valid(result) { - return body + + result := body + if json.Valid(body) { + result = pretty.Pretty(body) } - // Trim trailing newline added by pretty.Pretty. - return bytes.TrimSuffix(result, []byte("\n")) + + return result } diff --git a/intercept/apidump/apidump_test.go b/intercept/apidump/apidump_test.go index a7b4c5b..1fa3d7f 100644 --- a/intercept/apidump/apidump_test.go +++ b/intercept/apidump/apidump_test.go @@ -2,10 +2,12 @@ package apidump import ( "bytes" + "fmt" "io" "net/http" "os" "path/filepath" + "strings" "testing" "cdr.dev/slog/v3" @@ -25,7 +27,7 @@ func findDumpFile(t *testing.T, dir, suffix string) string { return matches[0] } -func TestMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { +func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { t.Parallel() tmpDir := t.TempDir() @@ -33,7 +35,7 @@ func TestMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { clk := quartz.NewMock(t) interceptionID := uuid.New() - middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`))) @@ -82,7 +84,7 @@ func TestMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { require.Contains(t, content, "User-Agent: test-client") } -func TestMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) { +func TestBridgedMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) { t.Parallel() tmpDir := t.TempDir() @@ -90,7 +92,7 @@ func TestMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) { clk := quartz.NewMock(t) interceptionID := uuid.New() - middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) @@ -143,15 +145,15 @@ func TestMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) { require.Contains(t, content, "X-Request-Id: req-123") } -func TestMiddleware_EmptyBaseDir_ReturnsNil(t *testing.T) { +func TestBridgedMiddleware_EmptyBaseDir_ReturnsNil(t *testing.T) { t.Parallel() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - middleware := NewMiddleware("", "openai", "gpt-4", uuid.New(), logger, quartz.NewMock(t)) + middleware := NewBridgeMiddleware("", "openai", "gpt-4", uuid.New(), logger, quartz.NewMock(t)) require.Nil(t, middleware) } -func TestMiddleware_PreservesRequestBody(t *testing.T) { +func TestBridgedMiddleware_PreservesRequestBody(t *testing.T) { t.Parallel() tmpDir := t.TempDir() @@ -159,7 +161,7 @@ func TestMiddleware_PreservesRequestBody(t *testing.T) { clk := quartz.NewMock(t) interceptionID := uuid.New() - middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) originalBody := `{"messages": [{"role": "user", "content": "hello"}]}` @@ -184,7 +186,7 @@ func TestMiddleware_PreservesRequestBody(t *testing.T) { require.Equal(t, originalBody, string(capturedBody)) } -func TestMiddleware_ModelWithSlash(t *testing.T) { +func TestBridgedMiddleware_ModelWithSlash(t *testing.T) { t.Parallel() tmpDir := t.TempDir() @@ -193,7 +195,7 @@ func TestMiddleware_ModelWithSlash(t *testing.T) { interceptionID := uuid.New() // Model with slash should have it replaced with dash - middleware := NewMiddleware(tmpDir, "google", "gemini/1.5-pro", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "google", "gemini/1.5-pro", interceptionID, logger, clk) require.NotNil(t, middleware) req, err := http.NewRequest(http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`))) @@ -233,13 +235,26 @@ func TestPrettyPrintJSON(t *testing.T) { { name: "valid JSON", input: []byte(`{"key":"value"}`), - expected: "{\n \"key\": \"value\"\n}", + expected: "{\n \"key\": \"value\"\n}\n", }, { name: "invalid JSON returns as-is", input: []byte("not json"), expected: "not json", }, + // see: https://github.com/tidwall/pretty/blob/9090695766b652478676cc3e55bc3187056b1ff0/pretty.go#L117 + // for input starting with "t" it would change it to "true", eg. "t_rest_of_the_string_is_discarded" -> "true" + // similar for inputs startrting with "f" and "n" + { + name: "invalid JSON edge case t", + input: []byte("test"), + expected: "test", + }, + { + name: "invalid JSON edge case f", + input: []byte("f"), + expected: "f", + }, } for _, tc := range tests { @@ -251,7 +266,7 @@ func TestPrettyPrintJSON(t *testing.T) { } } -func TestMiddleware_AllSensitiveRequestHeaders(t *testing.T) { +func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) { t.Parallel() tmpDir := t.TempDir() @@ -259,7 +274,7 @@ func TestMiddleware_AllSensitiveRequestHeaders(t *testing.T) { clk := quartz.NewMock(t) interceptionID := uuid.New() - middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) @@ -311,3 +326,136 @@ func TestMiddleware_AllSensitiveRequestHeaders(t *testing.T) { require.Contains(t, content, "Proxy-Authorization:") require.Contains(t, content, "X-Amz-Security-Token:") } + +func TestPassthroughMiddleware(t *testing.T) { + t.Parallel() + + t.Run("empty_base_dir_returns_original_transport", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + inner := http.DefaultTransport + rt := NewPassthroughMiddleware(inner, "", "openai", logger, quartz.NewMock(t)) + require.Equal(t, inner, rt) + }) + + t.Run("returns_error_from_inner_round_trip", func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + clk := quartz.NewMock(t) + + innerErr := io.ErrUnexpectedEOF + inner := &mockRoundTripper{ + roundTrip: func(_ *http.Request) (*http.Response, error) { + return nil, innerErr + }, + } + + rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk) + + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/models", nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.ErrorIs(t, err, innerErr) + require.Nil(t, resp) + }) + + t.Run("dumps_request_and_response", func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + clk := quartz.NewMock(t) + + req1Body := `first request` + req2Body := `{"request": 2}` + req2BodyPretty := "{\n \"request\": 2\n}\n" + + callCount := 0 + inner := &mockRoundTripper{ + roundTrip: func(req *http.Request) (*http.Response, error) { + // Verify body is still readable after dump + body, err := io.ReadAll(req.Body) + require.NoError(t, err) + callCount++ + if callCount == 1 { + require.Equal(t, req1Body, string(body)) + } else { + require.Equal(t, req2Body, string(body)) + } + + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Proto: "HTTP/1.1", + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(`{"call": %d}"`, callCount)))), + }, nil + }, + } + + rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk) + + req, err := http.NewRequest(http.MethodPost, "/v1/models", bytes.NewReader([]byte(req1Body))) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer sk-secret-key-12345") + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + // Second request should create new req/resp files + req2, err := http.NewRequest(http.MethodPost, "/v1/conversations", bytes.NewReader([]byte(req2Body))) + require.NoError(t, err) + resp2, err := rt.RoundTrip(req2) + require.NoError(t, err) + _, err = io.ReadAll(resp2.Body) + require.NoError(t, err) + require.NoError(t, resp2.Body.Close()) + + // Validate request files contents + passthroughDir := filepath.Join(tmpDir, "openai", "passthrough") + req1Dump := readDumpFileContent(t, filepath.Join(passthroughDir, "*-v1-models-*"+SuffixRequest)) + req2Dump := readDumpFileContent(t, filepath.Join(passthroughDir, "*-v1-conversations-*"+SuffixRequest)) + + require.Contains(t, req1Dump, req1Body+"\n") + require.Contains(t, req2Dump, req2BodyPretty) + // Sensitive header should be redacted + require.NotContains(t, req1Dump, "sk-secret-key-12345") + require.NotContains(t, req2Dump, "sk-secret-key-12345") + require.Contains(t, req1Dump, "Authorization:") + require.NotContains(t, req2Dump, "Authorization:") + + // Validate response files contents + resp1Dump := readDumpFileContent(t, filepath.Join(passthroughDir, "*-v1-models-*"+SuffixResponse)) + resp2Dump := readDumpFileContent(t, filepath.Join(passthroughDir, "*-v1-conversations-*"+SuffixResponse)) + + require.Contains(t, resp1Dump, "200 OK") + require.Contains(t, resp1Dump, `{"call": 1}"`) + require.Contains(t, resp2Dump, "200 OK") + require.Contains(t, resp2Dump, `{"call": 2}"`) + }) +} + +type mockRoundTripper struct { + roundTrip func(*http.Request) (*http.Response, error) +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return m.roundTrip(req) +} + +// readDumpFileContent reads the content of the dump file matching the pattern. +// Expects exactly one file to match the pattern. +func readDumpFileContent(t *testing.T, pattern string) string { + t.Helper() + matches, err := filepath.Glob(pattern) + require.NoError(t, err) + require.Len(t, matches, 1, "expected exactly one match got: %v %s", len(matches), strings.Join(matches, ", "), pattern) + reqContent, readErr := os.ReadFile(matches[0]) + require.NoError(t, readErr) + return string(reqContent) +} diff --git a/intercept/apidump/headers_test.go b/intercept/apidump/headers_test.go index c9286ab..181eae2 100644 --- a/intercept/apidump/headers_test.go +++ b/intercept/apidump/headers_test.go @@ -101,12 +101,8 @@ func TestWriteRedactedHeaders(t *testing.T) { t.Parallel() d := &dumper{ - baseDir: "/tmp", - provider: "test", - model: "test", - interceptionID: uuid.New(), - clk: quartz.NewMock(t), - logger: slog.Make(), + dumpPath: interceptDumpPath("/tmp", "test", "test", uuid.New(), quartz.NewMock(t)), + logger: slog.Make(), } tests := []struct { diff --git a/intercept/apidump/streaming_test.go b/intercept/apidump/streaming_test.go index 3d76555..653a726 100644 --- a/intercept/apidump/streaming_test.go +++ b/intercept/apidump/streaming_test.go @@ -24,7 +24,7 @@ func TestMiddleware_StreamingResponse(t *testing.T) { clk := quartz.NewMock(t) interceptionID := uuid.New() - middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) @@ -100,7 +100,7 @@ func TestMiddleware_PreservesResponseBody(t *testing.T) { clk := quartz.NewMock(t) interceptionID := uuid.New() - middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index aed8f88..7a755e0 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -46,7 +46,7 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService } // Add API dump middleware if configured - if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 387591d..6a5f512 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -182,7 +182,7 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) // Add API dump middleware if configured - if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, aibconfig.ProviderAnthropic, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, aibconfig.ProviderAnthropic, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } diff --git a/intercept/responses/base.go b/intercept/responses/base.go index b531a71..8b7c3de 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -59,7 +59,7 @@ func (i *responsesInterceptionBase) newResponsesService() responses.ResponseServ } // Add API dump middleware if configured - if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } diff --git a/internal/testutil/mockprovider.go b/internal/testutil/mockprovider.go index 6e268c6..a21ac6a 100644 --- a/internal/testutil/mockprovider.go +++ b/internal/testutil/mockprovider.go @@ -25,6 +25,7 @@ func (m *MockProvider) PassthroughRoutes() []string { return m. func (m *MockProvider) AuthHeader() string { return "Authorization" } func (m *MockProvider) InjectAuthHeader(h *http.Header) {} func (m *MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil } +func (m *MockProvider) APIDumpDir() string { return "" } func (m *MockProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) { if m.InterceptorFunc != nil { return m.InterceptorFunc(w, r, tracer) diff --git a/passthrough.go b/passthrough.go index 0dcef9c..c6b59ed 100644 --- a/passthrough.go +++ b/passthrough.go @@ -8,9 +8,11 @@ import ( "time" "cdr.dev/slog/v3" + "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/tracing" + "github.com/coder/quartz" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" @@ -102,7 +104,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met } // Transport tuned for streaming (no response header timeout). - proxy.Transport = &http.Transport{ + t := &http.Transport{ Proxy: http.ProxyFromEnvironment, ForceAttemptHTTP2: true, MaxIdleConns: 100, @@ -110,6 +112,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } + proxy.Transport = apidump.NewPassthroughMiddleware(t, provider.APIDumpDir(), provider.Name(), logger, quartz.NewReal()) proxy.ServeHTTP(w, r) } diff --git a/provider/anthropic.go b/provider/anthropic.go index be12583..5195cec 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -133,3 +133,7 @@ func (p *Anthropic) InjectAuthHeader(headers *http.Header) { func (p *Anthropic) CircuitBreakerConfig() *config.CircuitBreaker { return p.cfg.CircuitBreaker } + +func (p *Anthropic) APIDumpDir() string { + return p.cfg.APIDumpDir +} diff --git a/provider/copilot.go b/provider/copilot.go index 34fab49..9b128ca 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -109,6 +109,10 @@ func (p *Copilot) CircuitBreakerConfig() *config.CircuitBreaker { return p.circuitBreaker } +func (p *Copilot) APIDumpDir() string { + return p.cfg.APIDumpDir +} + func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) { _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") defer tracing.EndSpanErr(span, &outErr) diff --git a/provider/openai.go b/provider/openai.go index 9f0bf70..730fc68 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -152,3 +152,7 @@ func (p *OpenAI) InjectAuthHeader(headers *http.Header) { func (p *OpenAI) CircuitBreakerConfig() *config.CircuitBreaker { return p.circuitBreaker } + +func (p *OpenAI) APIDumpDir() string { + return p.cfg.APIDumpDir +} diff --git a/provider/provider.go b/provider/provider.go index 0dbe352..f2a70f1 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -75,4 +75,8 @@ type Provider interface { // CircuitBreakerConfig returns the circuit breaker configuration for the provider. CircuitBreakerConfig() *config.CircuitBreaker + + // APIDumpDir returns the directory path for dumping API requests and responses. + // Empty string is returned when API dumping is not enabled. + APIDumpDir() string }