diff --git a/client/fetcher.go b/client/fetcher.go index 774c62ef8..687e31256 100644 --- a/client/fetcher.go +++ b/client/fetcher.go @@ -16,13 +16,19 @@ package client import ( "context" + "errors" "fmt" "io" + "math/rand/v2" + "net" "net/http" "net/url" "os" "path" + "strconv" "strings" + "syscall" + "time" "log/slog" @@ -30,6 +36,20 @@ import ( "github.com/transparency-dev/tessera/internal/fetcher" ) +// TransientError indicates that an error is temporary and the operation can be retried. +type TransientError struct { + Err error + RetryAfter time.Duration // Optional, parsed from header if available +} + +func (e TransientError) Error() string { + return fmt.Sprintf("transient error: %v", e.Err) +} + +func (e TransientError) Unwrap() error { + return e.Err +} + // NewHTTPFetcher creates a new HTTPFetcher for the log rooted at the given URL, using // the provided HTTP client. // @@ -61,6 +81,29 @@ func (h *HTTPFetcher) SetAuthorizationHeader(v string) { h.authHeader = v } +func isTransientNetworkError(err error) bool { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { + return true + } + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return true + } + var errno syscall.Errno + if errors.As(err, &errno) { + switch errno { + case syscall.ECONNRESET, syscall.ECONNABORTED, syscall.ECONNREFUSED: + return true + } + } + return false +} + func (h HTTPFetcher) fetch(ctx context.Context, p string) ([]byte, error) { u, err := h.rootURL.Parse(p) if err != nil { @@ -75,24 +118,70 @@ func (h HTTPFetcher) fetch(ctx context.Context, p string) ([]byte, error) { } r, err := h.c.Do(req) if err != nil { - return nil, fmt.Errorf("get(%q): %v", u.String(), err) + if isTransientNetworkError(err) { + return nil, TransientError{Err: err} + } + return nil, err } + defer func() { + // Drain the body to ensure the underlying TCP connection can be returned + // to the keep-alive pool and reused for future requests. + // Limit the drain to avoid hanging on large or infinite responses. + _, _ = io.Copy(io.Discard, io.LimitReader(r.Body, 4096)) + + if err := r.Body.Close(); err != nil { + slog.ErrorContext(ctx, "resp.Body.Close", slog.Any("error", err)) + } + }() + switch r.StatusCode { case http.StatusOK: - // All good, continue below + data, err := io.ReadAll(r.Body) + if err != nil { + if isTransientNetworkError(err) { + return nil, TransientError{Err: err} + } + return nil, err + } + return data, nil case http.StatusNotFound: // Need to return ErrNotExist here, by contract. return nil, fmt.Errorf("get(%q): %w", u.String(), os.ErrNotExist) + case http.StatusTooManyRequests, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + var retryAfter time.Duration + if ra := r.Header.Get("Retry-After"); ra != "" { + retryAfter = parseRetryAfter(ra) + } + return nil, TransientError{ + Err: fmt.Errorf("get(%q): status %d", u.String(), r.StatusCode), + RetryAfter: retryAfter, + } default: return nil, fmt.Errorf("get(%q): %v", u.String(), r.StatusCode) } +} - defer func() { - if err := r.Body.Close(); err != nil { - slog.ErrorContext(ctx, "resp.Body.Close", slog.Any("error", err)) +// parseRetryAfter parses the Retry-After header and returns a time.Duration. +func parseRetryAfter(retryAfter string) time.Duration { + if retryAfter == "" { + return 0 + } + d, err := http.ParseTime(retryAfter) + if err == nil { + dur := time.Until(d) + if dur <= 0 { + return time.Nanosecond } - }() - return io.ReadAll(r.Body) + return dur + } + s, err := strconv.Atoi(retryAfter) + if err == nil { + if s <= 0 { + return time.Nanosecond + } + return time.Duration(s) * time.Second + } + return 0 } func (h HTTPFetcher) ReadCheckpoint(ctx context.Context) ([]byte, error) { @@ -111,6 +200,123 @@ func (h HTTPFetcher) ReadEntryBundle(ctx context.Context, i uint64, p uint8) ([] }) } +// retryOpts holds the configuration for retry logic. +type retryOpts struct { + maxRetries int + initialBackoff time.Duration + maxBackoff time.Duration +} + +// RetryOption is a function that modifies retryOpts. +type RetryOption func(*retryOpts) + +// WithMaxRetries sets the maximum number of retries. +func WithMaxRetries(n int) RetryOption { + return func(o *retryOpts) { o.maxRetries = n } +} + +// WithInitialBackoff sets the initial backoff duration. +func WithInitialBackoff(d time.Duration) RetryOption { + return func(o *retryOpts) { o.initialBackoff = d } +} + +// WithMaxBackoff sets the maximum backoff duration. +func WithMaxBackoff(d time.Duration) RetryOption { + return func(o *retryOpts) { o.maxBackoff = d } +} + +func defaultRetryOpts() retryOpts { + return retryOpts{ + maxRetries: 5, + initialBackoff: 100 * time.Millisecond, + maxBackoff: 2 * time.Second, + } +} + +// WithTileRetry decorates a TileFetcherFunc with retry logic. +func WithTileRetry(f TileFetcherFunc, opts ...RetryOption) TileFetcherFunc { + o := defaultRetryOpts() + for _, opt := range opts { + opt(&o) + } + return func(ctx context.Context, level, index uint64, p uint8) ([]byte, error) { + return retry(ctx, o, func() ([]byte, error) { + return f(ctx, level, index, p) + }) + } +} + +// WithEntryBundleRetry decorates an EntryBundleFetcherFunc with retry logic. +func WithEntryBundleRetry(f EntryBundleFetcherFunc, opts ...RetryOption) EntryBundleFetcherFunc { + o := defaultRetryOpts() + for _, opt := range opts { + opt(&o) + } + return func(ctx context.Context, bundleIndex uint64, p uint8) ([]byte, error) { + return retry(ctx, o, func() ([]byte, error) { + return f(ctx, bundleIndex, p) + }) + } +} + +// WithCheckpointRetry decorates a CheckpointFetcherFunc with retry logic. +func WithCheckpointRetry(f CheckpointFetcherFunc, opts ...RetryOption) CheckpointFetcherFunc { + o := defaultRetryOpts() + for _, opt := range opts { + opt(&o) + } + return func(ctx context.Context) ([]byte, error) { + return retry(ctx, o, func() ([]byte, error) { + return f(ctx) + }) + } +} + +// retry retries the function f with exponential backoff up to maxRetries. +func retry[T any](ctx context.Context, opts retryOpts, f func() (T, error)) (T, error) { + var backoff = opts.initialBackoff + var err error + var res T + for attempt := 0; attempt <= opts.maxRetries; attempt++ { + res, err = f() + if err == nil { + return res, nil + } + + var tErr TransientError + if errors.As(err, &tErr) { + if attempt == opts.maxRetries { + break + } + delay := backoff + if tErr.RetryAfter > 0 { + if tErr.RetryAfter > opts.maxBackoff { + return res, fmt.Errorf("Retry-After %v exceeds maxBackoff %v: %w", tErr.RetryAfter, opts.maxBackoff, err) + } + delay = tErr.RetryAfter + } + timer := time.NewTimer(delay) + select { + case <-ctx.Done(): + timer.Stop() + return res, ctx.Err() + case <-timer.C: + if tErr.RetryAfter == 0 { + nextBackoff := backoff * 2 + var jitter time.Duration + if n := int64(backoff); n > 0 { + jitter = time.Duration(rand.Int64N(n)) + } + backoff = min(nextBackoff+jitter, opts.maxBackoff) + } + continue + } + } + return res, err + } + return res, fmt.Errorf("after %d retries: %w", opts.maxRetries, err) +} + // FileFetcher knows how to fetch log artifacts from a filesystem rooted at Root. type FileFetcher struct { Root string diff --git a/client/fetcher_test.go b/client/fetcher_test.go index 3c0c69d3c..409927ed7 100644 --- a/client/fetcher_test.go +++ b/client/fetcher_test.go @@ -3,7 +3,15 @@ package client import ( "context" "errors" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "syscall" "testing" + "time" ) func TestFileFetcherContextCancellation(t *testing.T) { @@ -31,3 +39,294 @@ func TestFileFetcherContextCancellation(t *testing.T) { t.Errorf("ReadEntryBundle: got error %v, want %v", err, context.Canceled) } } + +func TestHTTPFetcherRetry(t *testing.T) { + tests := []struct { + name string + responses []int + retryAfter string + expectedError error + wantAttempts int + minDuration time.Duration + }{ + { + name: "SuccessFirstTry", + responses: []int{http.StatusOK}, + wantAttempts: 1, + }, + { + name: "RetryThenSuccess", + responses: []int{http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusOK}, + wantAttempts: 3, + }, + { + name: "MaxRetriesExceeded", + responses: []int{http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable}, + expectedError: errors.New("after 5 retries"), + wantAttempts: 6, + }, + { + name: "NotFoundNoRetry", + responses: []int{http.StatusNotFound}, + expectedError: os.ErrNotExist, + wantAttempts: 1, + }, + { + name: "RetryAfterRespected", + responses: []int{http.StatusTooManyRequests, http.StatusOK}, + retryAfter: "1", // 1 second + wantAttempts: 2, + minDuration: time.Second, + }, + { + name: "RetryAfterPastDate", + responses: []int{http.StatusTooManyRequests, http.StatusOK}, + retryAfter: "Wed, 21 Oct 2015 07:28:00 GMT", + wantAttempts: 2, + minDuration: 0, + }, + { + name: "RetryAfterZero", + responses: []int{http.StatusTooManyRequests, http.StatusOK}, + retryAfter: "0", + wantAttempts: 2, + minDuration: 0, + }, + { + name: "RetryAfterRFC850", + responses: []int{http.StatusTooManyRequests, http.StatusOK}, + retryAfter: "Sunday, 06-Nov-94 08:49:37 GMT", + wantAttempts: 2, + minDuration: 0, + }, + { + name: "RetryAfterANSIC", + responses: []int{http.StatusTooManyRequests, http.StatusOK}, + retryAfter: "Sun Nov 6 08:49:37 1994", + wantAttempts: 2, + minDuration: 0, + }, + { + name: "RetryAfterExceedsMaxBackoff", + responses: []int{http.StatusTooManyRequests}, + retryAfter: "3600", // 1 hour + expectedError: errors.New("exceeds maxBackoff"), + wantAttempts: 1, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if attempts < len(tc.responses) { + status := tc.responses[attempts] + attempts++ + if tc.retryAfter != "" && status == http.StatusTooManyRequests { + w.Header().Set("Retry-After", tc.retryAfter) + } + w.WriteHeader(status) + if status == http.StatusOK { + _, _ = w.Write([]byte("data")) + } + return + } + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + u, _ := url.Parse(server.URL) + fetcher, err := NewHTTPFetcher(u, nil) + if err != nil { + t.Fatal(err) + } + + // Decorate the fetch call using the helper + decoratedFetch := func(ctx context.Context) ([]byte, error) { + return retry(ctx, defaultRetryOpts(), func() ([]byte, error) { + return fetcher.fetch(ctx, "/") + }) + } + + ctx := context.Background() + + startTime := time.Now() + _, err = decoratedFetch(ctx) + duration := time.Since(startTime) + + if tc.expectedError != nil { + if err == nil { + t.Errorf("expected error %v, got nil", tc.expectedError) + } else if !errors.Is(err, tc.expectedError) && !strings.Contains(err.Error(), tc.expectedError.Error()) { + t.Errorf("expected error containing %v, got %v", tc.expectedError, err) + } + } else if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if attempts != tc.wantAttempts { + t.Errorf("got %d attempts, want %d", attempts, tc.wantAttempts) + } + + if tc.minDuration > 0 && duration < tc.minDuration { + t.Errorf("expected retry delay of at least %v, took %v", tc.minDuration, duration) + } + }) + } +} + +func TestWithTileRetry(t *testing.T) { + tests := []struct { + name string + responses []error + options []RetryOption + expectedError error + wantAttempts int + }{ + { + name: "SuccessFirstTry", + responses: []error{nil}, + wantAttempts: 1, + }, + { + name: "RetryThenSuccess", + responses: []error{TransientError{Err: errors.New("temporary")}, TransientError{Err: errors.New("temporary")}, nil}, + wantAttempts: 3, + }, + { + name: "MaxRetriesExceeded", + responses: []error{TransientError{Err: errors.New("temporary")}, TransientError{Err: errors.New("temporary")}, TransientError{Err: errors.New("temporary")}, TransientError{Err: errors.New("temporary")}, TransientError{Err: errors.New("temporary")}, TransientError{Err: errors.New("temporary")}}, + expectedError: errors.New("after 5 retries"), + wantAttempts: 6, + }, + { + name: "NonTransientErrorNoRetry", + responses: []error{errors.New("fatal")}, + expectedError: errors.New("fatal"), + wantAttempts: 1, + }, + { + name: "CustomMaxRetries", + responses: []error{TransientError{Err: errors.New("temporary")}, TransientError{Err: errors.New("temporary")}, TransientError{Err: errors.New("temporary")}}, + options: []RetryOption{WithMaxRetries(2)}, + expectedError: errors.New("after 2 retries"), + wantAttempts: 3, + }, + { + name: "InitialBackoffZero", + responses: []error{TransientError{Err: errors.New("temporary")}, nil}, + options: []RetryOption{WithInitialBackoff(0)}, + wantAttempts: 2, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + attempts := 0 + dummyFetcher := func(ctx context.Context, level, index uint64, p uint8) ([]byte, error) { + if attempts < len(tc.responses) { + err := tc.responses[attempts] + attempts++ + if err != nil { + return nil, err + } + return []byte("data"), nil + } + return nil, errors.New("unexpected call") + } + + decorated := WithTileRetry(dummyFetcher, tc.options...) + + _, err := decorated(context.Background(), 0, 0, 0) + + if tc.expectedError != nil { + if err == nil { + t.Errorf("expected error %v, got nil", tc.expectedError) + } else if !strings.Contains(err.Error(), tc.expectedError.Error()) { + t.Errorf("expected error containing %v, got %v", tc.expectedError, err) + } + } else if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if attempts != tc.wantAttempts { + t.Errorf("got %d attempts, want %d", attempts, tc.wantAttempts) + } + }) + } +} + +type myNetError struct { + timeout bool +} + +func (e myNetError) Error() string { return "error" } +func (e myNetError) Timeout() bool { return e.timeout } +func (e myNetError) Temporary() bool { return false } + +func TestIsTransientNetworkError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "NilError", + err: nil, + want: false, + }, + { + name: "GenericError", + err: errors.New("generic error"), + want: false, + }, + { + name: "ContextCanceled", + err: context.Canceled, + want: false, + }, + { + name: "ContextDeadlineExceeded", + err: context.DeadlineExceeded, + want: false, + }, + { + name: "TimeoutError", + err: myNetError{timeout: true}, + want: true, + }, + { + name: "NonTimeoutNetError", + err: myNetError{timeout: false}, + want: false, + }, + { + name: "UnexpectedEOF", + err: io.ErrUnexpectedEOF, + want: true, + }, + { + name: "EOF", + err: io.EOF, + want: true, + }, + { + name: "ConnReset", + err: syscall.ECONNRESET, + want: true, + }, + { + name: "ConnRefused", + err: syscall.ECONNREFUSED, + want: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := isTransientNetworkError(tc.err); got != tc.want { + t.Errorf("isTransientNetworkError() = %v, want %v", got, tc.want) + } + }) + } +}