From cdf5b3a808550b3f79346b04e4cfba3b598d34f4 Mon Sep 17 00:00:00 2001 From: Benjamin Bengfort Date: Tue, 13 Jan 2026 08:53:43 -0600 Subject: [PATCH 1/3] Implement Transport --- cache.go | 3 + export_test.go | 5 ++ httpcache.go | 182 ++++++++++++++++++++++++++++++++++++++++++++-- httpcache_test.go | 100 +++++++++++++++++++++++++ 4 files changed, 285 insertions(+), 5 deletions(-) diff --git a/cache.go b/cache.go index 7c1c414..5f07ee2 100644 --- a/cache.go +++ b/cache.go @@ -58,8 +58,11 @@ func cacheKey(req *http.Request) string { // headers in their canonical form. This allows you to differentiate cache entries // based on header values such as Authorization or custom headers. func cacheKeyWithHeaders(req *http.Request, headers []string) string { + // Create base cachekey key := cacheKey(req) + // Important - return the base key if no headers specified so that we can use this + // method without checking the length of headers separately. if len(headers) == 0 { return key } diff --git a/export_test.go b/export_test.go index e57e73c..020d3e0 100644 --- a/export_test.go +++ b/export_test.go @@ -1,9 +1,14 @@ package httpcache +// Exports private functions but only when testing so they are not part of the public +// API but tests can be run in the httpcache_test package (works because of the _test.go +// suffix appended to this filename). var ( CacheKey = cacheKey CacheKeyWithHeaders = cacheKeyWithHeaders CacheKeyWithVary = cacheKeyWithVary Normalize = normalize CachedResponseWithKey = cachedResponse + AllHeaderCSVs = allHeaderCSVs + IsUnsafeMethod = isUnsafeMethod ) diff --git a/httpcache.go b/httpcache.go index 00faa77..38489d7 100644 --- a/httpcache.go +++ b/httpcache.go @@ -3,8 +3,167 @@ package httpcache import ( "net/http" "strings" + "time" ) +//=========================================================================== +// Transport +//=========================================================================== + +// Transport is an implementation of http.RoundTripper that will return values from a +// cache here possible (avoiding a network request) and will additionally add +// request cache headers (etag/if-modified-since) to repeated requests allowing servers +// to return cache-oriented responses such as 304 Not Modified. +type Transport struct { + Transport http.RoundTripper + Cache Cache + + // If true, responses returned from the cache will include an X-From-Cache header. + MarkCachedResponses bool + + // If true, server errors (5xx status codes) will be served from the cache. By + // default, this is false, forcing a new request on server errors. + ServerErrorsFromCache bool + + // Timeout for async requests triggered by stale-while-revalidate responses. + // If zero, no timeout is applied to async revalidation requests. + AsyncRevalidateTimeout time.Duration + + // IsPublicCache enables public cache mode (default is false for private cache). + // When true, the cache will NOT store responses with the Cache-Control: private + // directive. When false (default), the cache CAN store private responses. + // Set to true only if using httpcache as a shared/public cache (e.g. a reverse + // proxy server or CDN). + IsPublicCache bool + + // Disables RFC 9111 compliant Vary header separation (default: false). + // When true, responses with Vary headers are ignored and each variant overwrites + // the previous one in the cache. This is useful to reduce cache storage space since + // by default each variant is stored separately based on the Vary header values. + DisableVaryHeaderSeparation bool + + // Used to configure non-standard caching behavior based on the response. If set, + // this function is used to determine whether a non-200 response should be cached. + // This enables caching of responses like 404 Not Found or other status codes. If + // nil, only 200 OK responses are cached by default. + // + // The function receives the http.Response and should return true to cache it. + // NOTE: this only bypasses the status code check; Cache-Control headers are still + // respected. + ShouldCache func(resp *http.Response) bool + + // Specify additional request headers to include in the cache key generation + // allowing storage of separate cache entries based on these headers. For example, + // specifying the Authorization header allows user-specific caching or the + // Accept-Language header allows caching based on language preferences. + // Header names are case-insensitive and will be converted to their canonical form. + // NOTE: this is different from the Vary header handling which is server-driven. + CacheKeyHeaders []string +} + +// NewTransport returns a new Transport with provided Cache implementation and +// MarkCachedResponses set to true (default httpcache behavior). +func NewTransport(cache Cache) *Transport { + return &Transport{ + Cache: cache, + MarkCachedResponses: true, + } +} + +// Client returns a new http.Client that caches respones. +func (t *Transport) Client() *http.Client { + return &http.Client{Transport: t} +} + +// Execute an http Request and return an http Response. +// +// If there is a fresh Response already in the cache, it will be returned without +// connecting to the server. +// +// If there is a stale Response in the cache, then any validators it contains will be +// set on the new request to give the server a chance to respond with 304 Not Modified. +// In this case, the cached Response will be returned. +// +// RoundTrip implements the http.RoundTripper interface for Transport. +func (t *Transport) RoundTrip(req *http.Request) (rep *http.Response, err error) { + cacheKey := cacheKeyWithHeaders(req, t.CacheKeyHeaders) + cacheable := (req.Method == http.MethodGet || req.Method == http.MethodHead) && req.Header.Get("Range") == "" + + var cached *http.Response + if cacheable { + // Attempt to get a cached response + cached, err = cachedResponse(t.Cache, cacheKey, req) + + // RFC 9111 Vary Separation: if not disabled and cached response has Vary + // headers, recalculate the cache key with vary values and try again for the + // correct variant. + if !t.DisableVaryHeaderSeparation && cached != nil && err == nil { + if varyHeaders := allHeaderCSVs(cached.Header, "Vary"); len(varyHeaders) > 0 { + varyKey := cacheKeyWithVary(req, varyHeaders) + if varyKey != cacheKey { + // Try with vary-specific key + cachedVery, varyErr := cachedResponse(t.Cache, varyKey, req) + if varyErr == nil && cachedVery != nil { + // Found a variant match + cached = cachedVery + cacheKey = varyKey + } + } + } + } + + } else { + // RFC 7234 Section 4.4: Invalidate cache on unsafe methods + // Delete the request URI immediately for unsafe methods + t.Cache.Del(cacheKey) + } + + // Use default http transport if not specified + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + + // Handle cached vs uncached response + if cacheable && cached != nil && err == nil { + rep, err = t.processCachedResponse(cached, req, transport, cacheKey) + } else { + rep, err = t.processUncachedRequest(transport, req) + } + + // Error handling for multiple error phases above. + if err != nil { + return nil, err + } + + // RFC 7234 Section 4.4: Invalidate cache for unsafe methods + // After successful response, invalidate related URIs + if isUnsafeMethod(req.Method) { + t.invalidateCache(req, rep) + } + + // Store response in cache if applicable + t.storeResponseInCache(req, rep, cacheKey, cacheable) + return rep, nil +} + +func (t *Transport) processCachedResponse(cached *http.Response, req *http.Request, transport http.RoundTripper, cacheKey string) (rep *http.Response, err error) { + return nil, nil +} + +func (t *Transport) processUncachedRequest(transport http.RoundTripper, req *http.Request) (rep *http.Response, err error) { + return transport.RoundTrip(req) +} + +func (t *Transport) invalidateCache(req *http.Request, resp *http.Response) {} + +func (t *Transport) storeResponseInCache(req *http.Request, resp *http.Response, cacheKey string, cacheable bool) { +} + +//=========================================================================== +// Top Level Helper Functions +//=========================================================================== + const ( nbsp = ' ' ) @@ -36,10 +195,23 @@ func normalize(value string) string { return result } -//=========================================================================== -// Transport -//=========================================================================== +func allHeaderCSVs(headers http.Header, header string) []string { + var vals []string + for _, val := range headers[http.CanonicalHeaderKey(header)] { + fields := strings.Split(val, ",") + for i, field := range fields { + fields[i] = strings.TrimSpace(field) + } + vals = append(vals, fields...) + } + return vals +} -type Transport struct { - Transport http.RoundTripper +// isUnsafeMethod returns true if the HTTP method is considered unsafe. +// RFC 7234 Section 4.4: POST, PUT, DELETE, PATCH are unsafe methods. +func isUnsafeMethod(method string) bool { + return method == http.MethodPost || + method == http.MethodPut || + method == http.MethodDelete || + method == http.MethodPatch } diff --git a/httpcache_test.go b/httpcache_test.go index 4d9e9a4..2856908 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -62,3 +62,103 @@ func TestNormalize(t *testing.T) { require.Equal(t, test.expected, result) } } + +func TestAllHeaderCSVs(t *testing.T) { + tests := []struct { + name string + headers http.Header + header string + expected []string + }{ + { + name: "Single header with single value", + headers: http.Header{ + "Accept": []string{"text/html"}, + }, + header: "Accept", + expected: []string{"text/html"}, + }, + { + name: "Single header with single CSV", + headers: http.Header{ + "Accept": []string{"text/html, application/json"}, + }, + header: "Accept", + expected: []string{"text/html", "application/json"}, + }, + { + name: "Multiple headers with single values", + headers: http.Header{ + "Accept": []string{"text/html", "application/xml"}, + }, + header: "Accept", + expected: []string{"text/html", "application/xml"}, + }, + { + name: "Multiple Headers with multiple CSVs", + headers: http.Header{ + "Accept": []string{"text/html, application/json", "application/xml, text/plain, application/yaml"}, + }, + header: "Accept", + expected: []string{"text/html", "application/json", "application/xml", "text/plain", "application/yaml"}, + }, + { + name: "Header not present", + headers: http.Header{ + "Accept-Language": []string{"en,fr,de"}, + }, + header: "Accept", + expected: nil, + }, + { + name: "Canonicalize header name", + headers: http.Header{ + "Accept-Language": []string{"en,fr,de"}, + }, + header: "accept-language", + expected: []string{"en", "fr", "de"}, + }, + { + name: "Trim spaces in CSV values", + headers: http.Header{ + "Accept-Language": []string{"en,fr, de , es "}, + }, + header: "accept-language", + expected: []string{"en", "fr", "de", "es"}, + }, + { + name: "Case Sensitive CSV values", + headers: http.Header{ + "Accept-Language": []string{"en-US,fr-AG,es-MX"}, + }, + header: "accept-language", + expected: []string{"en-US", "fr-AG", "es-MX"}, + }, + } + + for _, test := range tests { + result := httpcache.AllHeaderCSVs(test.headers, test.header) + require.Equal(t, test.expected, result, "Failed test: %s", test.name) + } +} + +func TestIsUnsafeMethod(t *testing.T) { + tests := []struct { + method string + require require.BoolAssertionFunc + }{ + {"GET", require.False}, + {"HEAD", require.False}, + {"OPTIONS", require.False}, + {"POST", require.True}, + {"PUT", require.True}, + {"DELETE", require.True}, + {"PATCH", require.True}, + {"TRACE", require.False}, + } + + for _, test := range tests { + result := httpcache.IsUnsafeMethod(test.method) + test.require(t, result, "Failed for method: %q", test.method) + } +} From 1283c939580ee907119d1c9baa1de65b54231d72 Mon Sep 17 00:00:00 2001 From: Benjamin Bengfort Date: Wed, 14 Jan 2026 10:40:53 -0600 Subject: [PATCH 2/3] store in cache with directives --- go.mod | 1 + go.sum | 5 +- httpcache.go | 230 +++++++++++++++++++++++++++++++++++++++++++--- httpcache_test.go | 16 ++++ io.go | 38 ++++++++ io_test.go | 68 ++++++++++++++ 6 files changed, 342 insertions(+), 16 deletions(-) create mode 100644 io.go create mode 100644 io_test.go diff --git a/go.mod b/go.mod index 6680595..ad1e8e2 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/dgraph-io/ristretto/v2 v2.3.0 github.com/stretchr/testify v1.11.1 github.com/syndtr/goleveldb v1.0.0 + go.rtnl.ai/x v1.9.0 ) require ( diff --git a/go.sum b/go.sum index 68510c2..d8dfe04 100644 --- a/go.sum +++ b/go.sum @@ -25,14 +25,17 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= +go.rtnl.ai/x v1.9.0 h1:5M/1fLbVw0mxgnxhB5SyZyv7siyI+Hq2cGRPuMjwnTc= +go.rtnl.ai/x v1.9.0/go.mod h1:ciQ9PaXDtZDznzBrGDBV2yTElKX3aJgtQfi6V8613bo= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= diff --git a/httpcache.go b/httpcache.go index 38489d7..16cf154 100644 --- a/httpcache.go +++ b/httpcache.go @@ -1,9 +1,38 @@ package httpcache import ( + "io" + "log/slog" "net/http" + "net/http/httputil" "strings" "time" + + "go.rtnl.ai/x/httpcc" +) + +// RFC 9111 Section 5.2.2.3: HTTP status codes that are understood by this cache. +// When must-understand directive is present, only responses with these status codes +// can be cached, even if other cache directives would normally prevent caching. +var understoodStatusCodes = map[int]bool{ + http.StatusOK: true, // 200 + http.StatusNonAuthoritativeInfo: true, // 203 + http.StatusNoContent: true, // 204 + http.StatusPartialContent: true, // 206 + http.StatusMultipleChoices: true, // 300 + http.StatusMovedPermanently: true, // 301 + http.StatusNotFound: true, // 404 + http.StatusMethodNotAllowed: true, // 405 + http.StatusGone: true, // 410 + http.StatusRequestURITooLong: true, // 414 + http.StatusNotImplemented: true, // 501 +} + +const ( + Vary = "Vary" + Range = "Range" + Authorization = "Authorization" + XVariedPrefix = "X-Varied-" ) //=========================================================================== @@ -11,7 +40,7 @@ import ( //=========================================================================== // Transport is an implementation of http.RoundTripper that will return values from a -// cache here possible (avoiding a network request) and will additionally add +// cache where possible (avoiding a network request) and will additionally add // request cache headers (etag/if-modified-since) to repeated requests allowing servers // to return cache-oriented responses such as 304 Not Modified. type Transport struct { @@ -70,7 +99,7 @@ func NewTransport(cache Cache) *Transport { } } -// Client returns a new http.Client that caches respones. +// Client returns a new http.Client that caches responses. func (t *Transport) Client() *http.Client { return &http.Client{Transport: t} } @@ -87,7 +116,7 @@ func (t *Transport) Client() *http.Client { // RoundTrip implements the http.RoundTripper interface for Transport. func (t *Transport) RoundTrip(req *http.Request) (rep *http.Response, err error) { cacheKey := cacheKeyWithHeaders(req, t.CacheKeyHeaders) - cacheable := (req.Method == http.MethodGet || req.Method == http.MethodHead) && req.Header.Get("Range") == "" + cacheable := (req.Method == http.MethodGet || req.Method == http.MethodHead) && req.Header.Get(Range) == "" var cached *http.Response if cacheable { @@ -98,14 +127,14 @@ func (t *Transport) RoundTrip(req *http.Request) (rep *http.Response, err error) // headers, recalculate the cache key with vary values and try again for the // correct variant. if !t.DisableVaryHeaderSeparation && cached != nil && err == nil { - if varyHeaders := allHeaderCSVs(cached.Header, "Vary"); len(varyHeaders) > 0 { + if varyHeaders := allHeaderCSVs(cached.Header, Vary); len(varyHeaders) > 0 { varyKey := cacheKeyWithVary(req, varyHeaders) if varyKey != cacheKey { // Try with vary-specific key - cachedVery, varyErr := cachedResponse(t.Cache, varyKey, req) - if varyErr == nil && cachedVery != nil { - // Found a variant match - cached = cachedVery + cachedVary, varyErr := cachedResponse(t.Cache, varyKey, req) + if varyErr == nil && cachedVary != nil { + // Found a variant match use it as the cached value. + cached = cachedVary cacheKey = varyKey } } @@ -143,7 +172,7 @@ func (t *Transport) RoundTrip(req *http.Request) (rep *http.Response, err error) } // Store response in cache if applicable - t.storeResponseInCache(req, rep, cacheKey, cacheable) + t.cacheResponse(req, rep, cacheKey, cacheable) return rep, nil } @@ -155,9 +184,165 @@ func (t *Transport) processUncachedRequest(transport http.RoundTripper, req *htt return transport.RoundTrip(req) } -func (t *Transport) invalidateCache(req *http.Request, resp *http.Response) {} +func (t *Transport) invalidateCache(req *http.Request, rep *http.Response) {} + +func (t *Transport) cacheResponse(req *http.Request, rep *http.Response, cacheKey string, cacheable bool) { + var ( + err error + repcc *httpcc.ResponseDirective + reqcc *httpcc.RequestDirective + ) + + if repcc, err = httpcc.Response(rep); err != nil { + GetLogger().Warn("could not parse response cache-control directives", slog.Any("error", err)) + repcc = &httpcc.ResponseDirective{} + } + + if reqcc, err = httpcc.Request(req); err != nil { + GetLogger().Warn("could not parse request cache-control directives", slog.Any("error", err)) + reqcc = &httpcc.RequestDirective{} + } + + if !cacheable || !t.canStore(rep, req, reqcc, repcc) { + t.Cache.Del(cacheKey) + return + } + + // RFC 9111 Section 5.2.2.3: must-understand directive + // When must-understand is present and status code is understood, always cache + mustUnderstandAllowsCaching := repcc.MustUnderstand() && understoodStatusCodes[rep.StatusCode] + shouldCache := understoodStatusCodes[rep.StatusCode] || mustUnderstandAllowsCaching + + // Allow custom override via ShouldCache user supplied hook. + if !shouldCache && t.ShouldCache != nil { + shouldCache = t.ShouldCache(rep) + } + + if !shouldCache { + t.Cache.Del(cacheKey) + return + } + + // Store all Vary headers in the response for future cache validation. + storeVaryHeaders(rep, req) + + // Determine the cache key(s) to use. + cacheKeys := []string{cacheKey} + + // Handle Vary headers to store the body across multiple keys. + varyHeaders := allHeaderCSVs(rep.Header, Vary) + if !t.DisableVaryHeaderSeparation && len(varyHeaders) > 0 { + // Store the full response ounder both the variant key and the base key + // so that RoundTrip can read the base entry (to discover Vary) and then + // re-lookup the variant-specific entry. + if varyKey := cacheKeyWithVary(req, varyHeaders); varyKey != cacheKey { + cacheKeys = append(cacheKeys, varyKey) + } + } + + // Finally store the response in the cache! + // If the request is a GET request, add a CachingReadCloser to the response body + // so that the response is only stored once the body has been fully read; otherwise + // store the response immediately. + // + // NOTE: this means that not GET requests will not be able to read the body! + // If this is required, we need to add a setting to the Transport to handle this. + if req.Method == http.MethodGet { + // Setup caching on body read completion + t.setupCachingBody(rep, cacheKeys...) + } else { + // Store immediately for non-GET requests + t.store(rep, cacheKeys...) + } +} + +// Stores a response in the cache dumping it as bytes. +// NOTE: expects cache headers to already be set on the response. +func (t *Transport) store(rep *http.Response, cacheKeys ...string) { + if t.Cache == nil { + GetLogger().Error("cannot store response in cache: no cache configured") + return + } + + // Read the response body only once and store it in the cache for all provided keys. + if data, err := httputil.DumpResponse(rep, true); err == nil { + for _, cacheKey := range cacheKeys { + t.Cache.Put(cacheKey, data) + } + } else { + GetLogger().Error("could not dump response for caching", slog.Any("error", err)) + } +} + +// Wraps the response body to be cached when it is fully read. +func (t *Transport) setupCachingBody(rep *http.Response, cacheKeys ...string) { + rep.Body = &CachingReadCloser{ + R: rep.Body, + OnEOF: func(r io.Reader) { + // Convert the response body to the cached response by copying the reply + rep := *rep + rep.Body = io.NopCloser(r) + + // Store the response in the cache + t.store(&rep, cacheKeys...) + }, + } +} + +// Determines if a response can be stored in the cache based on request/response +// cache-control directives and status code. This method implements RFC 911 compliance. +// RFC 9111 Section 3: Storing Responses in Caches +// RFC 9111 Section 5.2.2.3: must-understand directive +// RFC 9111 Section 3.5: Storing Responses to Authenticated Requests +func (t *Transport) canStore(rep *http.Response, req *http.Request, reqcc *httpcc.RequestDirective, repcc *httpcc.ResponseDirective) bool { + // RFC 9111 Section 5.2.2.3: must-understand directive + // When must-understand is present, the cache can only store the response if: + // 1. The status code is understood by the cache, AND + // 2. All other cache directives are comprehended + // + // If must-understand is present and the status code is not understood, + // the cache MUST NOT store the response, even if other directives would permit it. + // + // If must-understand is present and the status code IS understood, + // then no-store is effectively ignored (the response can be cached). + if repcc.MustUnderstand() { + if !understoodStatusCodes[rep.StatusCode] { + // Status code not understood → must not cache. + return false + } + // Status code understood → proceed (overrides no-store). + } else { + // Normal behavior when must-understand is not present. + if repcc.NoStore() || reqcc.NoStore() { + return false + } + } + + // RFC 9111 Section 3.5: Storing Responses to Authenticated Requests + // A shared cache MUST NOT use a cached response to a request with an Authorization + // header field unless the response contains a Cache-Control field with the "public", + // "must-revalidate", or "s-maxage" response directive. + if t.IsPublicCache && req.Header.Get(Authorization) != "" { + _, hasSMaxAge := repcc.SMaxAge() + if !repcc.Public() && !repcc.MustRevalidate() && !hasSMaxAge { + GetLogger().Debug( + "refusing to cache Authorization request in shared cache", + slog.String("url", req.URL.String()), + slog.String("reason", "no public/must-revalidate/s-maxage directive"), + slog.Bool("public", repcc.Public()), + slog.Bool("must-revalidate", repcc.MustRevalidate()), + slog.Bool("has-s-maxage", hasSMaxAge), + ) + return false + } + } -func (t *Transport) storeResponseInCache(req *http.Request, resp *http.Response, cacheKey string, cacheable bool) { + // RFC 9111: Check Cache-Control: private directive + if repcc.Private() && t.IsPublicCache { + return false + } + + return true } //=========================================================================== @@ -198,11 +383,11 @@ func normalize(value string) string { func allHeaderCSVs(headers http.Header, header string) []string { var vals []string for _, val := range headers[http.CanonicalHeaderKey(header)] { - fields := strings.Split(val, ",") - for i, field := range fields { - fields[i] = strings.TrimSpace(field) + for _, field := range strings.Split(val, ",") { + if s := strings.TrimSpace(field); s != "" { + vals = append(vals, s) + } } - vals = append(vals, fields...) } return vals } @@ -215,3 +400,18 @@ func isUnsafeMethod(method string) bool { method == http.MethodDelete || method == http.MethodPatch } + +// Stores the Vary header values in the response for future cache validation. +// RFC 9111 Section 4.1: Values are normalized before storage to enable proper matching. +func storeVaryHeaders(rep *http.Response, req *http.Request) { + for _, varyKey := range allHeaderCSVs(rep.Header, Vary) { + varyKey := http.CanonicalHeaderKey(strings.TrimSpace(varyKey)) + if varyKey == "" || varyKey == "*" { + continue + } + + value := normalize(req.Header.Get(varyKey)) + cacheHeader := XVariedPrefix + varyKey + rep.Header.Set(cacheHeader, value) + } +} diff --git a/httpcache_test.go b/httpcache_test.go index 2856908..3034f5a 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -134,6 +134,22 @@ func TestAllHeaderCSVs(t *testing.T) { header: "accept-language", expected: []string{"en-US", "fr-AG", "es-MX"}, }, + { + name: "Filter empty values", + headers: http.Header{ + "Accept-Language": []string{"en,,fr, , es, "}, + }, + header: "accept-language", + expected: []string{"en", "fr", "es"}, + }, + { + name: "Trailing comma", + headers: http.Header{ + "Accept-Language": []string{"en,"}, + }, + header: "accept-language", + expected: []string{"en"}, + }, } for _, test := range tests { diff --git a/io.go b/io.go new file mode 100644 index 0000000..8bdf968 --- /dev/null +++ b/io.go @@ -0,0 +1,38 @@ +package httpcache + +import ( + "bytes" + "io" +) + +// CachingReadCloser wraps a ReadCloser R that calls the OnEOF handler with a full copy +// of the content read from R when EOF is reached. While this does cause the full read +// to be stored in memory, it allows the Transport to cache the response body only once +// it has been fully read. +type CachingReadCloser struct { + // Wrapped ReadCloser (response.Body) + R io.ReadCloser + + // Called when EOF is reached with a reader containing the full content read. + OnEOF func(r io.Reader) + + // Internal buffer to store a copy of the content of R. + buf bytes.Buffer +} + +func (r *CachingReadCloser) Read(p []byte) (n int, err error) { + n, err = r.R.Read(p) + r.buf.Write(p[:n]) + + if err == io.EOF || n < len(p) { + if r.OnEOF != nil { + r.OnEOF(bytes.NewReader(r.buf.Bytes())) + } + } + + return n, err +} + +func (r *CachingReadCloser) Close() error { + return r.R.Close() +} diff --git a/io_test.go b/io_test.go new file mode 100644 index 0000000..6388ed3 --- /dev/null +++ b/io_test.go @@ -0,0 +1,68 @@ +package httpcache_test + +import ( + "io" + "testing" + + "github.com/stretchr/testify/require" + "go.rtnl.ai/httpcache" +) + +func TestCachingReadCloser(t *testing.T) { + mock := &MockBody{ + data: []byte("Hello, World!"), + } + + r := &httpcache.CachingReadCloser{ + R: mock, + OnEOF: func(r io.Reader) { + // Read the full content from the reader passed to OnEOF + data, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, []byte("Hello, World!"), data, "OnEOF should receive the full content") + }, + } + + read, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, []byte("Hello, World!"), read, "ReadAll should return the full content") + + require.NoError(t, r.Close(), "Close should not return an error") + + mock.RequireClosed(t) + mock.RequireEmpty(t) +} + +type MockBody struct { + data []byte + readErr error + closed bool + closeErr error +} + +func (mb *MockBody) Read(p []byte) (n int, err error) { + if mb.readErr != nil { + return 0, mb.readErr + } + + if len(mb.data) == 0 { + return 0, io.EOF + } + + n = copy(p, mb.data) + mb.data = mb.data[n:] + return n, nil +} + +func (mb *MockBody) Close() error { + mb.closed = true + return mb.closeErr +} + +func (mb *MockBody) RequireClosed(t *testing.T) { + require.True(t, mb.closed, "expected body to be closed") +} + +func (mb *MockBody) RequireEmpty(t *testing.T) { + require.Empty(t, mb.data, "expected body to be empty") +} From d8f655d7dd9a0c96b71625213e19ad0dbd6f8027 Mon Sep 17 00:00:00 2001 From: Benjamin Bengfort Date: Wed, 14 Jan 2026 11:09:39 -0600 Subject: [PATCH 3/3] invalidate cache --- export_test.go | 2 + httpcache.go | 121 ++++++++++++++++++++++++++++++++++++++++++++-- httpcache_test.go | 47 ++++++++++++++++++ logger.go | 11 ++++- 4 files changed, 175 insertions(+), 6 deletions(-) diff --git a/export_test.go b/export_test.go index 020d3e0..ea81565 100644 --- a/export_test.go +++ b/export_test.go @@ -11,4 +11,6 @@ var ( CachedResponseWithKey = cachedResponse AllHeaderCSVs = allHeaderCSVs IsUnsafeMethod = isUnsafeMethod + IsSameOrigin = isSameOrigin + GetOrigin = getOrigin ) diff --git a/httpcache.go b/httpcache.go index 16cf154..ddf0ed5 100644 --- a/httpcache.go +++ b/httpcache.go @@ -5,6 +5,7 @@ import ( "log/slog" "net/http" "net/http/httputil" + "net/url" "strings" "time" @@ -29,10 +30,12 @@ var understoodStatusCodes = map[int]bool{ } const ( - Vary = "Vary" - Range = "Range" - Authorization = "Authorization" - XVariedPrefix = "X-Varied-" + Vary = "Vary" + Range = "Range" + Authorization = "Authorization" + Location = "Location" + ContentLocation = "Content-Location" + XVariedPrefix = "X-Varied-" ) //=========================================================================== @@ -184,7 +187,103 @@ func (t *Transport) processUncachedRequest(transport http.RoundTripper, req *htt return transport.RoundTrip(req) } -func (t *Transport) invalidateCache(req *http.Request, rep *http.Response) {} +//=========================================================================== +// Cache Invalidation +//=========================================================================== + +// Invalidates cache entries per RFC 9111 Section 4.4 +// When receiving a non-error response to an unsafe method, invalidate: +// 1. The effective Request-URI +// 2. URIs in Location and Content-Location response headers (if present and same-origin) +// +// RFC 9111 restricts invalidation to same-origin URIs for security. +func (t *Transport) invalidateCache(req *http.Request, rep *http.Response) { + // RFC 9111 Section 4.4: Only invalidate on non-error responses + if rep.StatusCode >= 400 { + GetLogger().Debug( + "skipping cache invalidation for error response", + slog.String("url", req.URL.String()), + slog.Int("status", rep.StatusCode), + ) + return + } + + // Always invalidate the Request-URI + t.invalidateURI(req.URL, GetLogger().With(slog.String("source", "request-uri"))) + + // Invalidate Location header URI (RFC 9111 Section 4.4) + if location := rep.Header.Get(Location); location != "" { + if err := t.invalidateHeaderURI(req.URL, location, Location); err != nil { + GetLogger().Debug( + "could not invalidate Location header URI", + slog.String("location", location), + slog.Any("error", err), + ) + } + } + + // Invalidate Content-Location header URI (RFC 9111 Section 4.4) + if location := rep.Header.Get(ContentLocation); location != "" { + if err := t.invalidateHeaderURI(req.URL, location, ContentLocation); err != nil { + GetLogger().Debug( + "could not invalidate Content-Location header URI", + slog.String("content-location", location), + slog.Any("error", err), + ) + } + } +} + +func (t *Transport) invalidateURI(url *url.URL, logger *slog.Logger) { + if logger == nil { + logger = GetLogger() + } + + // Invalidate GET request for this URL + getKey := cacheKey(&http.Request{ + Method: http.MethodGet, + URL: url, + }) + t.Cache.Del(getKey) + logger.Debug("invalidated cache entry", slog.String("key", getKey), slog.String("url", url.String())) + + // Also invalidate HEAD request if different key + headKey := cacheKey(&http.Request{ + Method: http.MethodHead, + URL: url, + }) + if headKey != getKey { + t.Cache.Del(headKey) + logger.Debug("invalidated cache entry", slog.String("key", headKey), slog.String("url", url.String())) + } +} + +func (t *Transport) invalidateHeaderURI(requestURL *url.URL, value string, header string) (err error) { + // Parse the header value as a URI (may be relative or absolute) + var targetURL *url.URL + if targetURL, err = requestURL.Parse(value); err != nil { + return err + } + + // RFC 9111 Section 4.4: Only invalidate same-origin URIs + // Origin = scheme + host (host includes port if present) + if !isSameOrigin(requestURL, targetURL) { + GetLogger().Debug( + "skipping cache invalidation for cross-origin URI", + slog.String("header", header), + slog.String("request-origin", getOrigin(requestURL)), + slog.String("target-origin", getOrigin(targetURL)), + ) + return nil + } + + t.invalidateURI(targetURL, GetLogger().With(slog.String("source", header))) + return nil +} + +//=========================================================================== +// Response Caching +//=========================================================================== func (t *Transport) cacheResponse(req *http.Request, rep *http.Response, cacheKey string, cacheable bool) { var ( @@ -415,3 +514,15 @@ func storeVaryHeaders(rep *http.Response, req *http.Request) { rep.Header.Set(cacheHeader, value) } } + +// Checks if two URLS have the same origin. Per RFC 9111, origin is defined as +// scheme + host (host includes port if present). +func isSameOrigin(u1, u2 *url.URL) bool { + return u1.Scheme == u2.Scheme && u1.Host == u2.Host +} + +// Returns the origin string for a URL (scheme://host). +// Only used for debugging and logging. +func getOrigin(u *url.URL) string { + return u.Scheme + "://" + u.Host +} diff --git a/httpcache_test.go b/httpcache_test.go index 3034f5a..748e3b1 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "net/http" + "net/url" "testing" "github.com/stretchr/testify/require" @@ -178,3 +179,49 @@ func TestIsUnsafeMethod(t *testing.T) { test.require(t, result, "Failed for method: %q", test.method) } } + +func TestIsSameOrigin(t *testing.T) { + tests := []struct { + url1 string + url2 string + require require.BoolAssertionFunc + }{ + {"http://example.com/resource", "http://example.com/other", require.True}, + {"https://example.com/resource", "https://example.com/other", require.True}, + {"http://example.com/resource", "https://example.com/other", require.False}, + {"http://example.com/resource", "http://other.com/resource", require.False}, + {"http://example.com/resource", "http://example.com:8080/resource", require.False}, + {"http://sub.example.com/resource", "http://example.com/resource", require.False}, + } + + for _, test := range tests { + u1, err := url.Parse(test.url1) + require.NoError(t, err, "Failed to parse url1: %q", test.url1) + + u2, err := url.Parse(test.url2) + require.NoError(t, err, "Failed to parse url2: %q", test.url2) + + result := httpcache.IsSameOrigin(u1, u2) + test.require(t, result, "Failed for URLs: %q and %q", test.url1, test.url2) + } +} + +func TestGetOrigin(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"http://example.com/resource", "http://example.com"}, + {"https://example.com:8080/resource", "https://example.com:8080"}, + {"ftp://ftp.example.com/files", "ftp://ftp.example.com"}, + {"http://localhost:3000/path", "http://localhost:3000"}, + } + + for _, test := range tests { + u, err := url.Parse(test.input) + require.NoError(t, err, "Failed to parse input URL: %q", test.input) + + origin := httpcache.GetOrigin(u) + require.Equal(t, test.expected, origin, "Failed for input URL: %q", test.input) + } +} diff --git a/logger.go b/logger.go index 695eadc..d05b1d4 100644 --- a/logger.go +++ b/logger.go @@ -13,8 +13,17 @@ var ( // SetLogger sets a custom slog.Logger instance to be used by httpcache. If not set, // the default slog logger will be used. Rotational apps should implement a zerolog // slogger for observability. +// +// To stop logging either pass slog.New(slog.DiscardHandler) or nil (which will +// create a discard handler). func SetLogger(l *slog.Logger) { - logger = l + if l != nil { + logger = l + return + } + + // Set a discard handler if nil is provided. + logger = slog.New(slog.DiscardHandler) } // GetLogger returns the configured logger or the default slog logger.