Skip to content

Commit b84b1c3

Browse files
authored
feat: add upstream request/response logging for passthrough routes (#186)
Adds `apidump.NewPassthroughMiddleware` that logs upstream request/response for passthrough routes.
1 parent 498272d commit b84b1c3

14 files changed

Lines changed: 425 additions & 75 deletions

File tree

apidump_integration_test.go

Lines changed: 129 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,32 +49,28 @@ func TestAPIDump(t *testing.T) {
4949
cases := []struct {
5050
name string
5151
fixture []byte
52-
providerName string
5352
providersFunc func(addr, dumpDir string) []aibridge.Provider
5453
createRequestFunc createRequestFunc
5554
}{
5655
{
57-
name: config.ProviderAnthropic,
58-
fixture: fixtures.AntSimple,
59-
providerName: config.ProviderAnthropic,
56+
name: "anthropic",
57+
fixture: fixtures.AntSimple,
6058
providersFunc: func(addr, dumpDir string) []aibridge.Provider {
6159
return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)}
6260
},
6361
createRequestFunc: createAnthropicMessagesReq,
6462
},
6563
{
66-
name: config.ProviderOpenAI,
67-
fixture: fixtures.OaiChatSimple,
68-
providerName: config.ProviderOpenAI,
64+
name: "openai_chat_completions",
65+
fixture: fixtures.OaiChatSimple,
6966
providersFunc: func(addr, dumpDir string) []aibridge.Provider {
7067
return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))}
7168
},
7269
createRequestFunc: createOpenAIChatCompletionsReq,
7370
},
7471
{
75-
name: config.ProviderOpenAI,
76-
fixture: fixtures.OaiResponsesBlockingSimple,
77-
providerName: config.ProviderOpenAI,
72+
name: "openai_responses",
73+
fixture: fixtures.OaiResponsesBlockingSimple,
7874
providersFunc: func(addr, dumpDir string) []aibridge.Provider {
7975
return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))}
8076
},
@@ -176,3 +172,126 @@ func TestAPIDump(t *testing.T) {
176172
})
177173
}
178174
}
175+
176+
func TestAPIDumpPassthrough(t *testing.T) {
177+
t.Parallel()
178+
179+
const responseBody = `{"object":"list","data":[{"id":"gpt-4","object":"model"}]}`
180+
181+
cases := []struct {
182+
name string
183+
providerFunc func(addr string, dumpDir string) aibridge.Provider
184+
requestPath string
185+
expectDumpName string
186+
}{
187+
{
188+
name: "anthropic",
189+
providerFunc: func(addr string, dumpDir string) aibridge.Provider {
190+
return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)
191+
},
192+
requestPath: "/anthropic/v1/models",
193+
expectDumpName: "-v1-models-",
194+
},
195+
{
196+
name: "openai",
197+
providerFunc: func(addr string, dumpDir string) aibridge.Provider {
198+
return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))
199+
},
200+
requestPath: "/openai/v1/models",
201+
expectDumpName: "-models-",
202+
},
203+
{
204+
name: "copilot",
205+
providerFunc: func(addr string, dumpDir string) aibridge.Provider {
206+
return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir})
207+
},
208+
requestPath: "/copilot/models",
209+
expectDumpName: "-models-",
210+
},
211+
}
212+
213+
for _, tc := range cases {
214+
t.Run(tc.name, func(t *testing.T) {
215+
t.Parallel()
216+
217+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
218+
219+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
220+
t.Cleanup(cancel)
221+
222+
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
223+
w.Header().Set("Content-Type", "application/json")
224+
w.Write([]byte(responseBody))
225+
}))
226+
t.Cleanup(upstream.Close)
227+
228+
dumpDir := t.TempDir()
229+
230+
recorderClient := &testutil.MockRecorder{}
231+
prov := tc.providerFunc(upstream.URL, dumpDir)
232+
provs := []aibridge.Provider{prov}
233+
b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
234+
require.NoError(t, err)
235+
236+
bridgeSrv := httptest.NewUnstartedServer(b)
237+
t.Cleanup(bridgeSrv.Close)
238+
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
239+
return aibcontext.AsActor(ctx, userID, nil)
240+
}
241+
bridgeSrv.Start()
242+
243+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, bridgeSrv.URL+tc.requestPath, nil)
244+
require.NoError(t, err)
245+
246+
resp, err := http.DefaultClient.Do(req)
247+
require.NoError(t, err)
248+
defer resp.Body.Close()
249+
250+
// Find dump files in the passthrough directory.
251+
passthroughDir := filepath.Join(dumpDir, tc.name, "passthrough")
252+
var reqDumpFile, respDumpFile string
253+
err = filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error {
254+
if err != nil {
255+
return err
256+
}
257+
if info.IsDir() {
258+
return nil
259+
}
260+
if strings.HasSuffix(path, apidump.SuffixRequest) {
261+
reqDumpFile = path
262+
} else if strings.HasSuffix(path, apidump.SuffixResponse) {
263+
respDumpFile = path
264+
}
265+
return nil
266+
})
267+
require.NoError(t, err, "walking failed: %v", err)
268+
269+
require.NotEmpty(t, reqDumpFile, "request dump file should exist")
270+
require.FileExists(t, reqDumpFile)
271+
require.Contains(t, reqDumpFile, "/passthrough/")
272+
require.Contains(t, reqDumpFile, tc.expectDumpName)
273+
274+
require.NotEmpty(t, respDumpFile, "response dump file should exist")
275+
require.FileExists(t, respDumpFile)
276+
require.Contains(t, respDumpFile, "/passthrough/")
277+
require.Contains(t, respDumpFile, tc.expectDumpName)
278+
279+
// Verify request dump.
280+
reqDumpData, err := os.ReadFile(reqDumpFile)
281+
require.NoError(t, err)
282+
dumpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(reqDumpData)))
283+
require.NoError(t, err)
284+
require.Equal(t, http.MethodGet, dumpReq.Method)
285+
286+
// Verify response dump.
287+
respDumpData, err := os.ReadFile(respDumpFile)
288+
require.NoError(t, err)
289+
dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil)
290+
require.NoError(t, err)
291+
require.Equal(t, http.StatusOK, dumpResp.StatusCode)
292+
dumpRespBody, err := io.ReadAll(dumpResp.Body)
293+
require.NoError(t, err)
294+
require.JSONEq(t, responseBody, string(dumpRespBody))
295+
})
296+
}
297+
}

intercept/apidump/apidump.go

Lines changed: 107 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,21 @@ type MiddlewareNext = func(*http.Request) (*http.Response, error)
3232
// Middleware is an HTTP middleware function compatible with SDK WithMiddleware options.
3333
type Middleware = func(*http.Request, MiddlewareNext) (*http.Response, error)
3434

35-
// NewMiddleware returns a middleware function that dumps requests and responses to files.
36-
// Files are written to the path returned by DumpPath.
35+
// NewBridgeMiddleware returns a middleware function that dumps requests and responses to files.
3736
// If baseDir is empty, returns nil (no middleware).
38-
func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, logger slog.Logger, clk quartz.Clock) Middleware {
37+
func NewBridgeMiddleware(baseDir string, provider string, model string, interceptionID uuid.UUID, logger slog.Logger, clk quartz.Clock) Middleware {
3938
if baseDir == "" {
4039
return nil
4140
}
4241

4342
d := &dumper{
44-
baseDir: baseDir,
45-
provider: provider,
46-
model: model,
47-
interceptionID: interceptionID,
48-
clk: clk,
49-
logger: logger,
43+
dumpPath: interceptDumpPath(baseDir, provider, model, interceptionID, clk),
44+
logger: logger,
5045
}
5146

5247
return func(req *http.Request, next MiddlewareNext) (*http.Response, error) {
5348
if err := d.dumpRequest(req); err != nil {
54-
logger.Named("apidump").Warn(context.Background(), "failed to dump request", slog.Error(err))
49+
logger.Named("apidump").Warn(req.Context(), "failed to dump request", slog.Error(err))
5550
}
5651

5752
// TODO: https://github.com/coder/aibridge/issues/129
@@ -61,24 +56,20 @@ func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, lo
6156
}
6257

6358
if err := d.dumpResponse(resp); err != nil {
64-
logger.Named("apidump").Warn(context.Background(), "failed to dump response", slog.Error(err))
59+
logger.Named("apidump").Warn(req.Context(), "failed to dump response", slog.Error(err))
6560
}
6661

6762
return resp, nil
6863
}
6964
}
7065

7166
type dumper struct {
72-
baseDir string
73-
provider string
74-
model string
75-
interceptionID uuid.UUID
76-
clk quartz.Clock
77-
logger slog.Logger
67+
dumpPath string
68+
logger slog.Logger
7869
}
7970

8071
func (d *dumper) dumpRequest(req *http.Request) error {
81-
dumpPath := d.path(SuffixRequest)
72+
dumpPath := d.dumpPath + SuffixRequest
8273
if err := os.MkdirAll(filepath.Dir(dumpPath), 0o755); err != nil {
8374
return fmt.Errorf("create dump dir: %w", err)
8475
}
@@ -98,25 +89,44 @@ func (d *dumper) dumpRequest(req *http.Request) error {
9889

9990
// Build raw HTTP request format
10091
var buf bytes.Buffer
101-
fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto)
102-
d.writeRedactedHeaders(&buf, req.Header, sensitiveRequestHeaders, map[string]string{
92+
_, err := fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto)
93+
if err != nil {
94+
return fmt.Errorf("write request uri: %w", err)
95+
}
96+
err = d.writeRedactedHeaders(&buf, req.Header, sensitiveRequestHeaders, map[string]string{
10397
"Content-Length": fmt.Sprintf("%d", len(prettyBody)),
10498
})
99+
if err != nil {
100+
return fmt.Errorf("write request headers: %w", err)
101+
}
105102

106-
fmt.Fprintf(&buf, "\r\n")
103+
_, err = fmt.Fprintf(&buf, "\r\n")
104+
if err != nil {
105+
return fmt.Errorf("write request header terminator: %w", err)
106+
}
107107
buf.Write(prettyBody)
108+
buf.WriteByte('\n')
108109

109110
return os.WriteFile(dumpPath, buf.Bytes(), 0o644)
110111
}
111112

112113
func (d *dumper) dumpResponse(resp *http.Response) error {
113-
dumpPath := d.path(SuffixResponse)
114+
dumpPath := d.dumpPath + SuffixResponse
114115

115116
// Build raw HTTP response headers
116117
var headerBuf bytes.Buffer
117-
fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status)
118-
d.writeRedactedHeaders(&headerBuf, resp.Header, sensitiveResponseHeaders, nil)
119-
fmt.Fprintf(&headerBuf, "\r\n")
118+
_, err := fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status)
119+
if err != nil {
120+
return fmt.Errorf("write response status: %w", err)
121+
}
122+
err = d.writeRedactedHeaders(&headerBuf, resp.Header, sensitiveResponseHeaders, nil)
123+
if err != nil {
124+
return fmt.Errorf("write response headers: %w", err)
125+
}
126+
_, err = fmt.Fprintf(&headerBuf, "\r\n")
127+
if err != nil {
128+
return fmt.Errorf("write response header terminator: %w", err)
129+
}
120130

121131
// Wrap the response body to capture it as it streams
122132
if resp.Body != nil {
@@ -141,7 +151,7 @@ func (d *dumper) dumpResponse(resp *http.Response) error {
141151
// for deterministic output.
142152
// `sensitive` and `overrides` must both supply keys in canoncialized form.
143153
// See [textproto.MIMEHeader].
144-
func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) {
154+
func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) error {
145155
// Collect all header keys including overrides.
146156
headerKeys := make([]string, 0, len(headers)+len(overrides))
147157
seen := make(map[string]struct{}, len(headers)+len(overrides))
@@ -163,7 +173,10 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv
163173
// If no values exist but we have an override, use that.
164174
if len(values) == 0 {
165175
if override, ok := overrides[key]; ok {
166-
fmt.Fprintf(w, "%s: %s\r\n", key, override)
176+
_, err := fmt.Fprintf(w, "%s: %s\r\n", key, override)
177+
if err != nil {
178+
return fmt.Errorf("write response header override: %w", err)
179+
}
167180
}
168181
continue
169182
}
@@ -175,16 +188,71 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv
175188
if isSensitive {
176189
value = redactHeaderValue(value)
177190
}
178-
fmt.Fprintf(w, "%s: %s\r\n", key, value)
191+
_, err := fmt.Fprintf(w, "%s: %s\r\n", key, value)
192+
if err != nil {
193+
return fmt.Errorf("write response headers: %w", err)
194+
}
179195
}
180196
}
197+
return nil
198+
}
199+
200+
// interceptDumpPath returns the base file path (without req/resp suffix) for an interception dump.
201+
func interceptDumpPath(baseDir string, provider string, model string, interceptionID uuid.UUID, clk quartz.Clock) string {
202+
safeModel := strings.ReplaceAll(model, "/", "-")
203+
return filepath.Join(baseDir, provider, safeModel, fmt.Sprintf("%d-%s", clk.Now().UTC().UnixMilli(), interceptionID))
204+
}
205+
206+
// passthroughDumpPath returns the base file path (without req/resp suffix) for a passthrough dump.
207+
func passthroughDumpPath(baseDir string, provider string, urlPath string, clk quartz.Clock) string {
208+
safeURLPath := strings.ReplaceAll(strings.TrimPrefix(urlPath, "/"), "/", "-")
209+
return filepath.Join(baseDir, provider, "passthrough", fmt.Sprintf("%d-%s-%s", clk.Now().UTC().UnixMilli(), safeURLPath, uuid.NewString()[:4]))
210+
}
211+
212+
// NewPassthroughMiddleware returns http.RoundTripper that dumps requests and responses to files.
213+
// If baseDir is empty, returns the original transport unchanged.
214+
// Used for logging in pass through routes.
215+
func NewPassthroughMiddleware(transport http.RoundTripper, baseDir string, provider string, logger slog.Logger, clk quartz.Clock) http.RoundTripper {
216+
if baseDir == "" {
217+
return transport
218+
}
219+
return &dumpRoundTripper{
220+
inner: transport,
221+
baseDir: baseDir,
222+
provider: provider,
223+
clk: clk,
224+
logger: logger,
225+
}
181226
}
182227

183-
// path returns the path to a request/response dump file for a given interception.
184-
// suffix should be SuffixRequest or SuffixResponse.
185-
func (d *dumper) path(suffix string) string {
186-
safeModel := strings.ReplaceAll(d.model, "/", "-")
187-
return filepath.Join(d.baseDir, d.provider, safeModel, fmt.Sprintf("%d-%s%s", d.clk.Now().UTC().UnixMilli(), d.interceptionID, suffix))
228+
type dumpRoundTripper struct {
229+
inner http.RoundTripper
230+
baseDir string
231+
provider string
232+
clk quartz.Clock
233+
logger slog.Logger
234+
}
235+
236+
func (rt *dumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
237+
dumper := dumper{
238+
dumpPath: passthroughDumpPath(rt.baseDir, rt.provider, req.URL.Path, rt.clk),
239+
logger: rt.logger,
240+
}
241+
242+
if err := dumper.dumpRequest(req); err != nil {
243+
dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough request", slog.Error(err))
244+
}
245+
246+
resp, err := rt.inner.RoundTrip(req)
247+
if err != nil {
248+
return resp, err
249+
}
250+
251+
if err := dumper.dumpResponse(resp); err != nil {
252+
dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough response", slog.Error(err))
253+
}
254+
255+
return resp, nil
188256
}
189257

190258
// prettyPrintJSON returns indented JSON if body is valid JSON, otherwise returns body as-is.
@@ -194,12 +262,11 @@ func prettyPrintJSON(body []byte) []byte {
194262
if len(body) == 0 {
195263
return body
196264
}
197-
result := pretty.Pretty(body)
198-
// pretty.Pretty returns a truncated/modified result for invalid JSON,
199-
// so check if the result is valid JSON; if not, return the original.
200-
if !json.Valid(result) {
201-
return body
265+
266+
result := body
267+
if json.Valid(body) {
268+
result = pretty.Pretty(body)
202269
}
203-
// Trim trailing newline added by pretty.Pretty.
204-
return bytes.TrimSuffix(result, []byte("\n"))
270+
271+
return result
205272
}

0 commit comments

Comments
 (0)