Skip to content

Commit 36705e3

Browse files
committed
Made NewRoundTripperMiddleware reusable
1 parent 92ffe95 commit 36705e3

2 files changed

Lines changed: 78 additions & 29 deletions

File tree

intercept/apidump/apidump.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -218,31 +218,39 @@ func NewRoundTripperMiddleware(transport http.RoundTripper, baseDir string, prov
218218
return transport
219219
}
220220
return &dumpRoundTripper{
221-
inner: transport,
222-
dumper: dumper{
223-
dumpPath: passthroughDumpPath(baseDir, provider, clk),
224-
logger: logger,
225-
},
221+
inner: transport,
222+
baseDir: baseDir,
223+
provider: provider,
224+
clk: clk,
225+
logger: logger,
226226
}
227227
}
228228

229229
type dumpRoundTripper struct {
230-
inner http.RoundTripper
231-
dumper dumper
230+
inner http.RoundTripper
231+
baseDir string
232+
provider string
233+
clk quartz.Clock
234+
logger slog.Logger
232235
}
233236

234237
func (rt *dumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
235-
if err := rt.dumper.dumpRequest(req); err != nil {
236-
rt.dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough request", slog.Error(err))
238+
dumper := dumper{
239+
dumpPath: passthroughDumpPath(rt.baseDir, rt.provider, rt.clk),
240+
logger: rt.logger,
241+
}
242+
243+
if err := dumper.dumpRequest(req); err != nil {
244+
dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough request", slog.Error(err))
237245
}
238246

239247
resp, err := rt.inner.RoundTrip(req)
240248
if err != nil {
241249
return resp, err
242250
}
243251

244-
if err := rt.dumper.dumpResponse(resp); err != nil {
245-
rt.dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough response", slog.Error(err))
252+
if err := dumper.dumpResponse(resp); err != nil {
253+
dumper.logger.Named("apidump").Warn(req.Context(), "failed to dump passthrough response", slog.Error(err))
246254
}
247255

248256
return resp, nil

intercept/apidump/apidump_test.go

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package apidump
22

33
import (
44
"bytes"
5+
"fmt"
56
"io"
67
"net/http"
78
"os"
89
"path/filepath"
10+
"sort"
911
"testing"
1012

1113
"cdr.dev/slog/v3"
@@ -367,55 +369,94 @@ func TestRoundTripperMiddleware(t *testing.T) {
367369
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
368370
clk := quartz.NewMock(t)
369371

372+
req1Body := `first request`
373+
req2Body := `{"request": 2}`
374+
req2BodyPretty := "{\n \"request\": 2\n}\n"
375+
376+
callCount := 0
370377
inner := &mockRoundTripper{
371378
roundTrip: func(req *http.Request) (*http.Response, error) {
372379
// Verify body is still readable after dump
373380
body, err := io.ReadAll(req.Body)
374381
require.NoError(t, err)
375-
require.Equal(t, `{"request": true}`, string(body))
382+
callCount++
383+
if callCount == 1 {
384+
require.Equal(t, req1Body, string(body))
385+
} else {
386+
require.Equal(t, req2Body, string(body))
387+
}
376388

377389
return &http.Response{
378390
StatusCode: http.StatusOK,
379391
Status: "200 OK",
380392
Proto: "HTTP/1.1",
381393
Header: http.Header{"Content-Type": []string{"application/json"}},
382-
Body: io.NopCloser(bytes.NewReader([]byte(`{"response": true}`))),
394+
Body: io.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(`{"call": %d}"`, callCount)))),
383395
}, nil
384396
},
385397
}
386398

387399
rt := NewRoundTripperMiddleware(inner, tmpDir, "openai", logger, clk)
388400

389-
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/models", bytes.NewReader([]byte(`{"request": true}`)))
401+
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/models", bytes.NewReader([]byte(req1Body)))
390402
require.NoError(t, err)
391403
req.Header.Set("Authorization", "Bearer sk-secret-key-12345")
392-
393404
resp, err := rt.RoundTrip(req)
394405
require.NoError(t, err)
395-
396-
// Must read and close response body to trigger the streaming dump
397406
_, err = io.ReadAll(resp.Body)
398407
require.NoError(t, err)
399408
require.NoError(t, resp.Body.Close())
400409

401-
// Verify files are in passthrough directory
410+
// Second request should create new req/resp files
411+
req2, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/models", bytes.NewReader([]byte(req2Body)))
412+
require.NoError(t, err)
413+
resp2, err := rt.RoundTrip(req2)
414+
require.NoError(t, err)
415+
_, err = io.ReadAll(resp2.Body)
416+
require.NoError(t, err)
417+
require.NoError(t, resp2.Body.Close())
418+
419+
// Validate request files contents
402420
passthroughDir := filepath.Join(tmpDir, "openai", "passthrough")
403-
reqDumpPath := findDumpFile(t, passthroughDir, SuffixRequest)
404-
reqContent, err := os.ReadFile(reqDumpPath)
421+
reqPattern := filepath.Join(passthroughDir, "*"+SuffixRequest)
422+
reqMatches, err := filepath.Glob(reqPattern)
405423
require.NoError(t, err)
424+
require.Len(t, reqMatches, 2, "expected exactly two %s files in %s", SuffixRequest, passthroughDir)
406425

407-
require.Contains(t, string(reqContent), "POST")
408-
require.Contains(t, string(reqContent), `"request": true`)
409-
// Sensitive header should be redacted
410-
require.NotContains(t, string(reqContent), "sk-secret-key-12345")
411-
require.Contains(t, string(reqContent), "Authorization:")
426+
reqContents := make([]string, 0, len(reqMatches))
427+
for _, reqDumpPath := range reqMatches {
428+
reqContent, readErr := os.ReadFile(reqDumpPath)
429+
require.NoError(t, readErr)
430+
reqContents = append(reqContents, string(reqContent))
431+
}
412432

413-
respDumpPath := findDumpFile(t, passthroughDir, SuffixResponse)
414-
respContent, err := os.ReadFile(respDumpPath)
433+
sort.Strings(reqContents)
434+
require.Contains(t, reqContents[0], req1Body+"\n")
435+
require.Contains(t, reqContents[1], req2BodyPretty)
436+
// Sensitive header should be redacted
437+
require.NotContains(t, reqContents[0], "sk-secret-key-12345")
438+
require.NotContains(t, reqContents[1], "sk-secret-key-12345")
439+
require.Contains(t, reqContents[0], "Authorization:")
440+
require.NotContains(t, reqContents[1], "Authorization:")
441+
442+
// Validate response files contents
443+
respPattern := filepath.Join(passthroughDir, "*"+SuffixResponse)
444+
respMatches, err := filepath.Glob(respPattern)
415445
require.NoError(t, err)
446+
require.Len(t, respMatches, 2, "expected exactly two %s files in %s", SuffixResponse, passthroughDir)
447+
448+
respContents := make([]string, 0, len(respMatches))
449+
for _, respDumpPath := range respMatches {
450+
respContent, readErr := os.ReadFile(respDumpPath)
451+
require.NoError(t, readErr)
452+
respContents = append(respContents, string(respContent))
453+
}
416454

417-
require.Contains(t, string(respContent), "200 OK")
418-
require.Contains(t, string(respContent), `"response": true`)
455+
sort.Strings(respContents)
456+
require.Contains(t, respContents[0], "200 OK")
457+
require.Contains(t, respContents[0], `{"call": 1}"`)
458+
require.Contains(t, respContents[1], "200 OK")
459+
require.Contains(t, respContents[1], `{"call": 2}"`)
419460
})
420461
}
421462

0 commit comments

Comments
 (0)