@@ -32,26 +32,21 @@ type MiddlewareNext = func(*http.Request) (*http.Response, error)
3232// Middleware is an HTTP middleware function compatible with SDK WithMiddleware options.
3333type Middleware = func (* http.Request , MiddlewareNext ) (* http.Response , error )
3434
35- // NewMiddleware returns a middleware function that dumps requests and responses to files.
36- // Files are written to the path returned by DumpPath.
35+ // NewBridgeMiddleware returns a middleware function that dumps requests and responses to files.
3736// If baseDir is empty, returns nil (no middleware).
38- func NewMiddleware (baseDir , provider , model string , interceptionID uuid.UUID , logger slog.Logger , clk quartz.Clock ) Middleware {
37+ func NewBridgeMiddleware (baseDir string , provider string , model string , interceptionID uuid.UUID , logger slog.Logger , clk quartz.Clock ) Middleware {
3938 if baseDir == "" {
4039 return nil
4140 }
4241
4342 d := & dumper {
44- baseDir : baseDir ,
45- provider : provider ,
46- model : model ,
47- interceptionID : interceptionID ,
48- clk : clk ,
49- logger : logger ,
43+ dumpPath : interceptDumpPath (baseDir , provider , model , interceptionID , clk ),
44+ logger : logger ,
5045 }
5146
5247 return func (req * http.Request , next MiddlewareNext ) (* http.Response , error ) {
5348 if err := d .dumpRequest (req ); err != nil {
54- logger .Named ("apidump" ).Warn (context . Background (), "failed to dump request" , slog .Error (err ))
49+ logger .Named ("apidump" ).Warn (req . Context (), "failed to dump request" , slog .Error (err ))
5550 }
5651
5752 // TODO: https://github.com/coder/aibridge/issues/129
@@ -61,24 +56,20 @@ func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, lo
6156 }
6257
6358 if err := d .dumpResponse (resp ); err != nil {
64- logger .Named ("apidump" ).Warn (context . Background (), "failed to dump response" , slog .Error (err ))
59+ logger .Named ("apidump" ).Warn (req . Context (), "failed to dump response" , slog .Error (err ))
6560 }
6661
6762 return resp , nil
6863 }
6964}
7065
7166type dumper struct {
72- baseDir string
73- provider string
74- model string
75- interceptionID uuid.UUID
76- clk quartz.Clock
77- logger slog.Logger
67+ dumpPath string
68+ logger slog.Logger
7869}
7970
8071func (d * dumper ) dumpRequest (req * http.Request ) error {
81- dumpPath := d .path ( SuffixRequest )
72+ dumpPath := d .dumpPath + SuffixRequest
8273 if err := os .MkdirAll (filepath .Dir (dumpPath ), 0o755 ); err != nil {
8374 return fmt .Errorf ("create dump dir: %w" , err )
8475 }
@@ -98,25 +89,44 @@ func (d *dumper) dumpRequest(req *http.Request) error {
9889
9990 // Build raw HTTP request format
10091 var buf bytes.Buffer
101- fmt .Fprintf (& buf , "%s %s %s\r \n " , req .Method , req .URL .RequestURI (), req .Proto )
102- d .writeRedactedHeaders (& buf , req .Header , sensitiveRequestHeaders , map [string ]string {
92+ _ , err := fmt .Fprintf (& buf , "%s %s %s\r \n " , req .Method , req .URL .RequestURI (), req .Proto )
93+ if err != nil {
94+ return fmt .Errorf ("write request uri: %w" , err )
95+ }
96+ err = d .writeRedactedHeaders (& buf , req .Header , sensitiveRequestHeaders , map [string ]string {
10397 "Content-Length" : fmt .Sprintf ("%d" , len (prettyBody )),
10498 })
99+ if err != nil {
100+ return fmt .Errorf ("write request headers: %w" , err )
101+ }
105102
106- fmt .Fprintf (& buf , "\r \n " )
103+ _ , err = fmt .Fprintf (& buf , "\r \n " )
104+ if err != nil {
105+ return fmt .Errorf ("write request header terminator: %w" , err )
106+ }
107107 buf .Write (prettyBody )
108+ buf .WriteByte ('\n' )
108109
109110 return os .WriteFile (dumpPath , buf .Bytes (), 0o644 )
110111}
111112
112113func (d * dumper ) dumpResponse (resp * http.Response ) error {
113- dumpPath := d .path ( SuffixResponse )
114+ dumpPath := d .dumpPath + SuffixResponse
114115
115116 // Build raw HTTP response headers
116117 var headerBuf bytes.Buffer
117- fmt .Fprintf (& headerBuf , "%s %s\r \n " , resp .Proto , resp .Status )
118- d .writeRedactedHeaders (& headerBuf , resp .Header , sensitiveResponseHeaders , nil )
119- fmt .Fprintf (& headerBuf , "\r \n " )
118+ _ , err := fmt .Fprintf (& headerBuf , "%s %s\r \n " , resp .Proto , resp .Status )
119+ if err != nil {
120+ return fmt .Errorf ("write response status: %w" , err )
121+ }
122+ err = d .writeRedactedHeaders (& headerBuf , resp .Header , sensitiveResponseHeaders , nil )
123+ if err != nil {
124+ return fmt .Errorf ("write response headers: %w" , err )
125+ }
126+ _ , err = fmt .Fprintf (& headerBuf , "\r \n " )
127+ if err != nil {
128+ return fmt .Errorf ("write response header terminator: %w" , err )
129+ }
120130
121131 // Wrap the response body to capture it as it streams
122132 if resp .Body != nil {
@@ -141,7 +151,7 @@ func (d *dumper) dumpResponse(resp *http.Response) error {
141151// for deterministic output.
142152// `sensitive` and `overrides` must both supply keys in canoncialized form.
143153// See [textproto.MIMEHeader].
144- func (d * dumper ) writeRedactedHeaders (w io.Writer , headers http.Header , sensitive map [string ]struct {}, overrides map [string ]string ) {
154+ func (d * dumper ) writeRedactedHeaders (w io.Writer , headers http.Header , sensitive map [string ]struct {}, overrides map [string ]string ) error {
145155 // Collect all header keys including overrides.
146156 headerKeys := make ([]string , 0 , len (headers )+ len (overrides ))
147157 seen := make (map [string ]struct {}, len (headers )+ len (overrides ))
@@ -163,7 +173,10 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv
163173 // If no values exist but we have an override, use that.
164174 if len (values ) == 0 {
165175 if override , ok := overrides [key ]; ok {
166- fmt .Fprintf (w , "%s: %s\r \n " , key , override )
176+ _ , err := fmt .Fprintf (w , "%s: %s\r \n " , key , override )
177+ if err != nil {
178+ return fmt .Errorf ("write response header override: %w" , err )
179+ }
167180 }
168181 continue
169182 }
@@ -175,16 +188,71 @@ func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitiv
175188 if isSensitive {
176189 value = redactHeaderValue (value )
177190 }
178- fmt .Fprintf (w , "%s: %s\r \n " , key , value )
191+ _ , err := fmt .Fprintf (w , "%s: %s\r \n " , key , value )
192+ if err != nil {
193+ return fmt .Errorf ("write response headers: %w" , err )
194+ }
179195 }
180196 }
197+ return nil
198+ }
199+
200+ // interceptDumpPath returns the base file path (without req/resp suffix) for an interception dump.
201+ func interceptDumpPath (baseDir string , provider string , model string , interceptionID uuid.UUID , clk quartz.Clock ) string {
202+ safeModel := strings .ReplaceAll (model , "/" , "-" )
203+ return filepath .Join (baseDir , provider , safeModel , fmt .Sprintf ("%d-%s" , clk .Now ().UTC ().UnixMilli (), interceptionID ))
204+ }
205+
206+ // passthroughDumpPath returns the base file path (without req/resp suffix) for a passthrough dump.
207+ func passthroughDumpPath (baseDir string , provider string , urlPath string , clk quartz.Clock ) string {
208+ safeURLPath := strings .ReplaceAll (strings .TrimPrefix (urlPath , "/" ), "/" , "-" )
209+ return filepath .Join (baseDir , provider , "passthrough" , fmt .Sprintf ("%d-%s-%s" , clk .Now ().UTC ().UnixMilli (), safeURLPath , uuid .NewString ()[:4 ]))
210+ }
211+
212+ // NewPassthroughMiddleware returns http.RoundTripper that dumps requests and responses to files.
213+ // If baseDir is empty, returns the original transport unchanged.
214+ // Used for logging in pass through routes.
215+ func NewPassthroughMiddleware (transport http.RoundTripper , baseDir string , provider string , logger slog.Logger , clk quartz.Clock ) http.RoundTripper {
216+ if baseDir == "" {
217+ return transport
218+ }
219+ return & dumpRoundTripper {
220+ inner : transport ,
221+ baseDir : baseDir ,
222+ provider : provider ,
223+ clk : clk ,
224+ logger : logger ,
225+ }
181226}
182227
183- // path returns the path to a request/response dump file for a given interception.
184- // suffix should be SuffixRequest or SuffixResponse.
185- func (d * dumper ) path (suffix string ) string {
186- safeModel := strings .ReplaceAll (d .model , "/" , "-" )
187- return filepath .Join (d .baseDir , d .provider , safeModel , fmt .Sprintf ("%d-%s%s" , d .clk .Now ().UTC ().UnixMilli (), d .interceptionID , suffix ))
228+ type dumpRoundTripper struct {
229+ inner http.RoundTripper
230+ baseDir string
231+ provider string
232+ clk quartz.Clock
233+ logger slog.Logger
234+ }
235+
236+ func (rt * dumpRoundTripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
237+ dumper := dumper {
238+ dumpPath : passthroughDumpPath (rt .baseDir , rt .provider , req .URL .Path , rt .clk ),
239+ logger : rt .logger ,
240+ }
241+
242+ if err := dumper .dumpRequest (req ); err != nil {
243+ dumper .logger .Named ("apidump" ).Warn (req .Context (), "failed to dump passthrough request" , slog .Error (err ))
244+ }
245+
246+ resp , err := rt .inner .RoundTrip (req )
247+ if err != nil {
248+ return resp , err
249+ }
250+
251+ if err := dumper .dumpResponse (resp ); err != nil {
252+ dumper .logger .Named ("apidump" ).Warn (req .Context (), "failed to dump passthrough response" , slog .Error (err ))
253+ }
254+
255+ return resp , nil
188256}
189257
190258// prettyPrintJSON returns indented JSON if body is valid JSON, otherwise returns body as-is.
@@ -194,12 +262,11 @@ func prettyPrintJSON(body []byte) []byte {
194262 if len (body ) == 0 {
195263 return body
196264 }
197- result := pretty .Pretty (body )
198- // pretty.Pretty returns a truncated/modified result for invalid JSON,
199- // so check if the result is valid JSON; if not, return the original.
200- if ! json .Valid (result ) {
201- return body
265+
266+ result := body
267+ if json .Valid (body ) {
268+ result = pretty .Pretty (body )
202269 }
203- // Trim trailing newline added by pretty.Pretty.
204- return bytes . TrimSuffix ( result , [] byte ( " \n " ))
270+
271+ return result
205272}
0 commit comments