Skip to content

Commit 1b1dd0f

Browse files
cih9088bartventer
andauthored
fix: return error immediately on roundtrip error (#29)
* Return immediately when upstream RoundTrip returns an error (avoid nil-response dereference). * Simplify validation handler API by removing error parameter. * Update tests for the new handler contract. --------- Co-authored-by: Bart Venter <72999113+bartventer@users.noreply.github.com>
1 parent 3cfaeec commit 1b1dd0f

File tree

5 files changed

+32
-59
lines changed

5 files changed

+32
-59
lines changed

internal/mocks.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,16 +183,15 @@ func (m *MockCacheInvalidator) InvalidateCache(
183183
var _ ValidationResponseHandler = (*MockValidationResponseHandler)(nil)
184184

185185
type MockValidationResponseHandler struct {
186-
HandleValidationResponseFunc func(ctx RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error)
186+
HandleValidationResponseFunc func(ctx RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error)
187187
}
188188

189189
func (m *MockValidationResponseHandler) HandleValidationResponse(
190190
ctx RevalidationContext,
191191
req *http.Request,
192192
resp *http.Response,
193-
err error,
194193
) (*http.Response, error) {
195-
return m.HandleValidationResponseFunc(ctx, req, resp, err)
194+
return m.HandleValidationResponseFunc(ctx, req, resp)
196195
}
197196

198197
var _ VaryMatcher = (*MockVaryMatcher)(nil)

internal/validationresponsehandler.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ type ValidationResponseHandler interface {
2424
ctx RevalidationContext,
2525
req *http.Request,
2626
resp *http.Response,
27-
err error,
2827
) (*http.Response, error)
2928
}
3029

@@ -75,9 +74,8 @@ func (r *validationResponseHandler) HandleValidationResponse(
7574
ctx RevalidationContext,
7675
req *http.Request,
7776
resp *http.Response,
78-
err error,
7977
) (*http.Response, error) {
80-
if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified {
78+
if req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified {
8179
// RFC 9111 §4.3.3 Handling Validation Responses (304 Not Modified)
8280
// RFC 9111 §4.3.4 Freshening Stored Responses upon Validation
8381
mergeResponseHeaders(ctx.Stored.Data, resp.Header)
@@ -99,7 +97,7 @@ func (r *validationResponseHandler) HandleValidationResponse(
9997
ccResp CCResponseDirectives
10098
ccRespOnce bool
10199
)
102-
if (err != nil || isStaleErrorAllowed(resp.StatusCode)) && req.Method == http.MethodGet {
100+
if isStaleErrorAllowed(resp.StatusCode) && req.Method == http.MethodGet {
103101
ccResp = ParseCCResponseDirectives(resp.Header)
104102
ccRespOnce = true
105103
if r.siep.CanStaleOnError(ctx.Freshness, ccResp) {
@@ -112,10 +110,6 @@ func (r *validationResponseHandler) HandleValidationResponse(
112110
}
113111
}
114112

115-
if err != nil {
116-
return nil, err
117-
}
118-
119113
if !ccRespOnce {
120114
ccResp = ParseCCResponseDirectives(resp.Header)
121115
}

internal/validationresponsehandler_test.go

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
package internal
1616

1717
import (
18-
"errors"
1918
"log/slog"
2019
"net/http"
2120
"net/url"
@@ -48,9 +47,8 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) {
4847
}
4948

5049
type args struct {
51-
req *http.Request
52-
resp *http.Response
53-
inputErr error
50+
req *http.Request
51+
resp *http.Response
5452
}
5553

5654
tests := []struct {
@@ -107,31 +105,7 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) {
107105
},
108106
},
109107
{
110-
name: "GET with error, stale allowed",
111-
handler: &validationResponseHandler{
112-
l: noopLogger,
113-
siep: &MockStaleIfErrorPolicy{
114-
CanStaleOnErrorFunc: func(*Freshness, ...StaleIfErrorer) bool { return true },
115-
},
116-
clock: &MockClock{NowResult: base},
117-
},
118-
setup: func(tt *testing.T, handler *validationResponseHandler) args {
119-
return args{
120-
req: &http.Request{Method: http.MethodGet},
121-
resp: &http.Response{
122-
StatusCode: http.StatusInternalServerError,
123-
Header: http.Header{"Cache-Control": {"stale-if-error=60"}},
124-
},
125-
inputErr: errors.New("network error"),
126-
}
127-
},
128-
assert: func(tt *testing.T, got *http.Response, err error) {
129-
testutil.RequireNoError(tt, err)
130-
testutil.AssertEqual(tt, http.StatusOK, got.StatusCode)
131-
},
132-
},
133-
{
134-
name: "GET with error, stale not allowed",
108+
name: "GET with error status, stale not allowed",
135109
handler: &validationResponseHandler{
136110
l: noopLogger,
137111
siep: &MockStaleIfErrorPolicy{
@@ -140,18 +114,20 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) {
140114
clock: &MockClock{NowResult: base},
141115
},
142116
setup: func(tt *testing.T, handler *validationResponseHandler) args {
117+
handler.ce = CacheabilityEvaluatorFunc(
118+
func(*http.Response, CCRequestDirectives, CCResponseDirectives) bool { return false },
119+
)
143120
return args{
144121
req: &http.Request{Method: http.MethodGet},
145122
resp: &http.Response{
146123
StatusCode: http.StatusInternalServerError,
147124
Header: http.Header{},
148125
},
149-
inputErr: errors.New("network error"),
150126
}
151127
},
152128
assert: func(tt *testing.T, got *http.Response, err error) {
153-
testutil.RequireError(tt, err)
154-
testutil.AssertNil(tt, got)
129+
testutil.RequireNoError(tt, err)
130+
testutil.AssertEqual(tt, "BYPASS", got.Header.Get(CacheStatusHeader))
155131
},
156132
},
157133
{
@@ -210,7 +186,7 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) {
210186
for _, tt := range tests {
211187
t.Run(tt.name, func(t *testing.T) {
212188
a := tt.setup(t, tt.handler)
213-
got, err := tt.handler.HandleValidationResponse(ctx, a.req, a.resp, a.inputErr)
189+
got, err := tt.handler.HandleValidationResponse(ctx, a.req, a.resp)
214190
tt.assert(t, got, err)
215191
})
216192
}

roundtripper.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ func (r *transport) handleCacheMiss(
288288
return resp, nil
289289
}
290290

291+
//nolint:cyclop // The complexity of this function is justified by the need to handle multiple caching scenarios according to RFC 9111.
291292
func (r *transport) handleCacheHit(
292293
req *http.Request,
293294
stored *internal.Response,
@@ -340,6 +341,9 @@ func (r *transport) handleCacheHit(
340341
revalidate:
341342
req = withConditionalHeaders(req, stored.Data.Header)
342343
resp, start, end, err := r.roundTripTimed(req)
344+
if err != nil {
345+
return nil, err
346+
}
343347
ctx := internal.RevalidationContext{
344348
URLKey: urlKey,
345349
Start: start,
@@ -350,7 +354,7 @@ revalidate:
350354
RefIndex: refIndex,
351355
Freshness: freshness,
352356
}
353-
return r.vrh.HandleValidationResponse(ctx, req, resp, err)
357+
return r.vrh.HandleValidationResponse(ctx, req, resp)
354358
}
355359

356360
func (r *transport) serveFromCache(
@@ -441,7 +445,7 @@ func (r *transport) backgroundRevalidate(
441445
Freshness: freshness,
442446
}
443447
//nolint:bodyclose // The response is not used, so we don't need to close it.
444-
_, err = r.vrh.HandleValidationResponse(revalCtx, req, resp, nil)
448+
_, err = r.vrh.HandleValidationResponse(revalCtx, req, resp)
445449
errc <- err
446450
}()
447451

roundtripper_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ func mockTransport(fields func(rt *transport)) *transport {
6565
ci: &internal.MockCacheInvalidator{},
6666
rs: &internal.MockResponseStorer{},
6767
vrh: &internal.MockValidationResponseHandler{
68-
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
69-
return resp, err
68+
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
69+
return resp, nil
7070
},
7171
},
7272
clock: &internal.MockClock{NowResult: time.Now()},
@@ -225,9 +225,9 @@ func Test_transport_CacheHit_MustRevalidate_Stale(t *testing.T) {
225225
},
226226
}
227227
rt.vrh = &internal.MockValidationResponseHandler{
228-
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
228+
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
229229
mockVHCalled = true
230-
return resp, err
230+
return resp, nil
231231
},
232232
}
233233
})
@@ -260,9 +260,9 @@ func Test_transport_CacheHit_NoCacheUnqualified(t *testing.T) {
260260
},
261261
}
262262
rt.vrh = &internal.MockValidationResponseHandler{
263-
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
263+
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
264264
mockVHCalled = true
265-
return resp, err
265+
return resp, nil
266266
},
267267
}
268268
})
@@ -540,10 +540,10 @@ func Test_transport_RevalidationPath(t *testing.T) {
540540
},
541541
}
542542
rt.vrh = &internal.MockValidationResponseHandler{
543-
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
543+
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
544544
mockVHCalled = true
545545
internal.CacheStatusRevalidated.ApplyTo(resp.Header)
546-
return resp, err
546+
return resp, nil
547547
},
548548
}
549549
})
@@ -597,9 +597,9 @@ func Test_transport_SWR_NormalPath(t *testing.T) {
597597
rt.clock = &internal.MockClock{NowResult: base.Add(5 * time.Second), SinceResult: 0}
598598
rt.siep = &internal.MockStaleIfErrorPolicy{}
599599
rt.vrh = &internal.MockValidationResponseHandler{
600-
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
600+
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
601601
revalidateCalled <- struct{}{} // Signal that revalidation was called
602-
return resp, err
602+
return resp, nil
603603
},
604604
}
605605
rt.swrTimeout = DefaultSWRTimeout
@@ -670,7 +670,7 @@ func Test_transport_SWR_NormalPathAndError(t *testing.T) {
670670
rt.clock = &internal.MockClock{NowResult: base.Add(5 * time.Second), SinceResult: 0}
671671
rt.swrTimeout = swrTimeout
672672
rt.vrh = &internal.MockValidationResponseHandler{
673-
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
673+
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
674674
defer func() { revalidateCalled <- struct{}{} }() // Signal that revalidation was called
675675
return nil, errors.New("revalidation error")
676676
},
@@ -737,9 +737,9 @@ func Test_transport_SWR_Timeout(t *testing.T) {
737737
rt.clock = &internal.MockClock{NowResult: base.Add(5 * time.Second), SinceResult: 0}
738738
rt.swrTimeout = swrTimeout
739739
rt.vrh = &internal.MockValidationResponseHandler{
740-
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
740+
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
741741
revalidateCalled <- struct{}{} // Signal that revalidation was called
742-
return resp, err
742+
return resp, nil
743743
},
744744
}
745745
rt.upstream = &internal.MockRoundTripper{

0 commit comments

Comments
 (0)