From 3f08e1450865344271b472f2a95506d6ef5b13a1 Mon Sep 17 00:00:00 2001 From: Alex Merren Date: Mon, 20 Jan 2025 14:22:50 +0000 Subject: [PATCH 01/12] Fix tests and add default configuration --- store.go => cache.go | 12 ++- config.go | 61 +++++++++++++ file.go | 34 ------- options.go | 45 ---------- round_tripper.go | 119 ------------------------- sqlite.go => sqlite_cache.go | 45 ++++++++-- sqlite_test.go => sqlite_cache_test.go | 4 +- transport.go | 96 ++++++++++++++++++++ 8 files changed, 208 insertions(+), 208 deletions(-) rename store.go => cache.go (78%) create mode 100644 config.go delete mode 100644 file.go delete mode 100644 options.go delete mode 100644 round_tripper.go rename sqlite.go => sqlite_cache.go (76%) rename sqlite_test.go => sqlite_cache_test.go (96%) create mode 100644 transport.go diff --git a/store.go b/cache.go similarity index 78% rename from store.go rename to cache.go index 6143ff7..4ba7505 100644 --- a/store.go +++ b/cache.go @@ -6,13 +6,23 @@ import ( "net/http" ) -type ResponseStorer interface { +// Add doc +type Cache interface { + + // Add doc Save(response *http.Response) error + + // Add doc Read(request *http.Request) (*http.Response, error) + + // Add doc SaveContext(ctx context.Context, response *http.Response) error + + // Add doc ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) } var ( + // Add doc ErrNoResponse = errors.New("no stored response") ) diff --git a/config.go b/config.go new file mode 100644 index 0000000..8c68594 --- /dev/null +++ b/config.go @@ -0,0 +1,61 @@ +package httpcache + +import "net/http" + +// Add doc +type Config struct { + + // Add doc + Cache Cache + + // Add doc + DeniedStatusCodes *[]int + + // Add doc + AllowedStatusCodes *[]int +} + +// Add doc +func NewConfig() *Config { + return &Config{} +} + +// Add doc +func (c *Config) WithDeniedStatusCodes(deniedStatusCodes []int) *Config { + c.DeniedStatusCodes = &deniedStatusCodes + return c +} + +// Add doc +func (c *Config) WithAllowedStatusCodes(allowedStatusCodes []int) *Config { + c.AllowedStatusCodes = &allowedStatusCodes + return c +} + +// Add doc +func (c *Config) WithCache(cache Cache) *Config { + c.Cache = cache + return c +} + +func DefaultConfig() (*Config, error) { + cache, err := NewSqliteCache("httpcache.sqlite") + if err != nil { + return nil, err + } + + defaultConfig := NewConfig(). + WithDeniedStatusCodes([]int{ + http.StatusNotFound, + http.StatusBadRequest, + http.StatusForbidden, + http.StatusUnauthorized, + http.StatusMethodNotAllowed, + }). + WithAllowedStatusCodes([]int{ + http.StatusOK, + }). + WithCache(cache) + + return defaultConfig, nil +} diff --git a/file.go b/file.go deleted file mode 100644 index 4e3a4b9..0000000 --- a/file.go +++ /dev/null @@ -1,34 +0,0 @@ -package httpcache - -import ( - "errors" - "os" - "path/filepath" -) - -func doesFileExist(filename string) (bool, error) { - _, err := os.Stat(filename) - if err != nil && errors.Is(err, os.ErrNotExist) { - return false, nil - } - - if err != nil && !errors.Is(err, os.ErrNotExist) { - return false, err - } - - return true, nil -} - -func createFile(filename string) (*os.File, error) { - directory, err := filepath.Abs(filepath.Dir(filename)) - if err != nil { - return nil, err - } - - err = os.MkdirAll(directory, os.ModePerm) - if err != nil { - return nil, err - } - - return os.Create(filename) -} diff --git a/options.go b/options.go deleted file mode 100644 index ac6bc2d..0000000 --- a/options.go +++ /dev/null @@ -1,45 +0,0 @@ -package httpcache - -import ( - "time" -) - -func WithDeniedStatusCodes(deniedStatusCodes []int) func(*CachedRoundTripper) error { - return func(c *CachedRoundTripper) error { - c.deniedStatusCodes = deniedStatusCodes - return nil - } -} - -func WithAllowedStatusCodes(allowedStatusCodes []int) func(*CachedRoundTripper) error { - return func(c *CachedRoundTripper) error { - c.allowedStatusCodes = allowedStatusCodes - return nil - } -} - -func WithExpiryTime(expiryTime time.Duration) func(*CachedRoundTripper) error { - return func(c *CachedRoundTripper) error { - c.expiryTime = expiryTime - return nil - } -} - -func WithName(name string) func(*CachedRoundTripper) error { - return func(c *CachedRoundTripper) error { - sqliteStore, err := newSqliteResponseStore(name) - if err != nil { - return err - } - c.store = sqliteStore - - return nil - } -} - -func WithCacheStore(store ResponseStorer) func(*CachedRoundTripper) error { - return func(c *CachedRoundTripper) error { - c.store = store - return nil - } -} diff --git a/round_tripper.go b/round_tripper.go deleted file mode 100644 index c221fd7..0000000 --- a/round_tripper.go +++ /dev/null @@ -1,119 +0,0 @@ -package httpcache - -import ( - "bytes" - "errors" - "io" - "net/http" - "time" -) - -var ( - defaultDeniedStatusCodes = []int{ - http.StatusNotFound, - http.StatusBadRequest, - http.StatusForbidden, - http.StatusUnauthorized, - http.StatusMethodNotAllowed, - } - - defaultAllowedStatusCodes = []int{ - http.StatusOK, - } - - defaultExpiryTime = time.Duration(0) - - defaultCacheName = "httpcache.sqlite" -) - -type CachedRoundTripper struct { - transport http.RoundTripper - store ResponseStorer - expiryTime time.Duration - deniedStatusCodes []int - allowedStatusCodes []int -} - -func NewCachedRoundTripper(options ...func(*CachedRoundTripper) error) (*CachedRoundTripper, error) { - roundTripper := &CachedRoundTripper{ - transport: http.DefaultTransport, - deniedStatusCodes: defaultDeniedStatusCodes, - allowedStatusCodes: defaultAllowedStatusCodes, - expiryTime: defaultExpiryTime, - } - - for _, optionFunc := range options { - err := optionFunc(roundTripper) - if err != nil { - return nil, err - } - } - - if roundTripper.store != nil { - return roundTripper, nil - } - - store, err := newSqliteResponseStore(defaultCacheName) - if err != nil { - return nil, err - } - roundTripper.store = store - - return roundTripper, nil -} - -func (h *CachedRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { - response, err := h.store.Read(request) - if err == nil { - return response, nil - } - - if !errors.Is(err, ErrNoResponse) { - return nil, err - } - - requestBody := []byte{} - if request.GetBody != nil { - body, err := request.GetBody() - if err != nil { - return nil, err - } - requestBody, err = io.ReadAll(body) - if err != nil { - return nil, err - } - defer body.Close() - } - - response, err = h.transport.RoundTrip(request) - if err != nil { - return nil, err - } - response.Request.Body = io.NopCloser(bytes.NewReader(requestBody)) - - if contains(h.deniedStatusCodes, response.StatusCode) { - return response, nil - } - - if !contains(h.allowedStatusCodes, response.StatusCode) { - return response, nil - } - - err = h.store.Save(response) - if err != nil { - response.Body.Close() - response.Request.Body.Close() - return nil, err - } - - return response, nil -} - -func contains(slice []int, searchValue int) bool { - for index := range slice { - if searchValue == slice[index] { - return true - } - } - return false -} diff --git a/sqlite.go b/sqlite_cache.go similarity index 76% rename from sqlite.go rename to sqlite_cache.go index 5865a04..2c9e090 100644 --- a/sqlite.go +++ b/sqlite_cache.go @@ -8,6 +8,8 @@ import ( "fmt" "io" "net/http" + "os" + "path/filepath" "strings" _ "github.com/mattn/go-sqlite3" @@ -19,11 +21,13 @@ const ( readRequestQuery = "SELECT response_body, status_code FROM responses WHERE request_url = ? AND request_method = ? AND request_body = ?" ) -type sqliteResponseStore struct { +// Add doc +type SqliteCache struct { database *sql.DB } -func newSqliteResponseStore(databaseName string) (*sqliteResponseStore, error) { +// Add doc +func NewSqliteCache(databaseName string) (*SqliteCache, error) { fileExists, err := doesFileExist(databaseName) if err != nil { return nil, err @@ -46,16 +50,16 @@ func newSqliteResponseStore(databaseName string) (*sqliteResponseStore, error) { return nil, err } - return &sqliteResponseStore{ + return &SqliteCache{ database: conn, }, nil } -func (s *sqliteResponseStore) Save(response *http.Response) error { +func (s *SqliteCache) Save(response *http.Response) error { return s.SaveContext(context.Background(), response) } -func (s *sqliteResponseStore) SaveContext(ctx context.Context, response *http.Response) error { +func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response) error { requestURL := response.Request.URL.Host + response.Request.URL.RequestURI() requestMethod := response.Request.Method requestBody := []byte{} @@ -106,11 +110,11 @@ func (s *sqliteResponseStore) SaveContext(ctx context.Context, response *http.Re return err } -func (s *sqliteResponseStore) Read(request *http.Request) (*http.Response, error) { +func (s *SqliteCache) Read(request *http.Request) (*http.Response, error) { return s.ReadContext(context.Background(), request) } -func (s *sqliteResponseStore) ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) { +func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) { requestURL := request.URL.Host + request.URL.RequestURI() requestMethod := request.Method requestBody := []byte{} @@ -152,3 +156,30 @@ func (s *sqliteResponseStore) ReadContext(ctx context.Context, request *http.Req StatusCode: responseStatusCode, }, nil } + +func doesFileExist(filename string) (bool, error) { + _, err := os.Stat(filename) + if err != nil && errors.Is(err, os.ErrNotExist) { + return false, nil + } + + if err != nil && !errors.Is(err, os.ErrNotExist) { + return false, err + } + + return true, nil +} + +func createFile(filename string) (*os.File, error) { + directory, err := filepath.Abs(filepath.Dir(filename)) + if err != nil { + return nil, err + } + + err = os.MkdirAll(directory, os.ModePerm) + if err != nil { + return nil, err + } + + return os.Create(filename) +} diff --git a/sqlite_test.go b/sqlite_cache_test.go similarity index 96% rename from sqlite_test.go rename to sqlite_cache_test.go index e70566b..93cc8f9 100644 --- a/sqlite_test.go +++ b/sqlite_cache_test.go @@ -29,7 +29,7 @@ func Test_Save_HappyPath(t *testing.T) { db, mock, closeFunc := aDatabaseMock(t) defer closeFunc() - subject := &sqliteResponseStore{database: db} + subject := &SqliteCache{database: db} mock.ExpectBegin() mock.ExpectExec(insertQuery).WithArgs( @@ -57,7 +57,7 @@ func Test_Read_HappyPath(t *testing.T) { db, mock, closeFunc := aDatabaseMock(t) defer closeFunc() - subject := &sqliteResponseStore{database: db} + subject := &SqliteCache{database: db} mockRows := sqlmock.NewRows([]string{"response_body", "status_code"}). AddRow("mock-body", 200). diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..d93db24 --- /dev/null +++ b/transport.go @@ -0,0 +1,96 @@ +package httpcache + +import ( + "bytes" + "errors" + "io" + "net/http" +) + +// Add doc +var ErrNoCache = errors.New("no cache set") + +// Add doc +type Transport struct { + + // Add doc + transport http.RoundTripper + + // Add doc + cache Cache + + // Add doc + config *Config +} + +// Add doc +func NewTransport(config *Config) (*Transport, error) { + transport := &Transport{ + transport: http.DefaultTransport, + cache: config.Cache, + config: config, + } + + if config.Cache == nil { + return nil, ErrNoCache + } + + return transport, nil +} + +// Add doc +func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { + response, err := t.cache.Read(request) + if err == nil { + return response, nil + } + + if !errors.Is(err, ErrNoResponse) { + return nil, err + } + + requestBody := []byte{} + if request.GetBody != nil { + body, err := request.GetBody() + if err != nil { + return nil, err + } + requestBody, err = io.ReadAll(body) + if err != nil { + return nil, err + } + defer body.Close() + } + + response, err = t.transport.RoundTrip(request) + if err != nil { + return nil, err + } + response.Request.Body = io.NopCloser(bytes.NewReader(requestBody)) + + if contains(*t.config.DeniedStatusCodes, response.StatusCode) { + return response, nil + } + + if !contains(*t.config.AllowedStatusCodes, response.StatusCode) { + return response, nil + } + + err = t.cache.Save(response) + if err != nil { + response.Body.Close() + response.Request.Body.Close() + return nil, err + } + + return response, nil +} + +func contains(slice []int, searchValue int) bool { + for index := range slice { + if searchValue == slice[index] { + return true + } + } + return false +} From 40701778e344066890b8025507759b92458dec5a Mon Sep 17 00:00:00 2001 From: Alex Merren Date: Mon, 20 Jan 2025 14:42:16 +0000 Subject: [PATCH 02/12] Add new empty tests for transport --- sqlite_cache.go | 10 +++++----- sqlite_cache_test.go | 7 ++++--- transport.go | 13 ++++++++----- transport_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 13 deletions(-) create mode 100644 transport_test.go diff --git a/sqlite_cache.go b/sqlite_cache.go index 2c9e090..eb978fa 100644 --- a/sqlite_cache.go +++ b/sqlite_cache.go @@ -23,7 +23,7 @@ const ( // Add doc type SqliteCache struct { - database *sql.DB + Database *sql.DB } // Add doc @@ -51,7 +51,7 @@ func NewSqliteCache(databaseName string) (*SqliteCache, error) { } return &SqliteCache{ - database: conn, + Database: conn, }, nil } @@ -88,7 +88,7 @@ func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response) response.Body = io.NopCloser(bytes.NewReader(responseBody)) responseStatusCode := response.StatusCode - tx, err := s.database.BeginTx(ctx, nil) + tx, err := s.Database.BeginTx(ctx, nil) if err != nil { return err } @@ -137,11 +137,11 @@ func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (* responseBody := "" responseStatusCode := 0 - if s.database == nil { + if s.Database == nil { return nil, fmt.Errorf("database is nil") } - row := s.database.QueryRow(readRequestQuery, requestURL, requestMethod, requestBody) + row := s.Database.QueryRow(readRequestQuery, requestURL, requestMethod, requestBody) err := row.Scan(&responseBody, &responseStatusCode) if err != nil { if errors.Is(err, sql.ErrNoRows) { diff --git a/sqlite_cache_test.go b/sqlite_cache_test.go index 93cc8f9..9b391c2 100644 --- a/sqlite_cache_test.go +++ b/sqlite_cache_test.go @@ -1,4 +1,4 @@ -package httpcache +package httpcache_test import ( "database/sql" @@ -9,6 +9,7 @@ import ( "testing" "github.com/DATA-DOG/go-sqlmock" + "github.com/alexmerren/httpcache" "github.com/stretchr/testify/assert" ) @@ -29,7 +30,7 @@ func Test_Save_HappyPath(t *testing.T) { db, mock, closeFunc := aDatabaseMock(t) defer closeFunc() - subject := &SqliteCache{database: db} + subject := &httpcache.SqliteCache{Database: db} mock.ExpectBegin() mock.ExpectExec(insertQuery).WithArgs( @@ -57,7 +58,7 @@ func Test_Read_HappyPath(t *testing.T) { db, mock, closeFunc := aDatabaseMock(t) defer closeFunc() - subject := &SqliteCache{database: db} + subject := &httpcache.SqliteCache{Database: db} mockRows := sqlmock.NewRows([]string{"response_body", "status_code"}). AddRow("mock-body", 200). diff --git a/transport.go b/transport.go index d93db24..d725c70 100644 --- a/transport.go +++ b/transport.go @@ -68,11 +68,7 @@ func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { } response.Request.Body = io.NopCloser(bytes.NewReader(requestBody)) - if contains(*t.config.DeniedStatusCodes, response.StatusCode) { - return response, nil - } - - if !contains(*t.config.AllowedStatusCodes, response.StatusCode) { + if !t.shouldSaveResponse(response.StatusCode) { return response, nil } @@ -86,6 +82,13 @@ func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { return response, nil } +func (t *Transport) shouldSaveResponse(statusCode int) bool { + isDeniedStatusCode := contains(*t.config.DeniedStatusCodes, statusCode) + isAllowedStatusCode := contains(*t.config.AllowedStatusCodes, statusCode) + + return isDeniedStatusCode || !isAllowedStatusCode +} + func contains(slice []int, searchValue int) bool { for index := range slice { if searchValue == slice[index] { diff --git a/transport_test.go b/transport_test.go new file mode 100644 index 0000000..8c49682 --- /dev/null +++ b/transport_test.go @@ -0,0 +1,39 @@ +package httpcache_test + +import "testing" + +func Test_RoundTrip_ResponseSavedHappyPath(t *testing.T) { + // given + // when + // then +} + +func Test_RoundTrip_ResponseNotSavedHappyPath(t *testing.T) { + // given + // when + // then +} + +func Test_RoundTrip_ReadFailure(t *testing.T) { + // given + // when + // then +} + +func Test_RoundTrip_ResponseHasDeniedCode(t *testing.T) { + // given + // when + // then +} + +func Test_RoundTrip_ResponseHasMiscCode(t *testing.T) { + // given + // when + // then +} + +func Test_RoundTrip_SaveFailure(t *testing.T) { + // given + // when + // then +} From 0e5621130dfa87b70cd43f62bfb4a46186537af9 Mon Sep 17 00:00:00 2001 From: Alex Merren Date: Mon, 20 Jan 2025 15:01:19 +0000 Subject: [PATCH 03/12] Add allowed methods --- config.go | 22 +++++++++------------- transport.go | 14 +++++++------- transport_test.go | 18 ++++++++++++------ 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/config.go b/config.go index 8c68594..dd9c0a9 100644 --- a/config.go +++ b/config.go @@ -9,10 +9,10 @@ type Config struct { Cache Cache // Add doc - DeniedStatusCodes *[]int + AllowedStatusCodes *[]int // Add doc - AllowedStatusCodes *[]int + AllowedMethods *[]string } // Add doc @@ -21,14 +21,14 @@ func NewConfig() *Config { } // Add doc -func (c *Config) WithDeniedStatusCodes(deniedStatusCodes []int) *Config { - c.DeniedStatusCodes = &deniedStatusCodes +func (c *Config) WithAllowedStatusCodes(allowedStatusCodes []int) *Config { + c.AllowedStatusCodes = &allowedStatusCodes return c } // Add doc -func (c *Config) WithAllowedStatusCodes(allowedStatusCodes []int) *Config { - c.AllowedStatusCodes = &allowedStatusCodes +func (c *Config) WithAllowedMethods(allowedMethods []string) *Config { + c.AllowedMethods = &allowedMethods return c } @@ -45,16 +45,12 @@ func DefaultConfig() (*Config, error) { } defaultConfig := NewConfig(). - WithDeniedStatusCodes([]int{ - http.StatusNotFound, - http.StatusBadRequest, - http.StatusForbidden, - http.StatusUnauthorized, - http.StatusMethodNotAllowed, - }). WithAllowedStatusCodes([]int{ http.StatusOK, }). + WithAllowedMethods([]string{ + http.MethodGet, + }). WithCache(cache) return defaultConfig, nil diff --git a/transport.go b/transport.go index d725c70..e02fdc1 100644 --- a/transport.go +++ b/transport.go @@ -68,7 +68,7 @@ func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { } response.Request.Body = io.NopCloser(bytes.NewReader(requestBody)) - if !t.shouldSaveResponse(response.StatusCode) { + if !t.shouldSaveResponse(response.StatusCode, response.Request.Method) { return response, nil } @@ -82,16 +82,16 @@ func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { return response, nil } -func (t *Transport) shouldSaveResponse(statusCode int) bool { - isDeniedStatusCode := contains(*t.config.DeniedStatusCodes, statusCode) +func (t *Transport) shouldSaveResponse(statusCode int, method string) bool { isAllowedStatusCode := contains(*t.config.AllowedStatusCodes, statusCode) + isAllowedMethod := contains(*t.config.AllowedMethods, method) - return isDeniedStatusCode || !isAllowedStatusCode + return !isAllowedStatusCode || !isAllowedMethod } -func contains(slice []int, searchValue int) bool { - for index := range slice { - if searchValue == slice[index] { +func contains[T comparable](slice []T, searchValue T) bool { + for _, value := range slice { + if value == searchValue { return true } } diff --git a/transport_test.go b/transport_test.go index 8c49682..96ca46f 100644 --- a/transport_test.go +++ b/transport_test.go @@ -2,37 +2,43 @@ package httpcache_test import "testing" -func Test_RoundTrip_ResponseSavedHappyPath(t *testing.T) { +func Test_RoundTrip_HappyPath_ResponseSaved(t *testing.T) { // given // when // then } -func Test_RoundTrip_ResponseNotSavedHappyPath(t *testing.T) { +func Test_RoundTrip_HappyPath_ResponseNotSaved(t *testing.T) { // given // when // then } -func Test_RoundTrip_ReadFailure(t *testing.T) { +func Test_RoundTrip_ReadError(t *testing.T) { // given // when // then } -func Test_RoundTrip_ResponseHasDeniedCode(t *testing.T) { +func Test_RoundTrip_ResponseHasNotAllowedStatusCode(t *testing.T) { // given // when // then } -func Test_RoundTrip_ResponseHasMiscCode(t *testing.T) { +func Test_RoundTrip_ResponseHasNotAllowedMethod(t *testing.T) { // given // when // then } -func Test_RoundTrip_SaveFailure(t *testing.T) { +func Test_RoundTrip_TransportError(t *testing.T) { + // given + // when + // then +} + +func Test_RoundTrip_SaveError(t *testing.T) { // given // when // then From 57a203dc342444373d23bd03821c14320dae45a3 Mon Sep 17 00:00:00 2001 From: alexmerren Date: Mon, 10 Feb 2025 21:54:17 +0000 Subject: [PATCH 04/12] Remove request body from table, simplify code a LOT --- sqlite_cache.go | 77 +++++++++----------------------------------- sqlite_cache_test.go | 26 +++++++-------- transport.go | 19 +---------- 3 files changed, 29 insertions(+), 93 deletions(-) diff --git a/sqlite_cache.go b/sqlite_cache.go index eb978fa..8ce7d74 100644 --- a/sqlite_cache.go +++ b/sqlite_cache.go @@ -10,15 +10,14 @@ import ( "net/http" "os" "path/filepath" - "strings" _ "github.com/mattn/go-sqlite3" ) const ( - createDatabaseQuery = "CREATE TABLE IF NOT EXISTS responses (request_url TEXT PRIMARY KEY, request_method TEXT, request_body TEXT, response_body TEXT, status_code INTEGER)" - saveRequestQuery = "INSERT INTO responses (request_url, request_method, request_body, response_body, status_code) VALUES (?, ?, ?, ?, ?)" - readRequestQuery = "SELECT response_body, status_code FROM responses WHERE request_url = ? AND request_method = ? AND request_body = ?" + createDatabaseQuery = "CREATE TABLE IF NOT EXISTS responses (request_url TEXT PRIMARY KEY, request_method TEXT, response_body BLOB, status_code INTEGER)" + saveRequestQuery = "INSERT INTO responses (request_url, request_method, response_body, status_code) VALUES (?, ?, ?, ?)" + readRequestQuery = "SELECT response_body, status_code FROM responses WHERE request_url = ? AND request_method = ?" ) // Add doc @@ -60,49 +59,20 @@ func (s *SqliteCache) Save(response *http.Response) error { } func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response) error { - requestURL := response.Request.URL.Host + response.Request.URL.RequestURI() + requestUrl := response.Request.URL.Host + response.Request.URL.RequestURI() requestMethod := response.Request.Method - requestBody := []byte{} - // If the request has no body (e.g. GET request) then Request.GetBody will - // be nil. Check if it is nil to get request body for other types of request - // (e.g. POST, PATCH). - if response.Request.GetBody != nil { - body, err := response.Request.GetBody() - if err != nil { - return err - } - requestBody, err = io.ReadAll(body) - if err != nil { - return err - } - defer body.Close() - } - - responseBody, err := io.ReadAll(response.Body) + responseBody := bytes.NewBuffer(nil) + _, err := io.Copy(responseBody, response.Body) if err != nil { return err } // Reset the response body stream to the beginning to be read again. - response.Body = io.NopCloser(bytes.NewReader(responseBody)) + response.Body = io.NopCloser(responseBody) responseStatusCode := response.StatusCode - tx, err := s.Database.BeginTx(ctx, nil) - if err != nil { - return err - } - - defer func() { - switch err { - case nil: - err = tx.Commit() - default: - tx.Rollback() - } - }() - - _, err = tx.Exec(saveRequestQuery, requestURL, requestMethod, requestBody, responseBody, responseStatusCode) + _, err = s.Database.Exec(saveRequestQuery, requestUrl, requestMethod, responseBody.Bytes(), responseStatusCode) if err != nil { return err } @@ -115,33 +85,16 @@ func (s *SqliteCache) Read(request *http.Request) (*http.Response, error) { } func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) { - requestURL := request.URL.Host + request.URL.RequestURI() - requestMethod := request.Method - requestBody := []byte{} - - // If the request has no body (e.g. GET request) then Request.GetBody will - // be nil. Check if it is nil to get request body for other types of request - // (e.g. POST, PATCH). - if request.GetBody != nil { - body, err := request.GetBody() - if err != nil { - return nil, err - } - requestBody, err = io.ReadAll(body) - if err != nil { - return nil, err - } - defer body.Close() - } - - responseBody := "" - responseStatusCode := 0 - if s.Database == nil { return nil, fmt.Errorf("database is nil") } - row := s.Database.QueryRow(readRequestQuery, requestURL, requestMethod, requestBody) + requestURL := request.URL.Host + request.URL.RequestURI() + requestMethod := request.Method + row := s.Database.QueryRow(readRequestQuery, requestURL, requestMethod) + + var responseBody []byte + var responseStatusCode int err := row.Scan(&responseBody, &responseStatusCode) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -152,7 +105,7 @@ func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (* return &http.Response{ Request: request, - Body: io.NopCloser(strings.NewReader(responseBody)), + Body: io.NopCloser(bytes.NewReader(responseBody)), StatusCode: responseStatusCode, }, nil } diff --git a/sqlite_cache_test.go b/sqlite_cache_test.go index 9b391c2..2d27270 100644 --- a/sqlite_cache_test.go +++ b/sqlite_cache_test.go @@ -20,8 +20,8 @@ const ( ) const ( - insertQuery = `INSERT INTO responses \(request_url, request_method, request_body, response_body, status_code\) VALUES \(\?, \?, \?, \?, \?\)` - selectQuery = `SELECT response_body, status_code FROM responses WHERE request_url = \? AND request_method = \? AND request_body = \?` + insertQuery = `INSERT INTO responses \(request_url, request_method, response_body, status_code\) VALUES \(\?, \?, \?, \?\)` + selectQuery = `SELECT response_body, status_code FROM responses WHERE request_url = \? AND request_method = \?` ) func Test_Save_HappyPath(t *testing.T) { @@ -30,17 +30,14 @@ func Test_Save_HappyPath(t *testing.T) { db, mock, closeFunc := aDatabaseMock(t) defer closeFunc() - subject := &httpcache.SqliteCache{Database: db} - - mock.ExpectBegin() mock.ExpectExec(insertQuery).WithArgs( testHost+testPath, http.MethodGet, - []byte{}, []byte(testBody), http.StatusOK, ).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() + + subject := &httpcache.SqliteCache{Database: db} // When err := subject.Save(response) @@ -58,23 +55,26 @@ func Test_Read_HappyPath(t *testing.T) { db, mock, closeFunc := aDatabaseMock(t) defer closeFunc() - subject := &httpcache.SqliteCache{Database: db} - - mockRows := sqlmock.NewRows([]string{"response_body", "status_code"}). - AddRow("mock-body", 200). - AddRow("mock-body", 200) + mockRows := sqlmock.NewRows([]string{"response_body", "status_code"}).AddRow([]byte(testBody), 200) mock.ExpectQuery(selectQuery).WithArgs( testHost+testPath, http.MethodGet, - []byte{}, ).WillReturnRows(mockRows) + subject := &httpcache.SqliteCache{Database: db} + // When response, err := subject.Read(request) // Then assert.Nil(t, err) assert.NotNil(t, response) + + responseBody, err := io.ReadAll(response.Body) + defer response.Body.Close() + assert.Nil(t, err) + assert.Equal(t, string(responseBody), testBody) + if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } diff --git a/transport.go b/transport.go index e02fdc1..4cb3c93 100644 --- a/transport.go +++ b/transport.go @@ -1,9 +1,7 @@ package httpcache import ( - "bytes" "errors" - "io" "net/http" ) @@ -49,24 +47,10 @@ func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { return nil, err } - requestBody := []byte{} - if request.GetBody != nil { - body, err := request.GetBody() - if err != nil { - return nil, err - } - requestBody, err = io.ReadAll(body) - if err != nil { - return nil, err - } - defer body.Close() - } - response, err = t.transport.RoundTrip(request) if err != nil { return nil, err } - response.Request.Body = io.NopCloser(bytes.NewReader(requestBody)) if !t.shouldSaveResponse(response.StatusCode, response.Request.Method) { return response, nil @@ -75,7 +59,6 @@ func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { err = t.cache.Save(response) if err != nil { response.Body.Close() - response.Request.Body.Close() return nil, err } @@ -86,7 +69,7 @@ func (t *Transport) shouldSaveResponse(statusCode int, method string) bool { isAllowedStatusCode := contains(*t.config.AllowedStatusCodes, statusCode) isAllowedMethod := contains(*t.config.AllowedMethods, method) - return !isAllowedStatusCode || !isAllowedMethod + return isAllowedStatusCode && isAllowedMethod } func contains[T comparable](slice []T, searchValue T) bool { From 7e3ba029b3064a2f4ed5342e6834c8d3b77d6bbe Mon Sep 17 00:00:00 2001 From: alexmerren Date: Mon, 10 Feb 2025 22:00:47 +0000 Subject: [PATCH 05/12] More simplification, invert conditions to read cleaner --- sqlite_cache.go | 12 ++++-------- transport.go | 14 ++++++-------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/sqlite_cache.go b/sqlite_cache.go index 8ce7d74..5f5e079 100644 --- a/sqlite_cache.go +++ b/sqlite_cache.go @@ -59,9 +59,6 @@ func (s *SqliteCache) Save(response *http.Response) error { } func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response) error { - requestUrl := response.Request.URL.Host + response.Request.URL.RequestURI() - requestMethod := response.Request.Method - responseBody := bytes.NewBuffer(nil) _, err := io.Copy(responseBody, response.Body) if err != nil { @@ -70,9 +67,9 @@ func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response) // Reset the response body stream to the beginning to be read again. response.Body = io.NopCloser(responseBody) - responseStatusCode := response.StatusCode - _, err = s.Database.Exec(saveRequestQuery, requestUrl, requestMethod, responseBody.Bytes(), responseStatusCode) + requestUrl := response.Request.URL.Host + response.Request.URL.RequestURI() + _, err = s.Database.Exec(saveRequestQuery, requestUrl, response.Request.Method, responseBody.Bytes(), response.StatusCode) if err != nil { return err } @@ -89,9 +86,8 @@ func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (* return nil, fmt.Errorf("database is nil") } - requestURL := request.URL.Host + request.URL.RequestURI() - requestMethod := request.Method - row := s.Database.QueryRow(readRequestQuery, requestURL, requestMethod) + requestUrl := request.URL.Host + request.URL.RequestURI() + row := s.Database.QueryRow(readRequestQuery, requestUrl, request.Method) var responseBody []byte var responseStatusCode int diff --git a/transport.go b/transport.go index 4cb3c93..73cc6ac 100644 --- a/transport.go +++ b/transport.go @@ -52,14 +52,12 @@ func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { return nil, err } - if !t.shouldSaveResponse(response.StatusCode, response.Request.Method) { - return response, nil - } - - err = t.cache.Save(response) - if err != nil { - response.Body.Close() - return nil, err + if t.shouldSaveResponse(response.StatusCode, response.Request.Method) { + err = t.cache.Save(response) + if err != nil { + response.Body.Close() + return nil, err + } } return response, nil From a39a8698726babccdac7265900b64b81ed021fd2 Mon Sep 17 00:00:00 2001 From: alexmerren Date: Mon, 10 Feb 2025 22:03:45 +0000 Subject: [PATCH 06/12] Add method to extract URL from request struct --- sqlite_cache.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sqlite_cache.go b/sqlite_cache.go index 5f5e079..6523715 100644 --- a/sqlite_cache.go +++ b/sqlite_cache.go @@ -68,7 +68,7 @@ func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response) // Reset the response body stream to the beginning to be read again. response.Body = io.NopCloser(responseBody) - requestUrl := response.Request.URL.Host + response.Request.URL.RequestURI() + requestUrl := generateUrl(response.Request) _, err = s.Database.Exec(saveRequestQuery, requestUrl, response.Request.Method, responseBody.Bytes(), response.StatusCode) if err != nil { return err @@ -86,7 +86,7 @@ func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (* return nil, fmt.Errorf("database is nil") } - requestUrl := request.URL.Host + request.URL.RequestURI() + requestUrl := generateUrl(request) row := s.Database.QueryRow(readRequestQuery, requestUrl, request.Method) var responseBody []byte @@ -106,6 +106,10 @@ func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (* }, nil } +func generateUrl(request *http.Request) string { + return request.URL.Host + request.URL.RequestURI() +} + func doesFileExist(filename string) (bool, error) { _, err := os.Stat(filename) if err != nil && errors.Is(err, os.ErrNotExist) { From b20246497c84ae6f4392c70f88a380634ae4777a Mon Sep 17 00:00:00 2001 From: alexmerren Date: Sun, 9 Mar 2025 16:46:40 +0000 Subject: [PATCH 07/12] Add configuration builder, and add documentation --- Makefile | 2 +- cache.go | 16 +++++---- config.go | 85 +++++++++++++++++++++-------------------------- config_builder.go | 58 ++++++++++++++++++++++++++++++++ sqlite_cache.go | 17 ++++++---- transport.go | 56 ++++++++++++++++++++----------- 6 files changed, 154 insertions(+), 80 deletions(-) create mode 100644 config_builder.go diff --git a/Makefile b/Makefile index 7653b7f..02d8971 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ fmt: $(DRUN) $(GO) fmt ./... vet: - $(DRUN) $(GO) vet ./. + $(DRUN) $(GO) vet ./... bash: $(DRUN) /bin/bash \ No newline at end of file diff --git a/cache.go b/cache.go index 4ba7505..a1caba4 100644 --- a/cache.go +++ b/cache.go @@ -6,23 +6,27 @@ import ( "net/http" ) -// Add doc +// Cache is the entrypoint for saving and reading responses. This can be +// implemented for a custom method to cache responses. type Cache interface { - // Add doc + // Save a response for a HTTP request using [context.Background]. Save(response *http.Response) error - // Add doc + // Read a saved response for a HTTP request using [context.Background]. Read(request *http.Request) (*http.Response, error) - // Add doc + // Save a response for a HTTP request with a [context.Context]. SaveContext(ctx context.Context, response *http.Response) error - // Add doc + // Read a saved response for a HTTP request with a [context.Context]. If no + // response is saved for the corresponding request, return [ErrNoResponse]. ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) } var ( - // Add doc + // ErrNoResponse describes when the cache does not have a response stored. + // [Transport] will check if ErrNoResponse is returned from [Cache.Read]. If + // ErrNoResponse is returned, then the request/response will be saved with [Save]. ErrNoResponse = errors.New("no stored response") ) diff --git a/config.go b/config.go index dd9c0a9..f7f87fc 100644 --- a/config.go +++ b/config.go @@ -1,57 +1,48 @@ package httpcache -import "net/http" - -// Add doc +import ( + "net/http" + "time" +) + +var ( + defaultAllowedStatusCodes = []int{http.StatusOK} + defaultAllowedMethods = []string{http.MethodGet} + defaultExpiryTime = time.Duration(60*24*7) * time.Minute +) + +// Config describes the configuration to use when saving and reading responses +// from [Cache] using the [Transport]. type Config struct { - // Add doc - Cache Cache - - // Add doc - AllowedStatusCodes *[]int - - // Add doc - AllowedMethods *[]string -} - -// Add doc -func NewConfig() *Config { - return &Config{} -} - -// Add doc -func (c *Config) WithAllowedStatusCodes(allowedStatusCodes []int) *Config { - c.AllowedStatusCodes = &allowedStatusCodes - return c -} - -// Add doc -func (c *Config) WithAllowedMethods(allowedMethods []string) *Config { - c.AllowedMethods = &allowedMethods - return c -} - -// Add doc -func (c *Config) WithCache(cache Cache) *Config { - c.Cache = cache - return c + // AllowedStatusCodes describes if a HTTP response should be saved by + // checking that it's status code is accepted by [Cache]. If the HTTP + // response's status code is not in AllowedStatusCodes, then do not persist. + // + // This is a required field. + AllowedStatusCodes []int + + // AllowedMethods describes if a HTTP response should be saved by checking + // if the HTTP request's method is accepted by the [Cache]. If the HTTP + // request's method is not in AllowedMethods, then do not persist. + // + // This is a required field. + AllowedMethods []string + + // ExpiryTime describes when a HTTP response should be considered invalid. + ExpiryTime *time.Duration } +// DefaultConfig creates a [Config] with default values, namely: +// - AllowedStatusCodes: [http.StatusOK] +// - AllowedMethods: [http.MethodGet] +// - ExpiryTime: 7 days. func DefaultConfig() (*Config, error) { - cache, err := NewSqliteCache("httpcache.sqlite") - if err != nil { - return nil, err - } - - defaultConfig := NewConfig(). - WithAllowedStatusCodes([]int{ - http.StatusOK, - }). - WithAllowedMethods([]string{ - http.MethodGet, - }). - WithCache(cache) + defaultConfig := NewConfigBuilder(). + WithAllowedStatusCodes(defaultAllowedStatusCodes). + WithAllowedMethods(defaultAllowedMethods). + WithExpiryTime(defaultExpiryTime). + Build() return defaultConfig, nil } diff --git a/config_builder.go b/config_builder.go new file mode 100644 index 0000000..efa19f9 --- /dev/null +++ b/config_builder.go @@ -0,0 +1,58 @@ +package httpcache + +import "time" + +// If adding new fields to [Config], then add the corresponding field and +// methods to [configBuilder]. Configuration values in [configBuilder] always +// use pointer types. This allows the [Config] decide which fields are +// required. + +// configBuilder is an internal structure to create configs with required +// parameters. +type configBuilder struct { + + // allowedStatusCodes only persists HTTP responses that have an appropriate + // status code (i.e. 200). + // + // This is a required field. + allowedStatusCodes *[]int + + // allowedMethods only persists HTTP responses that use an appropriate HTTP + // method (i.e. "GET"). + // + // This is a required field. + allowedMethods *[]string + + // expiryTime invalidates HTTP responses after a duration has elapsed from + // [time.Now]. Set to nil for no expiry. + expiryTime *time.Duration +} + +func NewConfigBuilder() *configBuilder { + return &configBuilder{} +} + +func (c *configBuilder) WithAllowedStatusCodes(allowedStatusCodes []int) *configBuilder { + c.allowedStatusCodes = &allowedStatusCodes + return c +} + +func (c *configBuilder) WithAllowedMethods(allowedMethods []string) *configBuilder { + c.allowedMethods = &allowedMethods + return c +} + +func (c *configBuilder) WithExpiryTime(expiryDuration time.Duration) *configBuilder { + c.expiryTime = &expiryDuration + return c +} + +// Build constructs a [Config] that is ready to be consumed by [Transport]. If +// the configuration passed by [configBuilder] is invalid, it will panic. +func (c *configBuilder) Build() *Config { + return &Config{ + AllowedStatusCodes: *c.allowedStatusCodes, + AllowedMethods: *c.allowedMethods, + ExpiryTime: c.expiryTime, + } +} diff --git a/sqlite_cache.go b/sqlite_cache.go index 6523715..36f5c7a 100644 --- a/sqlite_cache.go +++ b/sqlite_cache.go @@ -20,12 +20,15 @@ const ( readRequestQuery = "SELECT response_body, status_code FROM responses WHERE request_url = ? AND request_method = ?" ) -// Add doc +// SqliteCache is a default implementation of [Cache] which creates a local +// SQLite cache to persist and to query HTTP responses. type SqliteCache struct { Database *sql.DB } -// Add doc +// NewSqliteCache creates a new SQLite database with a certain name. This name +// is the filename of the database. If the file does not exist, then we create +// it. If the file is in a non-existent directory, we create the directory. func NewSqliteCache(databaseName string) (*SqliteCache, error) { fileExists, err := doesFileExist(databaseName) if err != nil { @@ -58,6 +61,10 @@ func (s *SqliteCache) Save(response *http.Response) error { return s.SaveContext(context.Background(), response) } +func (s *SqliteCache) Read(request *http.Request) (*http.Response, error) { + return s.ReadContext(context.Background(), request) +} + func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response) error { responseBody := bytes.NewBuffer(nil) _, err := io.Copy(responseBody, response.Body) @@ -74,11 +81,7 @@ func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response) return err } - return err -} - -func (s *SqliteCache) Read(request *http.Request) (*http.Response, error) { - return s.ReadContext(context.Background(), request) + return nil } func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) { diff --git a/transport.go b/transport.go index 73cc6ac..c48a58c 100644 --- a/transport.go +++ b/transport.go @@ -5,38 +5,53 @@ import ( "net/http" ) -// Add doc -var ErrNoCache = errors.New("no cache set") +var ( + // ErrMissingCache will be returned if cache is not set when creating an + // instance of [Transport]. + ErrMissingCache = errors.New("cache not set when creating transport") + + // ErrMissingConfig will be returned if config is not set when creating an + // instance of [Transport]. + ErrMissingConfig = errors.New("config not set when creating transport") +) -// Add doc +// Transport is the main interface of the package. It uses [Cache] to persist +// HTTP response, and and uses configuration values from [Config] to interpret +// requests and responses. type Transport struct { - // Add doc + // transport handles HTTP requests. This is hardcoded to be + // [http.DefaultTransport]. transport http.RoundTripper - // Add doc + // cache handles persisting HTTP responses. cache Cache - // Add doc + // config describes which HTTP responses to cache and how they are cached. config *Config } -// Add doc -func NewTransport(config *Config) (*Transport, error) { - transport := &Transport{ - transport: http.DefaultTransport, - cache: config.Cache, - config: config, +// NewTransport creates a [Transport]. If the cache is nil, return +// [ErrMissingCache]. If the config is nil, return [ErrMissingConfig]. +func NewTransport(config *Config, cache Cache) (*Transport, error) { + if cache == nil { + return nil, ErrMissingCache } - - if config.Cache == nil { - return nil, ErrNoCache + if config == nil { + return nil, ErrMissingConfig } - return transport, nil + return &Transport{ + transport: http.DefaultTransport, + cache: cache, + config: config, + }, nil } -// Add doc +// RoundTrip wraps the [http.DefaultTransport] RoundTrip to execute HTTP +// requests and persists, if necessary, the responses. If [Cache] returns +// [ErrNoResponse], then execute a HTTP request and persist the response if +// passing the criteria in [Transport.shouldSaveResponse]. func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { response, err := t.cache.Read(request) if err == nil { @@ -63,9 +78,12 @@ func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { return response, nil } +// shouldSaveResponse is responsible for interpreting configuration values from +// [Config] to determine if a HTTP response should be persisted. Any new values +// added to [Config] can be used here as criteria. func (t *Transport) shouldSaveResponse(statusCode int, method string) bool { - isAllowedStatusCode := contains(*t.config.AllowedStatusCodes, statusCode) - isAllowedMethod := contains(*t.config.AllowedMethods, method) + isAllowedStatusCode := contains(t.config.AllowedStatusCodes, statusCode) + isAllowedMethod := contains(t.config.AllowedMethods, method) return isAllowedStatusCode && isAllowedMethod } From f013fd10f73b633e332513c79d0862d68a718b22 Mon Sep 17 00:00:00 2001 From: alexmerren Date: Sun, 9 Mar 2025 17:49:50 +0000 Subject: [PATCH 08/12] Add expiry time and add tests to verify behaviour --- cache.go | 5 +- sqlite_cache.go | 64 ++++++++++-- sqlite_cache_test.go | 227 +++++++++++++++++++++++++++++++++++++++---- transport.go | 9 +- 4 files changed, 274 insertions(+), 31 deletions(-) diff --git a/cache.go b/cache.go index a1caba4..0a9be1e 100644 --- a/cache.go +++ b/cache.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "time" ) // Cache is the entrypoint for saving and reading responses. This can be @@ -11,13 +12,13 @@ import ( type Cache interface { // Save a response for a HTTP request using [context.Background]. - Save(response *http.Response) error + Save(response *http.Response, expiryTime *time.Time) error // Read a saved response for a HTTP request using [context.Background]. Read(request *http.Request) (*http.Response, error) // Save a response for a HTTP request with a [context.Context]. - SaveContext(ctx context.Context, response *http.Response) error + SaveContext(ctx context.Context, response *http.Response, expiryTime *time.Time) error // Read a saved response for a HTTP request with a [context.Context]. If no // response is saved for the corresponding request, return [ErrNoResponse]. diff --git a/sqlite_cache.go b/sqlite_cache.go index 36f5c7a..73b99f7 100644 --- a/sqlite_cache.go +++ b/sqlite_cache.go @@ -5,21 +5,44 @@ import ( "context" "database/sql" "errors" - "fmt" "io" "net/http" "os" "path/filepath" + "time" _ "github.com/mattn/go-sqlite3" ) const ( - createDatabaseQuery = "CREATE TABLE IF NOT EXISTS responses (request_url TEXT PRIMARY KEY, request_method TEXT, response_body BLOB, status_code INTEGER)" - saveRequestQuery = "INSERT INTO responses (request_url, request_method, response_body, status_code) VALUES (?, ?, ?, ?)" - readRequestQuery = "SELECT response_body, status_code FROM responses WHERE request_url = ? AND request_method = ?" + createDatabaseQuery = ` + CREATE TABLE IF NOT EXISTS responses ( + request_url TEXT PRIMARY KEY, + request_method TEXT NOT NULL, + response_body BLOB NOT NULL, + status_code INTEGER NOT NULL, + expiry_time INTEGER)` + + saveRequestQuery = ` + INSERT INTO responses ( + request_url, + request_method, + response_body, + status_code, + expiry_time) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT REPLACE` + + readRequestQuery = ` + SELECT response_body, status_code, expiry_time + FROM responses + WHERE request_url = ? AND request_method = ?` ) +// ErrNoDatabase denotes when the database does not exist or has not been +// constructed correctly. +var ErrNoDatabase = errors.New("no database connection") + // SqliteCache is a default implementation of [Cache] which creates a local // SQLite cache to persist and to query HTTP responses. type SqliteCache struct { @@ -57,26 +80,41 @@ func NewSqliteCache(databaseName string) (*SqliteCache, error) { }, nil } -func (s *SqliteCache) Save(response *http.Response) error { - return s.SaveContext(context.Background(), response) +func (s *SqliteCache) Save(response *http.Response, expiryTime *time.Duration) error { + return s.SaveContext(context.Background(), response, expiryTime) } func (s *SqliteCache) Read(request *http.Request) (*http.Response, error) { return s.ReadContext(context.Background(), request) } -func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response) error { +func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response, expiryTime *time.Duration) error { responseBody := bytes.NewBuffer(nil) _, err := io.Copy(responseBody, response.Body) if err != nil { return err } + var expiryTimestamp *int64 + if expiryTime != nil { + calculatedExpiry := time.Now().Add(*expiryTime).Unix() + expiryTimestamp = &calculatedExpiry + } else { + expiryTimestamp = nil + } + // Reset the response body stream to the beginning to be read again. response.Body = io.NopCloser(responseBody) requestUrl := generateUrl(response.Request) - _, err = s.Database.Exec(saveRequestQuery, requestUrl, response.Request.Method, responseBody.Bytes(), response.StatusCode) + _, err = s.Database.Exec( + saveRequestQuery, + requestUrl, + response.Request.Method, + responseBody.Bytes(), + response.StatusCode, + expiryTimestamp, + ) if err != nil { return err } @@ -86,7 +124,7 @@ func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response) func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) { if s.Database == nil { - return nil, fmt.Errorf("database is nil") + return nil, ErrNoDatabase } requestUrl := generateUrl(request) @@ -94,7 +132,9 @@ func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (* var responseBody []byte var responseStatusCode int - err := row.Scan(&responseBody, &responseStatusCode) + var expiryTime *int64 + + err := row.Scan(&responseBody, &responseStatusCode, &expiryTime) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNoResponse @@ -102,6 +142,10 @@ func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (* return nil, err } + if expiryTime != nil && time.Now().Unix() > *expiryTime { + return nil, ErrNoResponse + } + return &http.Response{ Request: request, Body: io.NopCloser(bytes.NewReader(responseBody)), diff --git a/sqlite_cache_test.go b/sqlite_cache_test.go index 2d27270..7f9f3b3 100644 --- a/sqlite_cache_test.go +++ b/sqlite_cache_test.go @@ -2,11 +2,13 @@ package httpcache_test import ( "database/sql" + "errors" "io" "net/http" "net/url" "strings" "testing" + "time" "github.com/DATA-DOG/go-sqlmock" "github.com/alexmerren/httpcache" @@ -20,31 +22,97 @@ const ( ) const ( - insertQuery = `INSERT INTO responses \(request_url, request_method, response_body, status_code\) VALUES \(\?, \?, \?, \?\)` - selectQuery = `SELECT response_body, status_code FROM responses WHERE request_url = \? AND request_method = \?` + insertQuery = `INSERT INTO responses \( request_url, request_method, response_body, status_code, expiry_time\) VALUES \(\?, \?, \?, \?, \?\) ON CONFLICT REPLACE` + selectQuery = `SELECT response_body, status_code, expiry_time FROM responses WHERE request_url = \? AND request_method = \?` ) func Test_Save_HappyPath(t *testing.T) { // Given response := aDummyResponse() - db, mock, closeFunc := aDatabaseMock(t) + db, mockDatabase, closeFunc := aDatabaseMock(t) defer closeFunc() - mock.ExpectExec(insertQuery).WithArgs( - testHost+testPath, - http.MethodGet, - []byte(testBody), - http.StatusOK, - ).WillReturnResult(sqlmock.NewResult(1, 1)) + mockDatabase. + ExpectExec(insertQuery). + WithArgs( + testHost+testPath, + http.MethodGet, + []byte(testBody), + http.StatusOK, + nil, + ). + WillReturnResult(sqlmock.NewResult(1, 1)) subject := &httpcache.SqliteCache{Database: db} // When - err := subject.Save(response) + err := subject.Save(response, nil) // Then assert.Nil(t, err) - if err := mock.ExpectationsWereMet(); err != nil { + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Save_HappyPathWithExpiryTime(t *testing.T) { + // Given + expiryTimestamp := time.Duration(1) * time.Minute + + response := aDummyResponse() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockDatabase. + ExpectExec(insertQuery). + WithArgs( + testHost+testPath, + http.MethodGet, + []byte(testBody), + http.StatusOK, + time.Now().Add(expiryTimestamp).Unix(), + ). + WillReturnResult(sqlmock.NewResult(1, 1)) + + subject := &httpcache.SqliteCache{Database: db} + + // When + err := subject.Save(response, &expiryTimestamp) + + // Then + assert.Nil(t, err) + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Save_ExecFails(t *testing.T) { + // Given + expiryTimestamp := time.Duration(1) * time.Minute + + response := aDummyResponse() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockDatabase. + ExpectExec(insertQuery). + WithArgs( + testHost+testPath, + http.MethodGet, + []byte(testBody), + http.StatusOK, + time.Now().Add(expiryTimestamp).Unix(), + ). + WillReturnError(errors.New("dummy error")) + + subject := &httpcache.SqliteCache{Database: db} + + // When + err := subject.Save(response, &expiryTimestamp) + + // Then + assert.NotNil(t, err) + if err := mockDatabase.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } @@ -52,14 +120,17 @@ func Test_Save_HappyPath(t *testing.T) { func Test_Read_HappyPath(t *testing.T) { // Given request := aDummyRequest() - db, mock, closeFunc := aDatabaseMock(t) + db, mockDatabase, closeFunc := aDatabaseMock(t) defer closeFunc() - mockRows := sqlmock.NewRows([]string{"response_body", "status_code"}).AddRow([]byte(testBody), 200) - mock.ExpectQuery(selectQuery).WithArgs( - testHost+testPath, - http.MethodGet, - ).WillReturnRows(mockRows) + mockRows := sqlmock. + NewRows([]string{"response_body", "status_code", "expiry_time"}). + AddRow([]byte(testBody), 200, nil) + + mockDatabase. + ExpectQuery(selectQuery). + WithArgs(testHost+testPath, http.MethodGet). + WillReturnRows(mockRows) subject := &httpcache.SqliteCache{Database: db} @@ -75,7 +146,127 @@ func Test_Read_HappyPath(t *testing.T) { assert.Nil(t, err) assert.Equal(t, string(responseBody), testBody) - if err := mock.ExpectationsWereMet(); err != nil { + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Read_HappyPathWithExpiryTime(t *testing.T) { + // Given + expiryTimestamp := time.Duration(1) * time.Minute + + request := aDummyRequest() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockRows := sqlmock. + NewRows([]string{"response_body", "status_code", "expiry_time"}). + AddRow([]byte(testBody), 200, time.Now().Add(expiryTimestamp).Unix()) + + mockDatabase. + ExpectQuery(selectQuery). + WithArgs(testHost+testPath, http.MethodGet). + WillReturnRows(mockRows) + + subject := &httpcache.SqliteCache{Database: db} + + // When + response, err := subject.Read(request) + + // Then + assert.Nil(t, err) + assert.NotNil(t, response) + + responseBody, err := io.ReadAll(response.Body) + defer response.Body.Close() + assert.Nil(t, err) + assert.Equal(t, string(responseBody), testBody) + + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Read_ExpiryTimeHasPassed(t *testing.T) { + // Given + expiryTimestamp := -time.Duration(10) * time.Minute + + request := aDummyRequest() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockRows := sqlmock. + NewRows([]string{"response_body", "status_code", "expiry_time"}). + AddRow([]byte(testBody), 200, time.Now().Add(expiryTimestamp).Unix()) + + mockDatabase. + ExpectQuery(selectQuery). + WithArgs(testHost+testPath, http.MethodGet). + WillReturnRows(mockRows) + + subject := &httpcache.SqliteCache{Database: db} + + // When + response, err := subject.Read(request) + + // Then + assert.NotNil(t, err) + assert.ErrorIs(t, err, httpcache.ErrNoResponse) + assert.Nil(t, response) + + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Read_NoRowsFromScan(t *testing.T) { + // Given + request := aDummyRequest() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockDatabase. + ExpectQuery(selectQuery). + WithArgs(testHost+testPath, http.MethodGet). + WillReturnError(sql.ErrNoRows) + + subject := &httpcache.SqliteCache{Database: db} + + // When + response, err := subject.Read(request) + + // Then + assert.NotNil(t, err) + assert.ErrorIs(t, err, httpcache.ErrNoResponse) + assert.Nil(t, response) + + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Read_ScanError(t *testing.T) { + // Given + request := aDummyRequest() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockDatabase. + ExpectQuery(selectQuery). + WithArgs(testHost+testPath, http.MethodGet). + WillReturnError(sql.ErrConnDone) + + subject := &httpcache.SqliteCache{Database: db} + + // When + response, err := subject.Read(request) + + // Then + assert.NotNil(t, err) + assert.ErrorIs(t, err, sql.ErrConnDone) + assert.Nil(t, response) + + if err := mockDatabase.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } diff --git a/transport.go b/transport.go index c48a58c..372d706 100644 --- a/transport.go +++ b/transport.go @@ -3,6 +3,7 @@ package httpcache import ( "errors" "net/http" + "time" ) var ( @@ -68,7 +69,13 @@ func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { } if t.shouldSaveResponse(response.StatusCode, response.Request.Method) { - err = t.cache.Save(response) + var expiryTimestamp time.Time + + if t.config.ExpiryTime != nil { + expiryTimestamp = time.Now().Add(*t.config.ExpiryTime) + } + + err = t.cache.Save(response, &expiryTimestamp) if err != nil { response.Body.Close() return nil, err From 279c592ff3d372007a86b6d369eea7c44bf64d25 Mon Sep 17 00:00:00 2001 From: alexmerren Date: Sun, 9 Mar 2025 18:34:48 +0000 Subject: [PATCH 09/12] Fix tests, update README, fix regular flow --- README.md | 69 +++++++++++++++++++++++--------------------- cache.go | 4 +-- config.go | 24 +++++++-------- sqlite_cache.go | 5 ++-- sqlite_cache_test.go | 2 +- transport.go | 9 +----- 6 files changed, 52 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 2e21c88..5cf124b 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,20 @@ -# HTTPcache +# httpcache [![Go Report Card](https://goreportcard.com/badge/github.com/alexmerren/httpcache)](https://goreportcard.com/report/github.com/alexmerren/httpcache) ![Go Version](https://img.shields.io/badge/go%20version-%3E=1.21-61CFDD.svg?style=flat-square) [![Go Reference](https://pkg.go.dev/badge/github.com/alexmerren/httpcache.svg)](https://pkg.go.dev/github.com/alexmerren/httpcache) -HTTPcache is a fast, local cache for HTTP requests and responses and wraps the default `http.RoundTripper` from Go standard library. +httpcache is a local cache for HTTP requests and responses, wrapping `http.RoundTripper` from Go standard library. ## Features -HTTPcache has a few useful features: +httpcache has a few useful features: -* Store and retrieve HTTP responses for any type of request. -* Expire responses after a customisable time duration. -* Decide when to store responses based on status code. +- Store and retrieve HTTP responses for any type of request; +- Expire responses after a customisable time duration; +- Decide when to store responses based on status code and request method. -If you want to request a feature then open a [GitHub Issue](https://www.github.com/alexmerren/httpcache/issues) today! +If you want to request a feature then please open a [GitHub Issue](https://www.github.com/alexmerren/httpcache/issues) today! ## Quick Start @@ -27,33 +27,36 @@ go get -u github.com/alexmerren/httpcache Here's an example of using the `httpcache` module to cache responses: ```go -package main - func main() { - // Create a new cached round tripper that: - // * Only stores responses with status code 200. - // * Refuses to store responses with status code 404. - cache := httpcache.NewCachedRoundTripper( - httpcache.WithAllowedStatusCodes([]int{200}), - httpcache.WithDeniedStatusCodes([]int{404}), - ) - - // Create HTTP client with cached round tripper. - httpClient := &http.Client{ - Transport: cache, - } - - // Do first request to populate local database. - httpClient.Get("https://www.google.com") - - // Subsequent requests read from database with no outgoing HTTP request. - for _ = range 10 { - response, _ = httpClient.Get("https://www.google.com") - defer response.Body.Close() - responseBody = io.ReadAll(response.body) - - fmt.Println(responseBody) - } + // Create a new SQLite database to store HTTP responses. + cache, _ := httpcache.NewSqliteCache("database.sqlite") + + // Use the default config with the SQLite database. + // The default config has the behaviour of: + // - Storing responses with status code 200; + // - Storing responses from HTTP requests using method "GET"; + // - Expiring responses after 7 days. + cachedTransport, _ := httpcache.NewTransport( + httpcache.DefaultConfig, + cache, + ) + + // Create a HTTP client with the cached roundtripper. + httpClient := http.Client{ + Transport: cachedTransport, + } + + // Do first request to populate local database. + httpClient.Get("https://www.google.com") + + // Subsequent requests read from database with no outgoing HTTP request. + for _ = range 10 { + response, _ := httpClient.Get("https://www.google.com") + defer response.Body.Close() + responseBody, _ := io.ReadAll(response.Body) + + fmt.Println(string(responseBody)) + } } ``` diff --git a/cache.go b/cache.go index 0a9be1e..124ff1d 100644 --- a/cache.go +++ b/cache.go @@ -12,13 +12,13 @@ import ( type Cache interface { // Save a response for a HTTP request using [context.Background]. - Save(response *http.Response, expiryTime *time.Time) error + Save(response *http.Response, expiryTime *time.Duration) error // Read a saved response for a HTTP request using [context.Background]. Read(request *http.Request) (*http.Response, error) // Save a response for a HTTP request with a [context.Context]. - SaveContext(ctx context.Context, response *http.Response, expiryTime *time.Time) error + SaveContext(ctx context.Context, response *http.Response, expiryTime *time.Duration) error // Read a saved response for a HTTP request with a [context.Context]. If no // response is saved for the corresponding request, return [ErrNoResponse]. diff --git a/config.go b/config.go index f7f87fc..b54716b 100644 --- a/config.go +++ b/config.go @@ -11,6 +11,16 @@ var ( defaultExpiryTime = time.Duration(60*24*7) * time.Minute ) +// DefaultConfig creates a [Config] with default values, namely: +// - AllowedStatusCodes: [http.StatusOK] +// - AllowedMethods: [http.MethodGet] +// - ExpiryTime: 7 days. +var DefaultConfig = NewConfigBuilder(). + WithAllowedStatusCodes(defaultAllowedStatusCodes). + WithAllowedMethods(defaultAllowedMethods). + WithExpiryTime(defaultExpiryTime). + Build() + // Config describes the configuration to use when saving and reading responses // from [Cache] using the [Transport]. type Config struct { @@ -32,17 +42,3 @@ type Config struct { // ExpiryTime describes when a HTTP response should be considered invalid. ExpiryTime *time.Duration } - -// DefaultConfig creates a [Config] with default values, namely: -// - AllowedStatusCodes: [http.StatusOK] -// - AllowedMethods: [http.MethodGet] -// - ExpiryTime: 7 days. -func DefaultConfig() (*Config, error) { - defaultConfig := NewConfigBuilder(). - WithAllowedStatusCodes(defaultAllowedStatusCodes). - WithAllowedMethods(defaultAllowedMethods). - WithExpiryTime(defaultExpiryTime). - Build() - - return defaultConfig, nil -} diff --git a/sqlite_cache.go b/sqlite_cache.go index 73b99f7..a1df3cf 100644 --- a/sqlite_cache.go +++ b/sqlite_cache.go @@ -24,14 +24,13 @@ const ( expiry_time INTEGER)` saveRequestQuery = ` - INSERT INTO responses ( + INSERT OR REPLACE INTO responses ( request_url, request_method, response_body, status_code, expiry_time) - VALUES (?, ?, ?, ?, ?) - ON CONFLICT REPLACE` + VALUES (?, ?, ?, ?, ?)` readRequestQuery = ` SELECT response_body, status_code, expiry_time diff --git a/sqlite_cache_test.go b/sqlite_cache_test.go index 7f9f3b3..2d96b01 100644 --- a/sqlite_cache_test.go +++ b/sqlite_cache_test.go @@ -22,7 +22,7 @@ const ( ) const ( - insertQuery = `INSERT INTO responses \( request_url, request_method, response_body, status_code, expiry_time\) VALUES \(\?, \?, \?, \?, \?\) ON CONFLICT REPLACE` + insertQuery = `INSERT OR REPLACE INTO responses \( request_url, request_method, response_body, status_code, expiry_time\) VALUES \(\?, \?, \?, \?, \?\)` selectQuery = `SELECT response_body, status_code, expiry_time FROM responses WHERE request_url = \? AND request_method = \?` ) diff --git a/transport.go b/transport.go index 372d706..a1cc82e 100644 --- a/transport.go +++ b/transport.go @@ -3,7 +3,6 @@ package httpcache import ( "errors" "net/http" - "time" ) var ( @@ -69,13 +68,7 @@ func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { } if t.shouldSaveResponse(response.StatusCode, response.Request.Method) { - var expiryTimestamp time.Time - - if t.config.ExpiryTime != nil { - expiryTimestamp = time.Now().Add(*t.config.ExpiryTime) - } - - err = t.cache.Save(response, &expiryTimestamp) + err = t.cache.Save(response, t.config.ExpiryTime) if err != nil { response.Body.Close() return nil, err From 11f72ea1203e540e8a268c684e821fe59a580aea Mon Sep 17 00:00:00 2001 From: alexmerren Date: Sun, 9 Mar 2025 18:39:13 +0000 Subject: [PATCH 10/12] Fix README --- README.md | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 5cf124b..4ebe31f 100644 --- a/README.md +++ b/README.md @@ -31,15 +31,21 @@ func main() { // Create a new SQLite database to store HTTP responses. cache, _ := httpcache.NewSqliteCache("database.sqlite") - // Use the default config with the SQLite database. - // The default config has the behaviour of: + // Create a config with a behaviour of: // - Storing responses with status code 200; // - Storing responses from HTTP requests using method "GET"; // - Expiring responses after 7 days. - cachedTransport, _ := httpcache.NewTransport( - httpcache.DefaultConfig, - cache, - ) + config := httpcache.NewConfigBuilder(). + WithAllowedStatusCodes([]int{http.StatusOK}). + WithAllowedMethods([]string{http.MethodGet}). + WithExpiryTime(time.Duration(60*24*7) * time.Minute). + Build() + + // ... or use the default config + config = httpcache.DefaultConfig + + // Create a transport with the SQLite cache and config + cachedTransport, _ := httpcache.NewTransport(config, cache) // Create a HTTP client with the cached roundtripper. httpClient := http.Client{ From 823df38ea0140a4cce0884d30e34fc1ddfa3a7f2 Mon Sep 17 00:00:00 2001 From: alexmerren Date: Sun, 9 Mar 2025 18:42:26 +0000 Subject: [PATCH 11/12] Update doc --- cache.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cache.go b/cache.go index 124ff1d..3b700f6 100644 --- a/cache.go +++ b/cache.go @@ -17,11 +17,13 @@ type Cache interface { // Read a saved response for a HTTP request using [context.Background]. Read(request *http.Request) (*http.Response, error) - // Save a response for a HTTP request with a [context.Context]. + // Save a response for a HTTP request with a [context.Context]. expiryTime + // is the duration from [time.Now] to expire the response. SaveContext(ctx context.Context, response *http.Response, expiryTime *time.Duration) error // Read a saved response for a HTTP request with a [context.Context]. If no - // response is saved for the corresponding request, return [ErrNoResponse]. + // response is saved for the corresponding request, or the expiryTime has + // been surpassed, then return [ErrNoResponse]. ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) } From d868150c351da7b0643e47e627f1aab1abd89f7c Mon Sep 17 00:00:00 2001 From: alexmerren Date: Sun, 9 Mar 2025 18:47:20 +0000 Subject: [PATCH 12/12] fix wording in README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 4ebe31f..5b2b7c1 100644 --- a/README.md +++ b/README.md @@ -34,17 +34,17 @@ func main() { // Create a config with a behaviour of: // - Storing responses with status code 200; // - Storing responses from HTTP requests using method "GET"; - // - Expiring responses after 7 days. + // - Expiring responses after 7 days... config := httpcache.NewConfigBuilder(). WithAllowedStatusCodes([]int{http.StatusOK}). WithAllowedMethods([]string{http.MethodGet}). WithExpiryTime(time.Duration(60*24*7) * time.Minute). Build() - // ... or use the default config + // ... or use the default config. config = httpcache.DefaultConfig - // Create a transport with the SQLite cache and config + // Create a transport with the SQLite cache and config. cachedTransport, _ := httpcache.NewTransport(config, cache) // Create a HTTP client with the cached roundtripper.