Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 129 additions & 10 deletions apidump_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,32 +49,28 @@ func TestAPIDump(t *testing.T) {
cases := []struct {
name string
fixture []byte
providerName string
providersFunc func(addr, dumpDir string) []aibridge.Provider
createRequestFunc createRequestFunc
}{
{
name: config.ProviderAnthropic,
fixture: fixtures.AntSimple,
providerName: config.ProviderAnthropic,
name: "anthropic",
fixture: fixtures.AntSimple,
providersFunc: func(addr, dumpDir string) []aibridge.Provider {
return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)}
},
createRequestFunc: createAnthropicMessagesReq,
},
{
name: config.ProviderOpenAI,
fixture: fixtures.OaiChatSimple,
providerName: config.ProviderOpenAI,
name: "openai_chat_completions",
fixture: fixtures.OaiChatSimple,
providersFunc: func(addr, dumpDir string) []aibridge.Provider {
return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))}
},
createRequestFunc: createOpenAIChatCompletionsReq,
},
{
name: config.ProviderOpenAI,
fixture: fixtures.OaiResponsesBlockingSimple,
providerName: config.ProviderOpenAI,
name: "openai_responses",
fixture: fixtures.OaiResponsesBlockingSimple,
providersFunc: func(addr, dumpDir string) []aibridge.Provider {
return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))}
},
Expand Down Expand Up @@ -176,3 +172,126 @@ func TestAPIDump(t *testing.T) {
})
}
}

func TestAPIDumpPassthrough(t *testing.T) {
t.Parallel()

const responseBody = `{"object":"list","data":[{"id":"gpt-4","object":"model"}]}`

cases := []struct {
name string
providerFunc func(addr string, dumpDir string) aibridge.Provider
requestPath string
expectDumpName string
}{
{
name: "anthropic",
providerFunc: func(addr string, dumpDir string) aibridge.Provider {
return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)
},
requestPath: "/anthropic/v1/models",
expectDumpName: "-v1-models-",
},
{
name: "openai",
providerFunc: func(addr string, dumpDir string) aibridge.Provider {
return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))
},
requestPath: "/openai/v1/models",
expectDumpName: "-models-",
},
{
name: "copilot",
providerFunc: func(addr string, dumpDir string) aibridge.Provider {
return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir})
},
requestPath: "/copilot/models",
expectDumpName: "-models-",
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(responseBody))
}))
t.Cleanup(upstream.Close)

dumpDir := t.TempDir()

recorderClient := &testutil.MockRecorder{}
prov := tc.providerFunc(upstream.URL, dumpDir)
provs := []aibridge.Provider{prov}
b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
require.NoError(t, err)

bridgeSrv := httptest.NewUnstartedServer(b)
t.Cleanup(bridgeSrv.Close)
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
return aibcontext.AsActor(ctx, userID, nil)
}
bridgeSrv.Start()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, bridgeSrv.URL+tc.requestPath, nil)
require.NoError(t, err)

resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Find dump files in the passthrough directory.
passthroughDir := filepath.Join(dumpDir, tc.name, "passthrough")
var reqDumpFile, respDumpFile string
err = filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if strings.HasSuffix(path, apidump.SuffixRequest) {
reqDumpFile = path
} else if strings.HasSuffix(path, apidump.SuffixResponse) {
respDumpFile = path
}
return nil
})
require.NoError(t, err, "walking failed: %v", err)

require.NotEmpty(t, reqDumpFile, "request dump file should exist")
require.FileExists(t, reqDumpFile)
require.Contains(t, reqDumpFile, "/passthrough/")
require.Contains(t, reqDumpFile, tc.expectDumpName)

require.NotEmpty(t, respDumpFile, "response dump file should exist")
require.FileExists(t, respDumpFile)
require.Contains(t, respDumpFile, "/passthrough/")
require.Contains(t, respDumpFile, tc.expectDumpName)

// Verify request dump.
reqDumpData, err := os.ReadFile(reqDumpFile)
require.NoError(t, err)
dumpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(reqDumpData)))
require.NoError(t, err)
require.Equal(t, http.MethodGet, dumpReq.Method)

// Verify response dump.
respDumpData, err := os.ReadFile(respDumpFile)
require.NoError(t, err)
dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil)
require.NoError(t, err)
require.Equal(t, http.StatusOK, dumpResp.StatusCode)
dumpRespBody, err := io.ReadAll(dumpResp.Body)
require.NoError(t, err)
require.JSONEq(t, responseBody, string(dumpRespBody))
})
}
}
147 changes: 107 additions & 40 deletions intercept/apidump/apidump.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,21 @@ type MiddlewareNext = func(*http.Request) (*http.Response, error)
// Middleware is an HTTP middleware function compatible with SDK WithMiddleware options.
type Middleware = func(*http.Request, MiddlewareNext) (*http.Response, error)

// NewMiddleware returns a middleware function that dumps requests and responses to files.
// Files are written to the path returned by DumpPath.
// NewBridgeMiddleware returns a middleware function that dumps requests and responses to files.
// If baseDir is empty, returns nil (no middleware).
func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, logger slog.Logger, clk quartz.Clock) Middleware {
func NewBridgeMiddleware(baseDir string, provider string, model string, interceptionID uuid.UUID, logger slog.Logger, clk quartz.Clock) Middleware {
if baseDir == "" {
return nil
}

d := &dumper{
baseDir: baseDir,
provider: provider,
model: model,
interceptionID: interceptionID,
clk: clk,
logger: logger,
dumpPath: interceptDumpPath(baseDir, provider, model, interceptionID, clk),
logger: logger,
}

return func(req *http.Request, next MiddlewareNext) (*http.Response, error) {
if err := d.dumpRequest(req); err != nil {
logger.Named("apidump").Warn(context.Background(), "failed to dump request", slog.Error(err))
logger.Named("apidump").Warn(req.Context(), "failed to dump request", slog.Error(err))
}

// TODO: https://github.com/coder/aibridge/issues/129
Expand All @@ -61,24 +56,20 @@ func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, lo
}

if err := d.dumpResponse(resp); err != nil {
logger.Named("apidump").Warn(context.Background(), "failed to dump response", slog.Error(err))
logger.Named("apidump").Warn(req.Context(), "failed to dump response", slog.Error(err))
}

return resp, nil
}
}

type dumper struct {
baseDir string
provider string
model string
interceptionID uuid.UUID
clk quartz.Clock
logger slog.Logger
dumpPath string
logger slog.Logger
}

func (d *dumper) dumpRequest(req *http.Request) error {
dumpPath := d.path(SuffixRequest)
dumpPath := d.dumpPath + SuffixRequest
if err := os.MkdirAll(filepath.Dir(dumpPath), 0o755); err != nil {
return fmt.Errorf("create dump dir: %w", err)
}
Expand All @@ -98,25 +89,44 @@ func (d *dumper) dumpRequest(req *http.Request) error {

// Build raw HTTP request format
var buf bytes.Buffer
fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto)
d.writeRedactedHeaders(&buf, req.Header, sensitiveRequestHeaders, map[string]string{
_, err := fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto)
if err != nil {
return fmt.Errorf("write request uri: %w", err)
}
err = d.writeRedactedHeaders(&buf, req.Header, sensitiveRequestHeaders, map[string]string{
"Content-Length": fmt.Sprintf("%d", len(prettyBody)),
})
if err != nil {
return fmt.Errorf("write request headers: %w", err)
}

fmt.Fprintf(&buf, "\r\n")
_, err = fmt.Fprintf(&buf, "\r\n")
if err != nil {
return fmt.Errorf("write request header terminator: %w", err)
}
buf.Write(prettyBody)
buf.WriteByte('\n')

return os.WriteFile(dumpPath, buf.Bytes(), 0o644)
}

func (d *dumper) dumpResponse(resp *http.Response) error {
dumpPath := d.path(SuffixResponse)
dumpPath := d.dumpPath + SuffixResponse

// Build raw HTTP response headers
var headerBuf bytes.Buffer
fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status)
d.writeRedactedHeaders(&headerBuf, resp.Header, sensitiveResponseHeaders, nil)
fmt.Fprintf(&headerBuf, "\r\n")
_, err := fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status)
if err != nil {
return fmt.Errorf("write response status: %w", err)
}
err = d.writeRedactedHeaders(&headerBuf, resp.Header, sensitiveResponseHeaders, nil)
if err != nil {
return fmt.Errorf("write response headers: %w", err)
}
_, err = fmt.Fprintf(&headerBuf, "\r\n")
if err != nil {
return fmt.Errorf("write response header terminator: %w", err)
}

// Wrap the response body to capture it as it streams
if resp.Body != nil {
Expand All @@ -141,7 +151,7 @@ func (d *dumper) dumpResponse(resp *http.Response) error {
// for deterministic output.
// `sensitive` and `overrides` must both supply keys in canoncialized form.
// See [textproto.MIMEHeader].
func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) {
func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) error {
// Collect all header keys including overrides.
headerKeys := make([]string, 0, len(headers)+len(overrides))
seen := make(map[string]struct{}, len(headers)+len(overrides))
Expand All @@ -163,7 +173,10 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv
// If no values exist but we have an override, use that.
if len(values) == 0 {
if override, ok := overrides[key]; ok {
fmt.Fprintf(w, "%s: %s\r\n", key, override)
_, err := fmt.Fprintf(w, "%s: %s\r\n", key, override)
if err != nil {
return fmt.Errorf("write response header override: %w", err)
Copy link
Contributor

@ssncferreira ssncferreira Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is used for both request/responses, right? Consider removing response from here, to avoid confusion. Same below 👀

}
}
continue
}
Expand All @@ -175,16 +188,71 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv
if isSensitive {
value = redactHeaderValue(value)
}
fmt.Fprintf(w, "%s: %s\r\n", key, value)
_, err := fmt.Fprintf(w, "%s: %s\r\n", key, value)
if err != nil {
return fmt.Errorf("write response headers: %w", err)
}
}
}
return nil
}

// interceptDumpPath returns the base file path (without req/resp suffix) for an interception dump.
func interceptDumpPath(baseDir string, provider string, model string, interceptionID uuid.UUID, clk quartz.Clock) string {
safeModel := strings.ReplaceAll(model, "/", "-")
return filepath.Join(baseDir, provider, safeModel, fmt.Sprintf("%d-%s", clk.Now().UTC().UnixMilli(), interceptionID))
}

// passthroughDumpPath returns the base file path (without req/resp suffix) for a passthrough dump.
func passthroughDumpPath(baseDir string, provider string, urlPath string, clk quartz.Clock) string {
safeURLPath := strings.ReplaceAll(strings.TrimPrefix(urlPath, "/"), "/", "-")
return filepath.Join(baseDir, provider, "passthrough", fmt.Sprintf("%d-%s-%s", clk.Now().UTC().UnixMilli(), safeURLPath, uuid.NewString()[:4]))
}

// NewPassthroughMiddleware returns http.RoundTripper that dumps requests and responses to files.
// If baseDir is empty, returns the original transport unchanged.
// Used for logging in pass through routes.
func NewPassthroughMiddleware(transport http.RoundTripper, baseDir string, provider string, logger slog.Logger, clk quartz.Clock) http.RoundTripper {
if baseDir == "" {
return transport
}
return &dumpRoundTripper{
inner: transport,
baseDir: baseDir,
provider: provider,
clk: clk,
logger: logger,
}
}

// path returns the path to a request/response dump file for a given interception.
// suffix should be SuffixRequest or SuffixResponse.
func (d *dumper) path(suffix string) string {
safeModel := strings.ReplaceAll(d.model, "/", "-")
return filepath.Join(d.baseDir, d.provider, safeModel, fmt.Sprintf("%d-%s%s", d.clk.Now().UTC().UnixMilli(), d.interceptionID, suffix))
type dumpRoundTripper struct {
inner http.RoundTripper
baseDir string
provider string
clk quartz.Clock
logger slog.Logger
}

func (rt *dumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
dumper := dumper{
dumpPath: passthroughDumpPath(rt.baseDir, rt.provider, req.URL.Path, rt.clk),
logger: rt.logger,
}

if err := dumper.dumpRequest(req); err != nil {
dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough request", slog.Error(err))
}

resp, err := rt.inner.RoundTrip(req)
if err != nil {
return resp, err
}

if err := dumper.dumpResponse(resp); err != nil {
dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough response", slog.Error(err))
}

return resp, nil
}

// prettyPrintJSON returns indented JSON if body is valid JSON, otherwise returns body as-is.
Expand All @@ -194,12 +262,11 @@ func prettyPrintJSON(body []byte) []byte {
if len(body) == 0 {
return body
}
result := pretty.Pretty(body)
// pretty.Pretty returns a truncated/modified result for invalid JSON,
// so check if the result is valid JSON; if not, return the original.
if !json.Valid(result) {
return body

result := body
if json.Valid(body) {
result = pretty.Pretty(body)
}
Comment on lines +266 to +269
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this change? Since result is body, if json.Valid returns false we'll be appending a newline to it, couldn't this impact the caller's slice?

// Trim trailing newline added by pretty.Pretty.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there a reson for removing trailing new line?
Viewing dumps in terminal without it is annoying.

return bytes.TrimSuffix(result, []byte("\n"))

return result
}
Loading