diff --git a/Makefile b/Makefile index 7653b7f..02d8971 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ fmt: $(DRUN) $(GO) fmt ./... vet: - $(DRUN) $(GO) vet ./. + $(DRUN) $(GO) vet ./... bash: $(DRUN) /bin/bash \ No newline at end of file diff --git a/README.md b/README.md index 2e21c88..5b2b7c1 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,20 @@ -# HTTPcache +# httpcache [![Go Report Card](https://goreportcard.com/badge/github.com/alexmerren/httpcache)](https://goreportcard.com/report/github.com/alexmerren/httpcache) ![Go Version](https://img.shields.io/badge/go%20version-%3E=1.21-61CFDD.svg?style=flat-square) [![Go Reference](https://pkg.go.dev/badge/github.com/alexmerren/httpcache.svg)](https://pkg.go.dev/github.com/alexmerren/httpcache) -HTTPcache is a fast, local cache for HTTP requests and responses and wraps the default `http.RoundTripper` from Go standard library. +httpcache is a local cache for HTTP requests and responses, wrapping `http.RoundTripper` from Go standard library. ## Features -HTTPcache has a few useful features: +httpcache has a few useful features: -* Store and retrieve HTTP responses for any type of request. -* Expire responses after a customisable time duration. -* Decide when to store responses based on status code. +- Store and retrieve HTTP responses for any type of request; +- Expire responses after a customisable time duration; +- Decide when to store responses based on status code and request method. -If you want to request a feature then open a [GitHub Issue](https://www.github.com/alexmerren/httpcache/issues) today! +If you want to request a feature then please open a [GitHub Issue](https://www.github.com/alexmerren/httpcache/issues) today! ## Quick Start @@ -27,33 +27,42 @@ go get -u github.com/alexmerren/httpcache Here's an example of using the `httpcache` module to cache responses: ```go -package main - func main() { - // Create a new cached round tripper that: - // * Only stores responses with status code 200. - // * Refuses to store responses with status code 404. - cache := httpcache.NewCachedRoundTripper( - httpcache.WithAllowedStatusCodes([]int{200}), - httpcache.WithDeniedStatusCodes([]int{404}), - ) - - // Create HTTP client with cached round tripper. - httpClient := &http.Client{ - Transport: cache, - } - - // Do first request to populate local database. - httpClient.Get("https://www.google.com") - - // Subsequent requests read from database with no outgoing HTTP request. - for _ = range 10 { - response, _ = httpClient.Get("https://www.google.com") - defer response.Body.Close() - responseBody = io.ReadAll(response.body) - - fmt.Println(responseBody) - } + // Create a new SQLite database to store HTTP responses. + cache, _ := httpcache.NewSqliteCache("database.sqlite") + + // Create a config with a behaviour of: + // - Storing responses with status code 200; + // - Storing responses from HTTP requests using method "GET"; + // - Expiring responses after 7 days... + config := httpcache.NewConfigBuilder(). + WithAllowedStatusCodes([]int{http.StatusOK}). + WithAllowedMethods([]string{http.MethodGet}). + WithExpiryTime(time.Duration(60*24*7) * time.Minute). + Build() + + // ... or use the default config. + config = httpcache.DefaultConfig + + // Create a transport with the SQLite cache and config. + cachedTransport, _ := httpcache.NewTransport(config, cache) + + // Create a HTTP client with the cached roundtripper. + httpClient := http.Client{ + Transport: cachedTransport, + } + + // Do first request to populate local database. + httpClient.Get("https://www.google.com") + + // Subsequent requests read from database with no outgoing HTTP request. + for _ = range 10 { + response, _ := httpClient.Get("https://www.google.com") + defer response.Body.Close() + responseBody, _ := io.ReadAll(response.Body) + + fmt.Println(string(responseBody)) + } } ``` diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..3b700f6 --- /dev/null +++ b/cache.go @@ -0,0 +1,35 @@ +package httpcache + +import ( + "context" + "errors" + "net/http" + "time" +) + +// Cache is the entrypoint for saving and reading responses. This can be +// implemented for a custom method to cache responses. +type Cache interface { + + // Save a response for a HTTP request using [context.Background]. + Save(response *http.Response, expiryTime *time.Duration) error + + // Read a saved response for a HTTP request using [context.Background]. + Read(request *http.Request) (*http.Response, error) + + // Save a response for a HTTP request with a [context.Context]. expiryTime + // is the duration from [time.Now] to expire the response. + SaveContext(ctx context.Context, response *http.Response, expiryTime *time.Duration) error + + // Read a saved response for a HTTP request with a [context.Context]. If no + // response is saved for the corresponding request, or the expiryTime has + // been surpassed, then return [ErrNoResponse]. + ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) +} + +var ( + // ErrNoResponse describes when the cache does not have a response stored. + // [Transport] will check if ErrNoResponse is returned from [Cache.Read]. If + // ErrNoResponse is returned, then the request/response will be saved with [Save]. + ErrNoResponse = errors.New("no stored response") +) diff --git a/config.go b/config.go new file mode 100644 index 0000000..b54716b --- /dev/null +++ b/config.go @@ -0,0 +1,44 @@ +package httpcache + +import ( + "net/http" + "time" +) + +var ( + defaultAllowedStatusCodes = []int{http.StatusOK} + defaultAllowedMethods = []string{http.MethodGet} + defaultExpiryTime = time.Duration(60*24*7) * time.Minute +) + +// DefaultConfig creates a [Config] with default values, namely: +// - AllowedStatusCodes: [http.StatusOK] +// - AllowedMethods: [http.MethodGet] +// - ExpiryTime: 7 days. +var DefaultConfig = NewConfigBuilder(). + WithAllowedStatusCodes(defaultAllowedStatusCodes). + WithAllowedMethods(defaultAllowedMethods). + WithExpiryTime(defaultExpiryTime). + Build() + +// Config describes the configuration to use when saving and reading responses +// from [Cache] using the [Transport]. +type Config struct { + + // AllowedStatusCodes describes if a HTTP response should be saved by + // checking that it's status code is accepted by [Cache]. If the HTTP + // response's status code is not in AllowedStatusCodes, then do not persist. + // + // This is a required field. + AllowedStatusCodes []int + + // AllowedMethods describes if a HTTP response should be saved by checking + // if the HTTP request's method is accepted by the [Cache]. If the HTTP + // request's method is not in AllowedMethods, then do not persist. + // + // This is a required field. + AllowedMethods []string + + // ExpiryTime describes when a HTTP response should be considered invalid. + ExpiryTime *time.Duration +} diff --git a/config_builder.go b/config_builder.go new file mode 100644 index 0000000..efa19f9 --- /dev/null +++ b/config_builder.go @@ -0,0 +1,58 @@ +package httpcache + +import "time" + +// If adding new fields to [Config], then add the corresponding field and +// methods to [configBuilder]. Configuration values in [configBuilder] always +// use pointer types. This allows the [Config] decide which fields are +// required. + +// configBuilder is an internal structure to create configs with required +// parameters. +type configBuilder struct { + + // allowedStatusCodes only persists HTTP responses that have an appropriate + // status code (i.e. 200). + // + // This is a required field. + allowedStatusCodes *[]int + + // allowedMethods only persists HTTP responses that use an appropriate HTTP + // method (i.e. "GET"). + // + // This is a required field. + allowedMethods *[]string + + // expiryTime invalidates HTTP responses after a duration has elapsed from + // [time.Now]. Set to nil for no expiry. + expiryTime *time.Duration +} + +func NewConfigBuilder() *configBuilder { + return &configBuilder{} +} + +func (c *configBuilder) WithAllowedStatusCodes(allowedStatusCodes []int) *configBuilder { + c.allowedStatusCodes = &allowedStatusCodes + return c +} + +func (c *configBuilder) WithAllowedMethods(allowedMethods []string) *configBuilder { + c.allowedMethods = &allowedMethods + return c +} + +func (c *configBuilder) WithExpiryTime(expiryDuration time.Duration) *configBuilder { + c.expiryTime = &expiryDuration + return c +} + +// Build constructs a [Config] that is ready to be consumed by [Transport]. If +// the configuration passed by [configBuilder] is invalid, it will panic. +func (c *configBuilder) Build() *Config { + return &Config{ + AllowedStatusCodes: *c.allowedStatusCodes, + AllowedMethods: *c.allowedMethods, + ExpiryTime: c.expiryTime, + } +} diff --git a/file.go b/file.go deleted file mode 100644 index 4e3a4b9..0000000 --- a/file.go +++ /dev/null @@ -1,34 +0,0 @@ -package httpcache - -import ( - "errors" - "os" - "path/filepath" -) - -func doesFileExist(filename string) (bool, error) { - _, err := os.Stat(filename) - if err != nil && errors.Is(err, os.ErrNotExist) { - return false, nil - } - - if err != nil && !errors.Is(err, os.ErrNotExist) { - return false, err - } - - return true, nil -} - -func createFile(filename string) (*os.File, error) { - directory, err := filepath.Abs(filepath.Dir(filename)) - if err != nil { - return nil, err - } - - err = os.MkdirAll(directory, os.ModePerm) - if err != nil { - return nil, err - } - - return os.Create(filename) -} diff --git a/options.go b/options.go deleted file mode 100644 index ac6bc2d..0000000 --- a/options.go +++ /dev/null @@ -1,45 +0,0 @@ -package httpcache - -import ( - "time" -) - -func WithDeniedStatusCodes(deniedStatusCodes []int) func(*CachedRoundTripper) error { - return func(c *CachedRoundTripper) error { - c.deniedStatusCodes = deniedStatusCodes - return nil - } -} - -func WithAllowedStatusCodes(allowedStatusCodes []int) func(*CachedRoundTripper) error { - return func(c *CachedRoundTripper) error { - c.allowedStatusCodes = allowedStatusCodes - return nil - } -} - -func WithExpiryTime(expiryTime time.Duration) func(*CachedRoundTripper) error { - return func(c *CachedRoundTripper) error { - c.expiryTime = expiryTime - return nil - } -} - -func WithName(name string) func(*CachedRoundTripper) error { - return func(c *CachedRoundTripper) error { - sqliteStore, err := newSqliteResponseStore(name) - if err != nil { - return err - } - c.store = sqliteStore - - return nil - } -} - -func WithCacheStore(store ResponseStorer) func(*CachedRoundTripper) error { - return func(c *CachedRoundTripper) error { - c.store = store - return nil - } -} diff --git a/round_tripper.go b/round_tripper.go deleted file mode 100644 index c221fd7..0000000 --- a/round_tripper.go +++ /dev/null @@ -1,119 +0,0 @@ -package httpcache - -import ( - "bytes" - "errors" - "io" - "net/http" - "time" -) - -var ( - defaultDeniedStatusCodes = []int{ - http.StatusNotFound, - http.StatusBadRequest, - http.StatusForbidden, - http.StatusUnauthorized, - http.StatusMethodNotAllowed, - } - - defaultAllowedStatusCodes = []int{ - http.StatusOK, - } - - defaultExpiryTime = time.Duration(0) - - defaultCacheName = "httpcache.sqlite" -) - -type CachedRoundTripper struct { - transport http.RoundTripper - store ResponseStorer - expiryTime time.Duration - deniedStatusCodes []int - allowedStatusCodes []int -} - -func NewCachedRoundTripper(options ...func(*CachedRoundTripper) error) (*CachedRoundTripper, error) { - roundTripper := &CachedRoundTripper{ - transport: http.DefaultTransport, - deniedStatusCodes: defaultDeniedStatusCodes, - allowedStatusCodes: defaultAllowedStatusCodes, - expiryTime: defaultExpiryTime, - } - - for _, optionFunc := range options { - err := optionFunc(roundTripper) - if err != nil { - return nil, err - } - } - - if roundTripper.store != nil { - return roundTripper, nil - } - - store, err := newSqliteResponseStore(defaultCacheName) - if err != nil { - return nil, err - } - roundTripper.store = store - - return roundTripper, nil -} - -func (h *CachedRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { - response, err := h.store.Read(request) - if err == nil { - return response, nil - } - - if !errors.Is(err, ErrNoResponse) { - return nil, err - } - - requestBody := []byte{} - if request.GetBody != nil { - body, err := request.GetBody() - if err != nil { - return nil, err - } - requestBody, err = io.ReadAll(body) - if err != nil { - return nil, err - } - defer body.Close() - } - - response, err = h.transport.RoundTrip(request) - if err != nil { - return nil, err - } - response.Request.Body = io.NopCloser(bytes.NewReader(requestBody)) - - if contains(h.deniedStatusCodes, response.StatusCode) { - return response, nil - } - - if !contains(h.allowedStatusCodes, response.StatusCode) { - return response, nil - } - - err = h.store.Save(response) - if err != nil { - response.Body.Close() - response.Request.Body.Close() - return nil, err - } - - return response, nil -} - -func contains(slice []int, searchValue int) bool { - for index := range slice { - if searchValue == slice[index] { - return true - } - } - return false -} diff --git a/sqlite.go b/sqlite.go deleted file mode 100644 index 5865a04..0000000 --- a/sqlite.go +++ /dev/null @@ -1,154 +0,0 @@ -package httpcache - -import ( - "bytes" - "context" - "database/sql" - "errors" - "fmt" - "io" - "net/http" - "strings" - - _ "github.com/mattn/go-sqlite3" -) - -const ( - createDatabaseQuery = "CREATE TABLE IF NOT EXISTS responses (request_url TEXT PRIMARY KEY, request_method TEXT, request_body TEXT, response_body TEXT, status_code INTEGER)" - saveRequestQuery = "INSERT INTO responses (request_url, request_method, request_body, response_body, status_code) VALUES (?, ?, ?, ?, ?)" - readRequestQuery = "SELECT response_body, status_code FROM responses WHERE request_url = ? AND request_method = ? AND request_body = ?" -) - -type sqliteResponseStore struct { - database *sql.DB -} - -func newSqliteResponseStore(databaseName string) (*sqliteResponseStore, error) { - fileExists, err := doesFileExist(databaseName) - if err != nil { - return nil, err - } - - if !fileExists { - _, err = createFile(databaseName) - if err != nil { - return nil, err - } - } - - conn, err := sql.Open("sqlite3", databaseName) - if err != nil { - return nil, err - } - - _, err = conn.Exec(createDatabaseQuery) - if err != nil { - return nil, err - } - - return &sqliteResponseStore{ - database: conn, - }, nil -} - -func (s *sqliteResponseStore) Save(response *http.Response) error { - return s.SaveContext(context.Background(), response) -} - -func (s *sqliteResponseStore) SaveContext(ctx context.Context, response *http.Response) error { - requestURL := response.Request.URL.Host + response.Request.URL.RequestURI() - requestMethod := response.Request.Method - requestBody := []byte{} - - // If the request has no body (e.g. GET request) then Request.GetBody will - // be nil. Check if it is nil to get request body for other types of request - // (e.g. POST, PATCH). - if response.Request.GetBody != nil { - body, err := response.Request.GetBody() - if err != nil { - return err - } - requestBody, err = io.ReadAll(body) - if err != nil { - return err - } - defer body.Close() - } - - responseBody, err := io.ReadAll(response.Body) - if err != nil { - return err - } - - // Reset the response body stream to the beginning to be read again. - response.Body = io.NopCloser(bytes.NewReader(responseBody)) - responseStatusCode := response.StatusCode - - tx, err := s.database.BeginTx(ctx, nil) - if err != nil { - return err - } - - defer func() { - switch err { - case nil: - err = tx.Commit() - default: - tx.Rollback() - } - }() - - _, err = tx.Exec(saveRequestQuery, requestURL, requestMethod, requestBody, responseBody, responseStatusCode) - if err != nil { - return err - } - - return err -} - -func (s *sqliteResponseStore) Read(request *http.Request) (*http.Response, error) { - return s.ReadContext(context.Background(), request) -} - -func (s *sqliteResponseStore) ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) { - requestURL := request.URL.Host + request.URL.RequestURI() - requestMethod := request.Method - requestBody := []byte{} - - // If the request has no body (e.g. GET request) then Request.GetBody will - // be nil. Check if it is nil to get request body for other types of request - // (e.g. POST, PATCH). - if request.GetBody != nil { - body, err := request.GetBody() - if err != nil { - return nil, err - } - requestBody, err = io.ReadAll(body) - if err != nil { - return nil, err - } - defer body.Close() - } - - responseBody := "" - responseStatusCode := 0 - - if s.database == nil { - return nil, fmt.Errorf("database is nil") - } - - row := s.database.QueryRow(readRequestQuery, requestURL, requestMethod, requestBody) - err := row.Scan(&responseBody, &responseStatusCode) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, ErrNoResponse - } - return nil, err - } - - return &http.Response{ - Request: request, - Body: io.NopCloser(strings.NewReader(responseBody)), - StatusCode: responseStatusCode, - }, nil -} diff --git a/sqlite_cache.go b/sqlite_cache.go new file mode 100644 index 0000000..a1df3cf --- /dev/null +++ b/sqlite_cache.go @@ -0,0 +1,184 @@ +package httpcache + +import ( + "bytes" + "context" + "database/sql" + "errors" + "io" + "net/http" + "os" + "path/filepath" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +const ( + createDatabaseQuery = ` + CREATE TABLE IF NOT EXISTS responses ( + request_url TEXT PRIMARY KEY, + request_method TEXT NOT NULL, + response_body BLOB NOT NULL, + status_code INTEGER NOT NULL, + expiry_time INTEGER)` + + saveRequestQuery = ` + INSERT OR REPLACE INTO responses ( + request_url, + request_method, + response_body, + status_code, + expiry_time) + VALUES (?, ?, ?, ?, ?)` + + readRequestQuery = ` + SELECT response_body, status_code, expiry_time + FROM responses + WHERE request_url = ? AND request_method = ?` +) + +// ErrNoDatabase denotes when the database does not exist or has not been +// constructed correctly. +var ErrNoDatabase = errors.New("no database connection") + +// SqliteCache is a default implementation of [Cache] which creates a local +// SQLite cache to persist and to query HTTP responses. +type SqliteCache struct { + Database *sql.DB +} + +// NewSqliteCache creates a new SQLite database with a certain name. This name +// is the filename of the database. If the file does not exist, then we create +// it. If the file is in a non-existent directory, we create the directory. +func NewSqliteCache(databaseName string) (*SqliteCache, error) { + fileExists, err := doesFileExist(databaseName) + if err != nil { + return nil, err + } + + if !fileExists { + _, err = createFile(databaseName) + if err != nil { + return nil, err + } + } + + conn, err := sql.Open("sqlite3", databaseName) + if err != nil { + return nil, err + } + + _, err = conn.Exec(createDatabaseQuery) + if err != nil { + return nil, err + } + + return &SqliteCache{ + Database: conn, + }, nil +} + +func (s *SqliteCache) Save(response *http.Response, expiryTime *time.Duration) error { + return s.SaveContext(context.Background(), response, expiryTime) +} + +func (s *SqliteCache) Read(request *http.Request) (*http.Response, error) { + return s.ReadContext(context.Background(), request) +} + +func (s *SqliteCache) SaveContext(ctx context.Context, response *http.Response, expiryTime *time.Duration) error { + responseBody := bytes.NewBuffer(nil) + _, err := io.Copy(responseBody, response.Body) + if err != nil { + return err + } + + var expiryTimestamp *int64 + if expiryTime != nil { + calculatedExpiry := time.Now().Add(*expiryTime).Unix() + expiryTimestamp = &calculatedExpiry + } else { + expiryTimestamp = nil + } + + // Reset the response body stream to the beginning to be read again. + response.Body = io.NopCloser(responseBody) + + requestUrl := generateUrl(response.Request) + _, err = s.Database.Exec( + saveRequestQuery, + requestUrl, + response.Request.Method, + responseBody.Bytes(), + response.StatusCode, + expiryTimestamp, + ) + if err != nil { + return err + } + + return nil +} + +func (s *SqliteCache) ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) { + if s.Database == nil { + return nil, ErrNoDatabase + } + + requestUrl := generateUrl(request) + row := s.Database.QueryRow(readRequestQuery, requestUrl, request.Method) + + var responseBody []byte + var responseStatusCode int + var expiryTime *int64 + + err := row.Scan(&responseBody, &responseStatusCode, &expiryTime) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNoResponse + } + return nil, err + } + + if expiryTime != nil && time.Now().Unix() > *expiryTime { + return nil, ErrNoResponse + } + + return &http.Response{ + Request: request, + Body: io.NopCloser(bytes.NewReader(responseBody)), + StatusCode: responseStatusCode, + }, nil +} + +func generateUrl(request *http.Request) string { + return request.URL.Host + request.URL.RequestURI() +} + +func doesFileExist(filename string) (bool, error) { + _, err := os.Stat(filename) + if err != nil && errors.Is(err, os.ErrNotExist) { + return false, nil + } + + if err != nil && !errors.Is(err, os.ErrNotExist) { + return false, err + } + + return true, nil +} + +func createFile(filename string) (*os.File, error) { + directory, err := filepath.Abs(filepath.Dir(filename)) + if err != nil { + return nil, err + } + + err = os.MkdirAll(directory, os.ModePerm) + if err != nil { + return nil, err + } + + return os.Create(filename) +} diff --git a/sqlite_cache_test.go b/sqlite_cache_test.go new file mode 100644 index 0000000..2d96b01 --- /dev/null +++ b/sqlite_cache_test.go @@ -0,0 +1,323 @@ +package httpcache_test + +import ( + "database/sql" + "errors" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/alexmerren/httpcache" + "github.com/stretchr/testify/assert" +) + +const ( + testHost = "www.test.com" + testPath = "/test-path" + testBody = "this is a test body" +) + +const ( + insertQuery = `INSERT OR REPLACE INTO responses \( request_url, request_method, response_body, status_code, expiry_time\) VALUES \(\?, \?, \?, \?, \?\)` + selectQuery = `SELECT response_body, status_code, expiry_time FROM responses WHERE request_url = \? AND request_method = \?` +) + +func Test_Save_HappyPath(t *testing.T) { + // Given + response := aDummyResponse() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockDatabase. + ExpectExec(insertQuery). + WithArgs( + testHost+testPath, + http.MethodGet, + []byte(testBody), + http.StatusOK, + nil, + ). + WillReturnResult(sqlmock.NewResult(1, 1)) + + subject := &httpcache.SqliteCache{Database: db} + + // When + err := subject.Save(response, nil) + + // Then + assert.Nil(t, err) + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Save_HappyPathWithExpiryTime(t *testing.T) { + // Given + expiryTimestamp := time.Duration(1) * time.Minute + + response := aDummyResponse() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockDatabase. + ExpectExec(insertQuery). + WithArgs( + testHost+testPath, + http.MethodGet, + []byte(testBody), + http.StatusOK, + time.Now().Add(expiryTimestamp).Unix(), + ). + WillReturnResult(sqlmock.NewResult(1, 1)) + + subject := &httpcache.SqliteCache{Database: db} + + // When + err := subject.Save(response, &expiryTimestamp) + + // Then + assert.Nil(t, err) + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Save_ExecFails(t *testing.T) { + // Given + expiryTimestamp := time.Duration(1) * time.Minute + + response := aDummyResponse() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockDatabase. + ExpectExec(insertQuery). + WithArgs( + testHost+testPath, + http.MethodGet, + []byte(testBody), + http.StatusOK, + time.Now().Add(expiryTimestamp).Unix(), + ). + WillReturnError(errors.New("dummy error")) + + subject := &httpcache.SqliteCache{Database: db} + + // When + err := subject.Save(response, &expiryTimestamp) + + // Then + assert.NotNil(t, err) + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Read_HappyPath(t *testing.T) { + // Given + request := aDummyRequest() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockRows := sqlmock. + NewRows([]string{"response_body", "status_code", "expiry_time"}). + AddRow([]byte(testBody), 200, nil) + + mockDatabase. + ExpectQuery(selectQuery). + WithArgs(testHost+testPath, http.MethodGet). + WillReturnRows(mockRows) + + subject := &httpcache.SqliteCache{Database: db} + + // When + response, err := subject.Read(request) + + // Then + assert.Nil(t, err) + assert.NotNil(t, response) + + responseBody, err := io.ReadAll(response.Body) + defer response.Body.Close() + assert.Nil(t, err) + assert.Equal(t, string(responseBody), testBody) + + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Read_HappyPathWithExpiryTime(t *testing.T) { + // Given + expiryTimestamp := time.Duration(1) * time.Minute + + request := aDummyRequest() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockRows := sqlmock. + NewRows([]string{"response_body", "status_code", "expiry_time"}). + AddRow([]byte(testBody), 200, time.Now().Add(expiryTimestamp).Unix()) + + mockDatabase. + ExpectQuery(selectQuery). + WithArgs(testHost+testPath, http.MethodGet). + WillReturnRows(mockRows) + + subject := &httpcache.SqliteCache{Database: db} + + // When + response, err := subject.Read(request) + + // Then + assert.Nil(t, err) + assert.NotNil(t, response) + + responseBody, err := io.ReadAll(response.Body) + defer response.Body.Close() + assert.Nil(t, err) + assert.Equal(t, string(responseBody), testBody) + + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Read_ExpiryTimeHasPassed(t *testing.T) { + // Given + expiryTimestamp := -time.Duration(10) * time.Minute + + request := aDummyRequest() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockRows := sqlmock. + NewRows([]string{"response_body", "status_code", "expiry_time"}). + AddRow([]byte(testBody), 200, time.Now().Add(expiryTimestamp).Unix()) + + mockDatabase. + ExpectQuery(selectQuery). + WithArgs(testHost+testPath, http.MethodGet). + WillReturnRows(mockRows) + + subject := &httpcache.SqliteCache{Database: db} + + // When + response, err := subject.Read(request) + + // Then + assert.NotNil(t, err) + assert.ErrorIs(t, err, httpcache.ErrNoResponse) + assert.Nil(t, response) + + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Read_NoRowsFromScan(t *testing.T) { + // Given + request := aDummyRequest() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockDatabase. + ExpectQuery(selectQuery). + WithArgs(testHost+testPath, http.MethodGet). + WillReturnError(sql.ErrNoRows) + + subject := &httpcache.SqliteCache{Database: db} + + // When + response, err := subject.Read(request) + + // Then + assert.NotNil(t, err) + assert.ErrorIs(t, err, httpcache.ErrNoResponse) + assert.Nil(t, response) + + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func Test_Read_ScanError(t *testing.T) { + // Given + request := aDummyRequest() + db, mockDatabase, closeFunc := aDatabaseMock(t) + defer closeFunc() + + mockDatabase. + ExpectQuery(selectQuery). + WithArgs(testHost+testPath, http.MethodGet). + WillReturnError(sql.ErrConnDone) + + subject := &httpcache.SqliteCache{Database: db} + + // When + response, err := subject.Read(request) + + // Then + assert.NotNil(t, err) + assert.ErrorIs(t, err, sql.ErrConnDone) + assert.Nil(t, response) + + if err := mockDatabase.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func aDatabaseMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock, func() error) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("%v", err) + } + + return db, mock, db.Close +} + +func aDummyResponse() *http.Response { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(testBody)), + Request: &http.Request{ + URL: &url.URL{ + Scheme: "https", + Host: testHost, + Opaque: "", + User: nil, + Path: testPath, + RawPath: "", + OmitHost: false, + ForceQuery: false, + RawQuery: "", + Fragment: "", + RawFragment: "", + }, + Method: http.MethodGet, + }, + } +} + +func aDummyRequest() *http.Request { + return &http.Request{ + URL: &url.URL{ + Scheme: "https", + Host: testHost, + Opaque: "", + User: nil, + Path: testPath, + RawPath: "", + OmitHost: false, + ForceQuery: false, + RawQuery: "", + Fragment: "", + RawFragment: "", + }, + Method: http.MethodGet, + } +} diff --git a/sqlite_test.go b/sqlite_test.go deleted file mode 100644 index e70566b..0000000 --- a/sqlite_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package httpcache - -import ( - "database/sql" - "io" - "net/http" - "net/url" - "strings" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" -) - -const ( - testHost = "www.test.com" - testPath = "/test-path" - testBody = "this is a test body" -) - -const ( - insertQuery = `INSERT INTO responses \(request_url, request_method, request_body, response_body, status_code\) VALUES \(\?, \?, \?, \?, \?\)` - selectQuery = `SELECT response_body, status_code FROM responses WHERE request_url = \? AND request_method = \? AND request_body = \?` -) - -func Test_Save_HappyPath(t *testing.T) { - // Given - response := aDummyResponse() - db, mock, closeFunc := aDatabaseMock(t) - defer closeFunc() - - subject := &sqliteResponseStore{database: db} - - mock.ExpectBegin() - mock.ExpectExec(insertQuery).WithArgs( - testHost+testPath, - http.MethodGet, - []byte{}, - []byte(testBody), - http.StatusOK, - ).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - // When - err := subject.Save(response) - - // Then - assert.Nil(t, err) - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func Test_Read_HappyPath(t *testing.T) { - // Given - request := aDummyRequest() - db, mock, closeFunc := aDatabaseMock(t) - defer closeFunc() - - subject := &sqliteResponseStore{database: db} - - mockRows := sqlmock.NewRows([]string{"response_body", "status_code"}). - AddRow("mock-body", 200). - AddRow("mock-body", 200) - mock.ExpectQuery(selectQuery).WithArgs( - testHost+testPath, - http.MethodGet, - []byte{}, - ).WillReturnRows(mockRows) - - // When - response, err := subject.Read(request) - - // Then - assert.Nil(t, err) - assert.NotNil(t, response) - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func aDatabaseMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock, func() error) { - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("%v", err) - } - - return db, mock, db.Close -} - -func aDummyResponse() *http.Response { - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(testBody)), - Request: &http.Request{ - URL: &url.URL{ - Scheme: "https", - Host: testHost, - Opaque: "", - User: nil, - Path: testPath, - RawPath: "", - OmitHost: false, - ForceQuery: false, - RawQuery: "", - Fragment: "", - RawFragment: "", - }, - Method: http.MethodGet, - }, - } -} - -func aDummyRequest() *http.Request { - return &http.Request{ - URL: &url.URL{ - Scheme: "https", - Host: testHost, - Opaque: "", - User: nil, - Path: testPath, - RawPath: "", - OmitHost: false, - ForceQuery: false, - RawQuery: "", - Fragment: "", - RawFragment: "", - }, - Method: http.MethodGet, - } -} diff --git a/store.go b/store.go deleted file mode 100644 index 6143ff7..0000000 --- a/store.go +++ /dev/null @@ -1,18 +0,0 @@ -package httpcache - -import ( - "context" - "errors" - "net/http" -) - -type ResponseStorer interface { - Save(response *http.Response) error - Read(request *http.Request) (*http.Response, error) - SaveContext(ctx context.Context, response *http.Response) error - ReadContext(ctx context.Context, request *http.Request) (*http.Response, error) -} - -var ( - ErrNoResponse = errors.New("no stored response") -) diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..a1cc82e --- /dev/null +++ b/transport.go @@ -0,0 +1,98 @@ +package httpcache + +import ( + "errors" + "net/http" +) + +var ( + // ErrMissingCache will be returned if cache is not set when creating an + // instance of [Transport]. + ErrMissingCache = errors.New("cache not set when creating transport") + + // ErrMissingConfig will be returned if config is not set when creating an + // instance of [Transport]. + ErrMissingConfig = errors.New("config not set when creating transport") +) + +// Transport is the main interface of the package. It uses [Cache] to persist +// HTTP response, and and uses configuration values from [Config] to interpret +// requests and responses. +type Transport struct { + + // transport handles HTTP requests. This is hardcoded to be + // [http.DefaultTransport]. + transport http.RoundTripper + + // cache handles persisting HTTP responses. + cache Cache + + // config describes which HTTP responses to cache and how they are cached. + config *Config +} + +// NewTransport creates a [Transport]. If the cache is nil, return +// [ErrMissingCache]. If the config is nil, return [ErrMissingConfig]. +func NewTransport(config *Config, cache Cache) (*Transport, error) { + if cache == nil { + return nil, ErrMissingCache + } + if config == nil { + return nil, ErrMissingConfig + } + + return &Transport{ + transport: http.DefaultTransport, + cache: cache, + config: config, + }, nil +} + +// RoundTrip wraps the [http.DefaultTransport] RoundTrip to execute HTTP +// requests and persists, if necessary, the responses. If [Cache] returns +// [ErrNoResponse], then execute a HTTP request and persist the response if +// passing the criteria in [Transport.shouldSaveResponse]. +func (t *Transport) RoundTrip(request *http.Request) (*http.Response, error) { + response, err := t.cache.Read(request) + if err == nil { + return response, nil + } + + if !errors.Is(err, ErrNoResponse) { + return nil, err + } + + response, err = t.transport.RoundTrip(request) + if err != nil { + return nil, err + } + + if t.shouldSaveResponse(response.StatusCode, response.Request.Method) { + err = t.cache.Save(response, t.config.ExpiryTime) + if err != nil { + response.Body.Close() + return nil, err + } + } + + return response, nil +} + +// shouldSaveResponse is responsible for interpreting configuration values from +// [Config] to determine if a HTTP response should be persisted. Any new values +// added to [Config] can be used here as criteria. +func (t *Transport) shouldSaveResponse(statusCode int, method string) bool { + isAllowedStatusCode := contains(t.config.AllowedStatusCodes, statusCode) + isAllowedMethod := contains(t.config.AllowedMethods, method) + + return isAllowedStatusCode && isAllowedMethod +} + +func contains[T comparable](slice []T, searchValue T) bool { + for _, value := range slice { + if value == searchValue { + return true + } + } + return false +} diff --git a/transport_test.go b/transport_test.go new file mode 100644 index 0000000..96ca46f --- /dev/null +++ b/transport_test.go @@ -0,0 +1,45 @@ +package httpcache_test + +import "testing" + +func Test_RoundTrip_HappyPath_ResponseSaved(t *testing.T) { + // given + // when + // then +} + +func Test_RoundTrip_HappyPath_ResponseNotSaved(t *testing.T) { + // given + // when + // then +} + +func Test_RoundTrip_ReadError(t *testing.T) { + // given + // when + // then +} + +func Test_RoundTrip_ResponseHasNotAllowedStatusCode(t *testing.T) { + // given + // when + // then +} + +func Test_RoundTrip_ResponseHasNotAllowedMethod(t *testing.T) { + // given + // when + // then +} + +func Test_RoundTrip_TransportError(t *testing.T) { + // given + // when + // then +} + +func Test_RoundTrip_SaveError(t *testing.T) { + // given + // when + // then +}