diff --git a/README.md b/README.md index 5b2b7c1..d8dd36c 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,8 @@ go get -u github.com/alexmerren/httpcache Here's an example of using the `httpcache` module to cache responses: + + ```go func main() { // Create a new SQLite database to store HTTP responses. @@ -69,3 +71,20 @@ func main() { ## ❓ Questions and Support Any questions can be submitted via [GitHub Issues](https://www.github.com/alexmerren/httpcache/issues). Feel free to start contributing or asking any questions required! + +## Roadmap + +- Implement cache interface in separate cache modules: + - [x] Sqlite + - [ ] Duckdb + - [ ] Redis +- Make library thread safe to access cache. +- Implement other features from https://pypi.org/project/requests-cache/ + - [ ] Cache-control header for expiration time on individual records. + - [ ] Match headers to save different response on header value. + - [ ] Stale-if-error to use stale response if request errors out. + - [ ] (maybe) ignored parameters so don't save API keys? +- Write full test suite! +- Write benchmark tests +- Write documentation and contribution guide + diff --git a/cache.go b/cache.go index 3b700f6..0cbf381 100644 --- a/cache.go +++ b/cache.go @@ -12,24 +12,34 @@ import ( type Cache interface { // Save a response for a HTTP request using [context.Background]. - Save(response *http.Response, expiryTime *time.Duration) error + Save(response *http.Response) error // Read a saved response for a HTTP request using [context.Background]. - Read(request *http.Request) (*http.Response, error) + Read(request *http.Request) (*ReadResult, error) - // Save a response for a HTTP request with a [context.Context]. expiryTime - // is the duration from [time.Now] to expire the response. - SaveContext(ctx context.Context, response *http.Response, expiryTime *time.Duration) error + // Delete a saved response using [context.Background]. + Delete(response *http.Response) error + + // Save a response for a HTTP request with a [context.Context]. + SaveContext(ctx context.Context, response *http.Response) error // Read a saved response for a HTTP request with a [context.Context]. If no - // response is saved for the corresponding request, or the expiryTime has - // been surpassed, then return [ErrNoResponse]. - ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) + // response is saved for the corresponding request. + ReadContext(ctx context.Context, request *http.Request) (*ReadResult, error) + + // Delete a saved response with a [context.Context]. + DeleteContext(ctx context.Context, response *http.Response) error +} + +// ReadResult is the result of a successful read operation on the cache. +type ReadResult struct { + response *http.Response + createdAt *time.Time } var ( - // ErrNoResponse describes when the cache does not have a response stored. + // ErrNoResult describes when the cache does not have a response stored. // [Transport] will check if ErrNoResponse is returned from [Cache.Read]. If // ErrNoResponse is returned, then the request/response will be saved with [Save]. - ErrNoResponse = errors.New("no stored response") + ErrNoResult = errors.New("found no result from read") ) diff --git a/caches/duckdb/duckdb.go b/caches/duckdb/duckdb.go new file mode 100644 index 0000000..67a0dca --- /dev/null +++ b/caches/duckdb/duckdb.go @@ -0,0 +1,7 @@ +package duckdb + +import "database/sql" + +type DuckdbCache struct { + Database *sql.DB +} diff --git a/caches/duckdb/go.mod b/caches/duckdb/go.mod new file mode 100644 index 0000000..708a8e8 --- /dev/null +++ b/caches/duckdb/go.mod @@ -0,0 +1,3 @@ +module github.com/alexmerren/httpcache/caches/duckdb + +go 1.22.3 diff --git a/caches/sqlite/go.mod b/caches/sqlite/go.mod new file mode 100644 index 0000000..41b6822 --- /dev/null +++ b/caches/sqlite/go.mod @@ -0,0 +1,16 @@ +module github.com/alexmerren/httpcache/caches/sqlite + +go 1.22.3 + +require ( + github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/alexmerren/httpcache v0.4.0 + github.com/mattn/go-sqlite3 v1.14.32 + github.com/stretchr/testify v1.11.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/caches/sqlite/go.sum b/caches/sqlite/go.sum new file mode 100644 index 0000000..6c6db57 --- /dev/null +++ b/caches/sqlite/go.sum @@ -0,0 +1,17 @@ +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/alexmerren/httpcache v0.4.0 h1:NBVrDmJW7QFJ7xP2Qf9gV6+uiXVvVDBiGiCkcdbMbkU= +github.com/alexmerren/httpcache v0.4.0/go.mod h1:v8A6Vrn/8alCPq6577Ti5q/WerAQthkysMaPb9SJPsc= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/sqlite_cache.go b/caches/sqlite/sqlite.go similarity index 69% rename from sqlite_cache.go rename to caches/sqlite/sqlite.go index a1df3cf..569d34f 100644 --- a/sqlite_cache.go +++ b/caches/sqlite/sqlite.go @@ -1,4 +1,4 @@ -package httpcache +package sqlite import ( "bytes" @@ -11,6 +11,7 @@ import ( "path/filepath" "time" + "github.com/alexmerren/httpcache" _ "github.com/mattn/go-sqlite3" ) @@ -21,7 +22,7 @@ const ( request_method TEXT NOT NULL, response_body BLOB NOT NULL, status_code INTEGER NOT NULL, - expiry_time INTEGER)` + created_at INTEGER NOT NULL)` saveRequestQuery = ` INSERT OR REPLACE INTO responses ( @@ -29,11 +30,11 @@ const ( request_method, response_body, status_code, - expiry_time) + created_at) VALUES (?, ?, ?, ?, ?)` readRequestQuery = ` - SELECT response_body, status_code, expiry_time + SELECT response_body, status_code, created_at FROM responses WHERE request_url = ? AND request_method = ?` ) @@ -42,16 +43,16 @@ const ( // constructed correctly. var ErrNoDatabase = errors.New("no database connection") -// SqliteCache is a default implementation of [Cache] which creates a local -// SQLite cache to persist and to query HTTP responses. -type SqliteCache struct { +// SqliteDriver is a default implementation of [Cache] which creates a local +// SQLite driver to persist and to query HTTP responses. +type SqliteDriver struct { Database *sql.DB } -// NewSqliteCache creates a new SQLite database with a certain name. This name +// NewSqliteDriver creates a new SQLite database with a certain name. This name // is the filename of the database. If the file does not exist, then we create // it. If the file is in a non-existent directory, we create the directory. -func NewSqliteCache(databaseName string) (*SqliteCache, error) { +func NewSqliteDriver(databaseName string) (*SqliteDriver, error) { fileExists, err := doesFileExist(databaseName) if err != nil { return nil, err @@ -74,36 +75,29 @@ func NewSqliteCache(databaseName string) (*SqliteCache, error) { return nil, err } - return &SqliteCache{ + return &SqliteDriver{ Database: conn, }, nil } -func (s *SqliteCache) Save(response *http.Response, expiryTime *time.Duration) error { - return s.SaveContext(context.Background(), response, expiryTime) +func (s *SqliteDriver) Save(response *http.Response) error { + return s.SaveContext(context.Background(), response) } -func (s *SqliteCache) Read(request *http.Request) (*http.Response, error) { +func (s *SqliteDriver) Read(request *http.Request) (*http.Response, error) { return s.ReadContext(context.Background(), request) } -func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response, expiryTime *time.Duration) error { +func (s *SqliteDriver) SaveContext(ctx context.Context, response *httpcache.ReadResult) error { responseBody := bytes.NewBuffer(nil) _, err := io.Copy(responseBody, response.Body) if err != nil { return err } - var expiryTimestamp *int64 - if expiryTime != nil { - calculatedExpiry := time.Now().Add(*expiryTime).Unix() - expiryTimestamp = &calculatedExpiry - } else { - expiryTimestamp = nil - } - // Reset the response body stream to the beginning to be read again. response.Body = io.NopCloser(responseBody) + now := time.Now().UnixMilli() requestUrl := generateUrl(response.Request) _, err = s.Database.Exec( @@ -112,7 +106,7 @@ func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response, response.Request.Method, responseBody.Bytes(), response.StatusCode, - expiryTimestamp, + now, ) if err != nil { return err @@ -121,7 +115,7 @@ func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response, return nil } -func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) { +func (s *SqliteDriver) ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) { if s.Database == nil { return nil, ErrNoDatabase } @@ -131,20 +125,16 @@ func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (* var responseBody []byte var responseStatusCode int - var expiryTime *int64 + var createdAt *int64 - err := row.Scan(&responseBody, &responseStatusCode, &expiryTime) + err := row.Scan(&responseBody, &responseStatusCode, &createdAt) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return nil, ErrNoResponse + return nil, httpcache.ErrNoResult } return nil, err } - if expiryTime != nil && time.Now().Unix() > *expiryTime { - return nil, ErrNoResponse - } - return &http.Response{ Request: request, Body: io.NopCloser(bytes.NewReader(responseBody)), diff --git a/config.go b/config.go index b54716b..b3164cd 100644 --- a/config.go +++ b/config.go @@ -1,43 +1,27 @@ package httpcache import ( - "net/http" "time" ) -var ( - defaultAllowedStatusCodes = []int{http.StatusOK} - defaultAllowedMethods = []string{http.MethodGet} - defaultExpiryTime = time.Duration(60*24*7) * time.Minute -) - -// DefaultConfig creates a [Config] with default values, namely: -// - AllowedStatusCodes: [http.StatusOK] -// - AllowedMethods: [http.MethodGet] -// - ExpiryTime: 7 days. -var DefaultConfig = NewConfigBuilder(). - WithAllowedStatusCodes(defaultAllowedStatusCodes). - WithAllowedMethods(defaultAllowedMethods). - WithExpiryTime(defaultExpiryTime). - Build() +// DefaultConfig creates a [Config] with a configuration of: +// - Saves responses regardless of response status code, or HTTP request method; +// - Saved responses never expire. +var DefaultConfig = NewConfigBuilder().Build() -// Config describes the configuration to use when saving and reading responses -// from [Cache] using the [Transport]. +// Config describes the configuration to use when saving to, and reading responses +// from the [Cache]. type Config struct { // AllowedStatusCodes describes if a HTTP response should be saved by // checking that it's status code is accepted by [Cache]. If the HTTP // response's status code is not in AllowedStatusCodes, then do not persist. - // - // This is a required field. - AllowedStatusCodes []int + AllowedStatusCodes *[]int // AllowedMethods describes if a HTTP response should be saved by checking // if the HTTP request's method is accepted by the [Cache]. If the HTTP // request's method is not in AllowedMethods, then do not persist. - // - // This is a required field. - AllowedMethods []string + AllowedMethods *[]string // ExpiryTime describes when a HTTP response should be considered invalid. ExpiryTime *time.Duration diff --git a/config_builder.go b/config_builder.go index efa19f9..7268aff 100644 --- a/config_builder.go +++ b/config_builder.go @@ -13,14 +13,10 @@ type configBuilder struct { // allowedStatusCodes only persists HTTP responses that have an appropriate // status code (i.e. 200). - // - // This is a required field. allowedStatusCodes *[]int // allowedMethods only persists HTTP responses that use an appropriate HTTP // method (i.e. "GET"). - // - // This is a required field. allowedMethods *[]string // expiryTime invalidates HTTP responses after a duration has elapsed from @@ -47,12 +43,12 @@ func (c *configBuilder) WithExpiryTime(expiryDuration time.Duration) *configBuil return c } -// Build constructs a [Config] that is ready to be consumed by [Transport]. If +// Build constructs a [Config] that is ready to be consumed by [Cache]. If // the configuration passed by [configBuilder] is invalid, it will panic. func (c *configBuilder) Build() *Config { return &Config{ - AllowedStatusCodes: *c.allowedStatusCodes, - AllowedMethods: *c.allowedMethods, + AllowedStatusCodes: c.allowedStatusCodes, + AllowedMethods: c.allowedMethods, ExpiryTime: c.expiryTime, } } diff --git a/sqlite_cache_test.go b/sqlite_cache_test.go deleted file mode 100644 index 2d96b01..0000000 --- a/sqlite_cache_test.go +++ /dev/null @@ -1,323 +0,0 @@ -package httpcache_test - -import ( - "database/sql" - "errors" - "io" - "net/http" - "net/url" - "strings" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/alexmerren/httpcache" - "github.com/stretchr/testify/assert" -) - -const ( - testHost = "www.test.com" - testPath = "/test-path" - testBody = "this is a test body" -) - -const ( - insertQuery = `INSERT OR REPLACE INTO responses \( request_url, request_method, response_body, status_code, expiry_time\) VALUES \(\?, \?, \?, \?, \?\)` - selectQuery = `SELECT response_body, status_code, expiry_time FROM responses WHERE request_url = \? AND request_method = \?` -) - -func Test_Save_HappyPath(t *testing.T) { - // Given - response := aDummyResponse() - db, mockDatabase, closeFunc := aDatabaseMock(t) - defer closeFunc() - - mockDatabase. - ExpectExec(insertQuery). - WithArgs( - testHost+testPath, - http.MethodGet, - []byte(testBody), - http.StatusOK, - nil, - ). - WillReturnResult(sqlmock.NewResult(1, 1)) - - subject := &httpcache.SqliteCache{Database: db} - - // When - err := subject.Save(response, nil) - - // Then - assert.Nil(t, err) - if err := mockDatabase.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func Test_Save_HappyPathWithExpiryTime(t *testing.T) { - // Given - expiryTimestamp := time.Duration(1) * time.Minute - - response := aDummyResponse() - db, mockDatabase, closeFunc := aDatabaseMock(t) - defer closeFunc() - - mockDatabase. - ExpectExec(insertQuery). - WithArgs( - testHost+testPath, - http.MethodGet, - []byte(testBody), - http.StatusOK, - time.Now().Add(expiryTimestamp).Unix(), - ). - WillReturnResult(sqlmock.NewResult(1, 1)) - - subject := &httpcache.SqliteCache{Database: db} - - // When - err := subject.Save(response, &expiryTimestamp) - - // Then - assert.Nil(t, err) - if err := mockDatabase.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func Test_Save_ExecFails(t *testing.T) { - // Given - expiryTimestamp := time.Duration(1) * time.Minute - - response := aDummyResponse() - db, mockDatabase, closeFunc := aDatabaseMock(t) - defer closeFunc() - - mockDatabase. - ExpectExec(insertQuery). - WithArgs( - testHost+testPath, - http.MethodGet, - []byte(testBody), - http.StatusOK, - time.Now().Add(expiryTimestamp).Unix(), - ). - WillReturnError(errors.New("dummy error")) - - subject := &httpcache.SqliteCache{Database: db} - - // When - err := subject.Save(response, &expiryTimestamp) - - // Then - assert.NotNil(t, err) - if err := mockDatabase.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func Test_Read_HappyPath(t *testing.T) { - // Given - request := aDummyRequest() - db, mockDatabase, closeFunc := aDatabaseMock(t) - defer closeFunc() - - mockRows := sqlmock. - NewRows([]string{"response_body", "status_code", "expiry_time"}). - AddRow([]byte(testBody), 200, nil) - - mockDatabase. - ExpectQuery(selectQuery). - WithArgs(testHost+testPath, http.MethodGet). - WillReturnRows(mockRows) - - subject := &httpcache.SqliteCache{Database: db} - - // When - response, err := subject.Read(request) - - // Then - assert.Nil(t, err) - assert.NotNil(t, response) - - responseBody, err := io.ReadAll(response.Body) - defer response.Body.Close() - assert.Nil(t, err) - assert.Equal(t, string(responseBody), testBody) - - if err := mockDatabase.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func Test_Read_HappyPathWithExpiryTime(t *testing.T) { - // Given - expiryTimestamp := time.Duration(1) * time.Minute - - request := aDummyRequest() - db, mockDatabase, closeFunc := aDatabaseMock(t) - defer closeFunc() - - mockRows := sqlmock. - NewRows([]string{"response_body", "status_code", "expiry_time"}). - AddRow([]byte(testBody), 200, time.Now().Add(expiryTimestamp).Unix()) - - mockDatabase. - ExpectQuery(selectQuery). - WithArgs(testHost+testPath, http.MethodGet). - WillReturnRows(mockRows) - - subject := &httpcache.SqliteCache{Database: db} - - // When - response, err := subject.Read(request) - - // Then - assert.Nil(t, err) - assert.NotNil(t, response) - - responseBody, err := io.ReadAll(response.Body) - defer response.Body.Close() - assert.Nil(t, err) - assert.Equal(t, string(responseBody), testBody) - - if err := mockDatabase.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func Test_Read_ExpiryTimeHasPassed(t *testing.T) { - // Given - expiryTimestamp := -time.Duration(10) * time.Minute - - request := aDummyRequest() - db, mockDatabase, closeFunc := aDatabaseMock(t) - defer closeFunc() - - mockRows := sqlmock. - NewRows([]string{"response_body", "status_code", "expiry_time"}). - AddRow([]byte(testBody), 200, time.Now().Add(expiryTimestamp).Unix()) - - mockDatabase. - ExpectQuery(selectQuery). - WithArgs(testHost+testPath, http.MethodGet). - WillReturnRows(mockRows) - - subject := &httpcache.SqliteCache{Database: db} - - // When - response, err := subject.Read(request) - - // Then - assert.NotNil(t, err) - assert.ErrorIs(t, err, httpcache.ErrNoResponse) - assert.Nil(t, response) - - if err := mockDatabase.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func Test_Read_NoRowsFromScan(t *testing.T) { - // Given - request := aDummyRequest() - db, mockDatabase, closeFunc := aDatabaseMock(t) - defer closeFunc() - - mockDatabase. - ExpectQuery(selectQuery). - WithArgs(testHost+testPath, http.MethodGet). - WillReturnError(sql.ErrNoRows) - - subject := &httpcache.SqliteCache{Database: db} - - // When - response, err := subject.Read(request) - - // Then - assert.NotNil(t, err) - assert.ErrorIs(t, err, httpcache.ErrNoResponse) - assert.Nil(t, response) - - if err := mockDatabase.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func Test_Read_ScanError(t *testing.T) { - // Given - request := aDummyRequest() - db, mockDatabase, closeFunc := aDatabaseMock(t) - defer closeFunc() - - mockDatabase. - ExpectQuery(selectQuery). - WithArgs(testHost+testPath, http.MethodGet). - WillReturnError(sql.ErrConnDone) - - subject := &httpcache.SqliteCache{Database: db} - - // When - response, err := subject.Read(request) - - // Then - assert.NotNil(t, err) - assert.ErrorIs(t, err, sql.ErrConnDone) - assert.Nil(t, response) - - if err := mockDatabase.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func aDatabaseMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock, func() error) { - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("%v", err) - } - - return db, mock, db.Close -} - -func aDummyResponse() *http.Response { - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(testBody)), - Request: &http.Request{ - URL: &url.URL{ - Scheme: "https", - Host: testHost, - Opaque: "", - User: nil, - Path: testPath, - RawPath: "", - OmitHost: false, - ForceQuery: false, - RawQuery: "", - Fragment: "", - RawFragment: "", - }, - Method: http.MethodGet, - }, - } -} - -func aDummyRequest() *http.Request { - return &http.Request{ - URL: &url.URL{ - Scheme: "https", - Host: testHost, - Opaque: "", - User: nil, - Path: testPath, - RawPath: "", - OmitHost: false, - ForceQuery: false, - RawQuery: "", - Fragment: "", - RawFragment: "", - }, - Method: http.MethodGet, - } -} diff --git a/transport.go b/transport.go index a1cc82e..96a6299 100644 --- a/transport.go +++ b/transport.go @@ -3,6 +3,7 @@ package httpcache import ( "errors" "net/http" + "time" ) var ( @@ -27,13 +28,13 @@ type Transport struct { // cache handles persisting HTTP responses. cache Cache - // config describes which HTTP responses to cache and how they are cached. + // config handles logic on saving and reading HTTP responses. config *Config } // NewTransport creates a [Transport]. If the cache is nil, return // [ErrMissingCache]. If the config is nil, return [ErrMissingConfig]. -func NewTransport(config *Config, cache Cache) (*Transport, error) { +func NewTransport(cache Cache, config *Config) (*Transport, error) { if cache == nil { return nil, ErrMissingCache } @@ -50,42 +51,77 @@ func NewTransport(config *Config, cache Cache) (*Transport, error) { // RoundTrip wraps the [http.DefaultTransport] RoundTrip to execute HTTP // requests and persists, if necessary, the responses. If [Cache] returns -// [ErrNoResponse], then execute a HTTP request and persist the response if -// passing the criteria in [Transport.shouldSaveResponse]. +// [ErrNoResponse], then execute a HTTP request and persist the response func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { - response, err := t.cache.Read(request) - if err == nil { - return response, nil - } - - if !errors.Is(err, ErrNoResponse) { + result, err := t.cache.Read(request) + if err != nil && !errors.Is(err, ErrNoResult) { return nil, err } - response, err = t.transport.RoundTrip(request) - if err != nil { - return nil, err + isValidResult := isValidResult(t.config, result) + if isValidResult { + return result.response, nil } - if t.shouldSaveResponse(response.StatusCode, response.Request.Method) { - err = t.cache.Save(response, t.config.ExpiryTime) + resultIsInvalid := result != nil && !isValidResult + if resultIsInvalid { + err := t.cache.Delete(result.response) if err != nil { - response.Body.Close() return nil, err } } + response, err := t.transport.RoundTrip(request) + if err != nil { + return nil, err + } + + if ok := isEligibleToSave(t.config, response); !ok { + return response, nil + } + + err = t.cache.Save(response) + if err != nil { + response.Body.Close() + return nil, err + } + return response, nil } -// shouldSaveResponse is responsible for interpreting configuration values from -// [Config] to determine if a HTTP response should be persisted. Any new values -// added to [Config] can be used here as criteria. -func (t *Transport) shouldSaveResponse(statusCode int, method string) bool { - isAllowedStatusCode := contains(t.config.AllowedStatusCodes, statusCode) - isAllowedMethod := contains(t.config.AllowedMethods, method) +func isValidResult(config *Config, result *ReadResult) bool { + return isAllowedMethod(config, result.response.Request.Method) && + isAllowedStatusCode(config, result.response.StatusCode) && + isNotExpired(config, result.createdAt) +} - return isAllowedStatusCode && isAllowedMethod +func isEligibleToSave(config *Config, response *http.Response) bool { + return isAllowedMethod(config, response.Request.Method) && + isAllowedStatusCode(config, response.StatusCode) +} + +func isNotExpired(config *Config, createdAt *time.Time) bool { + hasExpiryTime := config.ExpiryTime != nil + if !hasExpiryTime { + return true + } + return time.Now().Before(createdAt.Add(*config.ExpiryTime)) +} + +func isAllowedStatusCode(config *Config, statusCode int) bool { + hasAllowedStatusCodes := config.AllowedStatusCodes != nil + if !hasAllowedStatusCodes { + return true + } + return contains(*config.AllowedStatusCodes, statusCode) +} + +func isAllowedMethod(config *Config, method string) bool { + hasAllowedMethods := config.AllowedMethods != nil + if !hasAllowedMethods { + return true + } + return contains(*config.AllowedMethods, method) } func contains[T comparable](slice []T, searchValue T) bool { @@ -94,5 +130,4 @@ func contains[T comparable](slice []T, searchValue T) bool { return true } } - return false } diff --git a/transport_test.go b/transport_test.go deleted file mode 100644 index 96ca46f..0000000 --- a/transport_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package httpcache_test - -import "testing" - -func Test_RoundTrip_HappyPath_ResponseSaved(t *testing.T) { - // given - // when - // then -} - -func Test_RoundTrip_HappyPath_ResponseNotSaved(t *testing.T) { - // given - // when - // then -} - -func Test_RoundTrip_ReadError(t *testing.T) { - // given - // when - // then -} - -func Test_RoundTrip_ResponseHasNotAllowedStatusCode(t *testing.T) { - // given - // when - // then -} - -func Test_RoundTrip_ResponseHasNotAllowedMethod(t *testing.T) { - // given - // when - // then -} - -func Test_RoundTrip_TransportError(t *testing.T) { - // given - // when - // then -} - -func Test_RoundTrip_SaveError(t *testing.T) { - // given - // when - // then -}