From 917a81e1d834e3e06e6a20502e1ae1ab2e461bdc Mon Sep 17 00:00:00 2001 From: Dean Karn Date: Wed, 27 Mar 2024 10:12:52 -0700 Subject: [PATCH] Customized retrier (#50) --- CHANGELOG.md | 11 +- README.md | 4 +- _examples/net/http/retrier/main.go | 59 ++++++ ascii/helpers.go | 21 ++ errors/do.go | 3 + errors/retrier.go | 179 +++++++++++++++++ errors/retrier_test.go | 164 ++++++++++++++++ errors/retryable.go | 15 ++ net/http/helpers.go | 29 ++- net/http/helpers_go1.18.go | 26 ++- net/http/retrier.go | 301 +++++++++++++++++++++++++++++ net/http/retrier_test.go | 191 ++++++++++++++++++ net/http/retryable.go | 39 +++- 13 files changed, 1033 insertions(+), 9 deletions(-) create mode 100644 _examples/net/http/retrier/main.go create mode 100644 ascii/helpers.go create mode 100644 errors/retrier.go create mode 100644 errors/retrier_test.go create mode 100644 net/http/retrier.go create mode 100644 net/http/retrier_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 7413461..fd3707c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [5.29.0] - 2024-03-24 +### Added +- `asciiext` package for ASCII related functions. +- `errorsext.Retrier` configurable retry helper for any fallible operation. +- `httpext.Retrier` configurable retry helper for HTTP requests and parsing of responses. +- `httpext.DecodeResponseAny` non-generic helper for decoding HTTP responses. +- `httpext.HasRetryAfter` helper for checking if a response has a `Retry-After` header and returning duration to wait. + ## [5.28.1] - 2024-02-14 ### Fixed - Additional supported types, cast to `sql.Valuer` supported types, they need to be returned to the driver for evaluation. @@ -120,7 +128,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `timext.NanoTime` for fast low level monotonic time with nanosecond precision. -[Unreleased]: https://github.com/go-playground/pkg/compare/v5.28.1...HEAD +[Unreleased]: https://github.com/go-playground/pkg/compare/v5.29.0...HEAD +[5.29.0]: https://github.com/go-playground/pkg/compare/v5.28.1..v5.29.0 [5.28.1]: https://github.com/go-playground/pkg/compare/v5.28.0..v5.28.1 [5.28.0]: https://github.com/go-playground/pkg/compare/v5.27.0..v5.28.0 [5.27.0]: https://github.com/go-playground/pkg/compare/v5.26.0..v5.27.0 diff --git a/README.md b/README.md index ec5fac2..ed94ff6 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # pkg -![Project status](https://img.shields.io/badge/version-5.28.0-green.svg) +![Project status](https://img.shields.io/badge/version-5.29.0-green.svg) [![Lint & Test](https://github.com/go-playground/pkg/actions/workflows/go.yml/badge.svg)](https://github.com/go-playground/pkg/actions/workflows/go.yml) [![Coverage Status](https://coveralls.io/repos/github/go-playground/pkg/badge.svg?branch=master)](https://coveralls.io/github/go-playground/pkg?branch=master) [![GoDoc](https://godoc.org/github.com/go-playground/pkg?status.svg)](https://pkg.go.dev/mod/github.com/go-playground/pkg/v5) @@ -23,7 +23,7 @@ This is a place to put common reusable code that is not quite a library but exte - Generic Mutex and RWMutex. - Bytes helper placeholders units eg. MB, MiB, GB, ... - Detachable context. -- Error retryable helper functions. +- Retrier for helping with any fallible operation. - Proper RFC3339Nano definition. - unsafe []byte->string & string->[]byte helper functions. - HTTP helper functions and constant placeholders. diff --git a/_examples/net/http/retrier/main.go b/_examples/net/http/retrier/main.go new file mode 100644 index 0000000..e080f7f --- /dev/null +++ b/_examples/net/http/retrier/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "time" + + appext "github.com/go-playground/pkg/v5/app" + errorsext "github.com/go-playground/pkg/v5/errors" + httpext "github.com/go-playground/pkg/v5/net/http" + . "github.com/go-playground/pkg/v5/values/result" +) + +// customize as desired to meet your needs including custom retryable status codes, errors etc. +var retrier = httpext.NewRetryer() + +func main() { + ctx := appext.Context().Build() + + type Test struct { + Date time.Time + } + var count int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if count < 2 { + count++ + http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) + return + } + _ = httpext.JSON(w, http.StatusOK, Test{Date: time.Now().UTC()}) + })) + defer server.Close() + + // fetch response + fn := func(ctx context.Context) Result[*http.Request, error] { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + if err != nil { + return Err[*http.Request, error](err) + } + return Ok[*http.Request, error](req) + } + + var result Test + err := retrier.Do(ctx, fn, &result, http.StatusOK) + if err != nil { + panic(err) + } + fmt.Printf("Response: %+v\n", result) + + // `Retrier` configuration is copy and so the base `Retrier` can be used and even customized for one-off requests. + // eg for this request we change the max attempts from the default configuration. + err = retrier.MaxAttempts(errorsext.MaxAttempts, 2).Do(ctx, fn, &result, http.StatusOK) + if err != nil { + panic(err) + } + fmt.Printf("Response: %+v\n", result) +} diff --git a/ascii/helpers.go b/ascii/helpers.go new file mode 100644 index 0000000..bbd69ac --- /dev/null +++ b/ascii/helpers.go @@ -0,0 +1,21 @@ +package asciiext + +// IsAlphanumeric returns true if the byte is an ASCII letter or digit. +func IsAlphanumeric(c byte) bool { + return IsLower(c) || IsUpper(c) || IsDigit(c) +} + +// IsUpper returns true if the byte is an ASCII uppercase letter. +func IsUpper(c byte) bool { + return c >= 'A' && c <= 'Z' +} + +// IsLower returns true if the byte is an ASCII lowercase letter. +func IsLower(c byte) bool { + return c >= 'a' && c <= 'z' +} + +// IsDigit returns true if the byte is an ASCII digit. +func IsDigit(c byte) bool { + return c >= '0' && c <= '9' +} diff --git a/errors/do.go b/errors/do.go index 8baa49c..a2a45df 100644 --- a/errors/do.go +++ b/errors/do.go @@ -5,6 +5,7 @@ package errorsext import ( "context" + optionext "github.com/go-playground/pkg/v5/values/option" resultext "github.com/go-playground/pkg/v5/values/result" ) @@ -21,6 +22,8 @@ type IsRetryableFn[E any] func(err E) (reason string, isRetryable bool) type OnRetryFn[E any] func(ctx context.Context, originalErr E, reason string, attempt int) optionext.Option[E] // DoRetryable will execute the provided functions code and automatically retry using the provided retry function. +// +// Deprecated: use `errorsext.Retrier` instead which corrects design issues with the current implementation. func DoRetryable[T, E any](ctx context.Context, isRetryFn IsRetryableFn[E], onRetryFn OnRetryFn[E], fn RetryableFn[T, E]) resultext.Result[T, E] { var attempt int for { diff --git a/errors/retrier.go b/errors/retrier.go new file mode 100644 index 0000000..5e4e798 --- /dev/null +++ b/errors/retrier.go @@ -0,0 +1,179 @@ +//go:build go1.18 +// +build go1.18 + +package errorsext + +import ( + "context" + "time" + + . "github.com/go-playground/pkg/v5/values/result" +) + +// MaxAttemptsMode is used to set the mode for the maximum number of attempts. +// +// eg. Should the max attempts apply to all errors, just ones not determined to be retryable, reset on retryable errors, etc. +type MaxAttemptsMode uint8 + +const ( + // MaxAttemptsNonRetryableReset will apply the max attempts to all errors not determined to be retryable, but will + // reset the attempts if a retryable error is encountered after a non-retryable error. + MaxAttemptsNonRetryableReset MaxAttemptsMode = iota + + // MaxAttemptsNonRetryable will apply the max attempts to all errors not determined to be retryable. + MaxAttemptsNonRetryable + + // MaxAttempts will apply the max attempts to all errors, even those determined to be retryable. + MaxAttempts + + // MaxAttemptsUnlimited will not apply a maximum number of attempts. + MaxAttemptsUnlimited +) + +// BackoffFn is a function used to apply a backoff strategy to the retryable function. +// +// It accepts `E` in cases where the amount of time to backoff is dynamic, for example when and http request fails +// with a 429 status code, the `Retry-After` header can be used to determine how long to backoff. It is not required +// to use or handle `E` and can be ignored if desired. +type BackoffFn[E any] func(ctx context.Context, attempt int, e E) + +// IsRetryableFn2 is called to determine if the type E is retryable. +type IsRetryableFn2[E any] func(ctx context.Context, e E) (isRetryable bool) + +// EarlyReturnFn is the function that can be used to bypass all retry logic, no matter the MaxAttemptsMode, for when the +// type of `E` will never succeed and should not be retried. +// +// eg. If retrying an HTTP request and getting 400 Bad Request, it's unlikely to ever succeed and should not be retried. +type EarlyReturnFn[E any] func(ctx context.Context, e E) (earlyReturn bool) + +// Retryer is used to retry any fallible operation. +type Retryer[T, E any] struct { + isRetryableFn IsRetryableFn2[E] + isEarlyReturnFn EarlyReturnFn[E] + maxAttemptsMode MaxAttemptsMode + maxAttempts uint8 + bo BackoffFn[E] + timeout time.Duration +} + +// NewRetryer returns a new `Retryer` with sane default values. +// +// The default values are: +// - `MaxAttemptsMode` is `MaxAttemptsNonRetryableReset`. +// - `MaxAttempts` is 5. +// - `Timeout` is 0 no context timeout. +// - `IsRetryableFn` will always return false as `E` is unknown until defined. +// - `BackoffFn` will sleep for 200ms. It's recommended to use exponential backoff for production. +// - `EarlyReturnFn` will be None. +func NewRetryer[T, E any]() Retryer[T, E] { + return Retryer[T, E]{ + isRetryableFn: func(_ context.Context, _ E) bool { return false }, + maxAttemptsMode: MaxAttemptsNonRetryableReset, + maxAttempts: 5, + bo: func(ctx context.Context, attempt int, _ E) { + t := time.NewTimer(time.Millisecond * 200) + defer t.Stop() + select { + case <-ctx.Done(): + case <-t.C: + } + }, + } +} + +// IsRetryableFn sets the `IsRetryableFn` for the `Retryer`. +func (r Retryer[T, E]) IsRetryableFn(fn IsRetryableFn2[E]) Retryer[T, E] { + if fn == nil { + fn = func(_ context.Context, _ E) bool { return false } + } + r.isRetryableFn = fn + return r +} + +// IsEarlyReturnFn sets the `EarlyReturnFn` for the `Retryer`. +// +// NOTE: If the `EarlyReturnFn` and `IsRetryableFn` are both set and a conflicting `IsRetryableFn` will take precedence. +func (r Retryer[T, E]) IsEarlyReturnFn(fn EarlyReturnFn[E]) Retryer[T, E] { + r.isEarlyReturnFn = fn + return r +} + +// MaxAttempts sets the maximum number of attempts for the `Retryer`. +// +// NOTE: Max attempts is optional and if not set will retry indefinitely on retryable errors. +func (r Retryer[T, E]) MaxAttempts(mode MaxAttemptsMode, maxAttempts uint8) Retryer[T, E] { + r.maxAttemptsMode, r.maxAttempts = mode, maxAttempts + return r +} + +// Backoff sets the backoff function for the `Retryer`. +func (r Retryer[T, E]) Backoff(fn BackoffFn[E]) Retryer[T, E] { + if fn == nil { + fn = func(_ context.Context, _ int, _ E) {} + } + r.bo = fn + return r +} + +// Timeout sets the timeout for the `Retryer`. This is the timeout per `RetyableFn` attempt and not the entirety +// of the `Retryer` execution. +// +// A timeout of 0 will disable the timeout and is the default. +func (r Retryer[T, E]) Timeout(timeout time.Duration) Retryer[T, E] { + r.timeout = timeout + return r +} + +// Do will execute the provided functions code and automatically retry using the provided retry function. +func (r Retryer[T, E]) Do(ctx context.Context, fn RetryableFn[T, E]) Result[T, E] { + var attempt int + remaining := r.maxAttempts + for { + var result Result[T, E] + if r.timeout == 0 { + result = fn(ctx) + } else { + ctx, cancel := context.WithTimeout(ctx, r.timeout) + result = fn(ctx) + cancel() + } + if result.IsErr() { + err := result.Err() + isRetryable := r.isRetryableFn(ctx, err) + if !isRetryable && r.isEarlyReturnFn != nil && r.isEarlyReturnFn(ctx, err) { + return result + } + + switch r.maxAttemptsMode { + case MaxAttemptsUnlimited: + goto RETRY + case MaxAttemptsNonRetryableReset: + if isRetryable { + remaining = r.maxAttempts + goto RETRY + } else if remaining > 0 { + remaining-- + } + case MaxAttemptsNonRetryable: + if isRetryable { + goto RETRY + } else if remaining > 0 { + remaining-- + } + case MaxAttempts: + if remaining > 0 { + remaining-- + } + } + if remaining == 0 { + return result + } + + RETRY: + r.bo(ctx, attempt, err) + attempt++ + continue + } + return result + } +} diff --git a/errors/retrier_test.go b/errors/retrier_test.go new file mode 100644 index 0000000..f92674f --- /dev/null +++ b/errors/retrier_test.go @@ -0,0 +1,164 @@ +//go:build go1.18 +// +build go1.18 + +package errorsext + +import ( + "context" + "errors" + "io" + "testing" + "time" + + . "github.com/go-playground/assert/v2" + . "github.com/go-playground/pkg/v5/values/result" +) + +// TODO: Add IsRetryable and Retryable to helper functions. + +func TestRetrierMaxAttempts(t *testing.T) { + var i, j int + result := NewRetryer[int, error]().Backoff(func(ctx context.Context, attempt int, _ error) { + j++ + }).MaxAttempts(MaxAttempts, 3).Do(context.Background(), func(ctx context.Context) Result[int, error] { + i++ + if i > 50 { + panic("infinite loop") + } + return Err[int, error](io.EOF) + }) + Equal(t, result.IsErr(), true) + Equal(t, result.Err(), io.EOF) + Equal(t, i, 3) + Equal(t, j, 2) +} + +func TestRetrierMaxAttemptsNonRetryable(t *testing.T) { + var i, j int + returnErr := io.ErrUnexpectedEOF + result := NewRetryer[int, error]().IsRetryableFn(func(_ context.Context, e error) (isRetryable bool) { + if returnErr == io.EOF { + return false + } else { + return true + } + }).Backoff(func(ctx context.Context, attempt int, _ error) { + j++ + if j == 10 { + returnErr = io.EOF + } + }).MaxAttempts(MaxAttemptsNonRetryable, 3).Do(context.Background(), func(ctx context.Context) Result[int, error] { + i++ + if i > 50 { + panic("infinite loop") + } + return Err[int, error](returnErr) + }) + Equal(t, result.IsErr(), true) + Equal(t, result.Err(), io.EOF) + Equal(t, i, 13) + Equal(t, j, 12) +} + +func TestRetrierMaxAttemptsNonRetryableReset(t *testing.T) { + var i, j int + returnErr := io.EOF + result := NewRetryer[int, error]().IsRetryableFn(func(_ context.Context, e error) (isRetryable bool) { + if returnErr == io.EOF { + return false + } else { + return true + } + }).Backoff(func(ctx context.Context, attempt int, _ error) { + j++ + if j == 2 { + returnErr = io.ErrUnexpectedEOF + } else if j == 10 { + returnErr = io.EOF + } + }).MaxAttempts(MaxAttemptsNonRetryableReset, 3).Do(context.Background(), func(ctx context.Context) Result[int, error] { + i++ + if i > 50 { + panic("infinite loop") + } + return Err[int, error](returnErr) + }) + Equal(t, result.IsErr(), true) + Equal(t, result.Err(), io.EOF) + Equal(t, i, 13) + Equal(t, j, 12) +} + +func TestRetrierMaxAttemptsUnlimited(t *testing.T) { + var i, j int + r := NewRetryer[int, error]().Backoff(func(ctx context.Context, attempt int, _ error) { + j++ + }).MaxAttempts(MaxAttemptsUnlimited, 0) + + PanicMatches(t, func() { + r.Do(context.Background(), func(ctx context.Context) Result[int, error] { + i++ + if i > 50 { + panic("infinite loop") + } + return Err[int, error](io.EOF) + }) + }, "infinite loop") +} + +func TestRetrierMaxAttemptsTimeout(t *testing.T) { + result := NewRetryer[int, error]().Backoff(func(ctx context.Context, attempt int, _ error) { + }).MaxAttempts(MaxAttempts, 1).Timeout(time.Second). + Do(context.Background(), func(ctx context.Context) Result[int, error] { + select { + case <-ctx.Done(): + return Err[int, error](ctx.Err()) + case <-time.After(time.Second * 3): + return Err[int, error](io.EOF) + } + }) + Equal(t, result.IsErr(), true) + Equal(t, result.Err(), context.DeadlineExceeded) +} + +func TestRetrierEarlyReturn(t *testing.T) { + var earlyReturnCount int + + r := NewRetryer[int, error]().Backoff(func(ctx context.Context, attempt int, _ error) { + }).MaxAttempts(MaxAttempts, 5).Timeout(time.Second). + IsEarlyReturnFn(func(ctx context.Context, err error) bool { + earlyReturnCount++ + return errors.Is(err, io.EOF) + }).Backoff(nil) + + result := r.Do(context.Background(), func(ctx context.Context) Result[int, error] { + return Err[int, error](io.EOF) + }) + Equal(t, result.IsErr(), true) + Equal(t, result.Err(), io.EOF) + Equal(t, earlyReturnCount, 1) + + // now let try with retryable overriding early return TL;DR retryable should take precedence over early return + earlyReturnCount = 0 + isRetryableCount := 0 + result = r.IsRetryableFn(func(ctx context.Context, err error) (isRetryable bool) { + isRetryableCount++ + return errors.Is(err, io.EOF) + }).Do(context.Background(), func(ctx context.Context) Result[int, error] { + return Err[int, error](io.EOF) + }) + Equal(t, result.IsErr(), true) + Equal(t, result.Err(), io.EOF) + Equal(t, earlyReturnCount, 0) + Equal(t, isRetryableCount, 5) + + // while here let's check the first test case again, `Retrier` should be a copy and original still intact. + isRetryableCount = 0 + result = r.Do(context.Background(), func(ctx context.Context) Result[int, error] { + return Err[int, error](io.EOF) + }) + Equal(t, result.IsErr(), true) + Equal(t, result.Err(), io.EOF) + Equal(t, earlyReturnCount, 1) + Equal(t, isRetryableCount, 0) +} diff --git a/errors/retryable.go b/errors/retryable.go index 4774c5c..46756bb 100644 --- a/errors/retryable.go +++ b/errors/retryable.go @@ -42,6 +42,9 @@ func IsRetryableHTTP(err error) (retryType string, isRetryable bool) { // IsRetryableNetwork returns if the provided error is a retryable network related error. It also returns the // type, in string form, for optional logging and metrics use. func IsRetryableNetwork(err error) (retryType string, isRetryable bool) { + if IsRetryable(err) { + return "retryable", true + } if IsTemporary(err) { return "temporary", true } @@ -51,6 +54,18 @@ func IsRetryableNetwork(err error) (retryType string, isRetryable bool) { return IsTemporaryConnection(err) } +// IsRetryable returns true if the provided error is considered retryable by testing if it +// complies with an interface implementing `Retryable() bool` or `IsRetryable bool` and calling the function. +func IsRetryable(err error) bool { + var t interface{ IsRetryable() bool } + if errors.As(err, &t) && t.IsRetryable() { + return true + } + + var t2 interface{ Retryable() bool } + return errors.As(err, &t2) && t2.Retryable() +} + // IsTemporary returns true if the provided error is considered retryable temporary error by testing if it // complies with an interface implementing `Temporary() bool` and calling the function. func IsTemporary(err error) bool { diff --git a/net/http/helpers.go b/net/http/helpers.go index 745ff3c..884c33d 100644 --- a/net/http/helpers.go +++ b/net/http/helpers.go @@ -4,6 +4,7 @@ import ( "compress/gzip" "encoding/json" "encoding/xml" + "errors" "io" "mime" "net" @@ -12,6 +13,7 @@ import ( "path/filepath" "strings" + bytesext "github.com/go-playground/pkg/v5/bytes" ioext "github.com/go-playground/pkg/v5/io" ) @@ -210,7 +212,7 @@ func DecodeMultipartForm(r *http.Request, qp QueryParamsOption, maxMemory int64, } // DecodeJSON decodes the request body into the provided struct and limits the request size via -// an ioext.LimitReader using the maxMemory param. +// an ioext.LimitReader using the maxBytes param. // // The Content-Type e.g. "application/json" and http method are not checked. // @@ -244,7 +246,7 @@ func decodeJSON(headers http.Header, body io.Reader, qp QueryParamsOption, value } // DecodeXML decodes the request body into the provided struct and limits the request size via -// an ioext.LimitReader using the maxMemory param. +// an ioext.LimitReader using the maxBytes param. // // The Content-Type e.g. "application/xml" and http method are not checked. // @@ -295,7 +297,7 @@ const ( // Decode takes the request and attempts to discover its content type via // the http headers and then decode the request body into the provided struct. // Example if header was "application/json" would decode using -// json.NewDecoder(ioext.LimitReader(r.Body, maxMemory)).Decode(v). +// json.NewDecoder(ioext.LimitReader(r.Body, maxBytes)).Decode(v). // // This default to parsing query params if includeQueryParams=true and no other content type matches. // @@ -322,3 +324,24 @@ func Decode(r *http.Request, qp QueryParamsOption, maxMemory int64, v interface{ } return } + +// DecodeResponseAny takes the response and attempts to discover its content type via +// the http headers and then decode the request body into the provided type. +// +// Example if header was "application/json" would decode using +// json.NewDecoder(ioext.LimitReader(r.Body, maxBytes)).Decode(v). +func DecodeResponseAny(r *http.Response, maxMemory bytesext.Bytes, v interface{}) (err error) { + typ := r.Header.Get(ContentType) + if idx := strings.Index(typ, ";"); idx != -1 { + typ = typ[:idx] + } + switch typ { + case nakedApplicationJSON: + err = decodeJSON(r.Header, r.Body, NoQueryParams, nil, maxMemory, v) + case nakedApplicationXML: + err = decodeXML(r.Header, r.Body, NoQueryParams, nil, maxMemory, v) + default: + err = errors.New("unsupported content type") + } + return +} diff --git a/net/http/helpers_go1.18.go b/net/http/helpers_go1.18.go index 8fb0aea..38b4d52 100644 --- a/net/http/helpers_go1.18.go +++ b/net/http/helpers_go1.18.go @@ -5,16 +5,21 @@ package httpext import ( "errors" - bytesext "github.com/go-playground/pkg/v5/bytes" "net/http" + "strconv" "strings" + "time" + + asciiext "github.com/go-playground/pkg/v5/ascii" + bytesext "github.com/go-playground/pkg/v5/bytes" + . "github.com/go-playground/pkg/v5/values/option" ) // DecodeResponse takes the response and attempts to discover its content type via // the http headers and then decode the request body into the provided type. // // Example if header was "application/json" would decode using -// json.NewDecoder(ioext.LimitReader(r.Body, maxMemory)).Decode(v). +// json.NewDecoder(ioext.LimitReader(r.Body, maxBytes)).Decode(v). func DecodeResponse[T any](r *http.Response, maxMemory bytesext.Bytes) (result T, err error) { typ := r.Header.Get(ContentType) if idx := strings.Index(typ, ";"); idx != -1 { @@ -30,3 +35,20 @@ func DecodeResponse[T any](r *http.Response, maxMemory bytesext.Bytes) (result T } return } + +// HasRetryAfter parses the Retry-After header and returns the duration if possible. +func HasRetryAfter(headers http.Header) Option[time.Duration] { + if ra := headers.Get(RetryAfter); ra != "" { + if asciiext.IsDigit(ra[0]) { + if n, err := strconv.ParseInt(ra, 10, 64); err == nil { + return Some(time.Duration(n) * time.Second) + } + } else { + // not a number so must be a date in the future + if t, err := http.ParseTime(ra); err == nil { + return Some(time.Until(t)) + } + } + } + return None[time.Duration]() +} diff --git a/net/http/retrier.go b/net/http/retrier.go new file mode 100644 index 0000000..c6913b9 --- /dev/null +++ b/net/http/retrier.go @@ -0,0 +1,301 @@ +//go:build go1.18 +// +build go1.18 + +package httpext + +import ( + "context" + "errors" + "io" + "net/http" + "strconv" + "time" + + bytesext "github.com/go-playground/pkg/v5/bytes" + errorsext "github.com/go-playground/pkg/v5/errors" + ioext "github.com/go-playground/pkg/v5/io" + typesext "github.com/go-playground/pkg/v5/types" + valuesext "github.com/go-playground/pkg/v5/values" + . "github.com/go-playground/pkg/v5/values/result" +) + +// ErrStatusCode can be used to treat/indicate a status code as an error and ability to indicate if it is retryable. +type ErrStatusCode struct { + // StatusCode is the HTTP response status code that was encountered. + StatusCode int + + // IsRetryableStatusCode indicates if the status code is considered retryable. + IsRetryableStatusCode bool + + // Headers contains the headers from the HTTP response. + Headers http.Header + + // Body is the optional body of the HTTP response. + Body []byte +} + +// Error returns the error message for the status code. +func (e ErrStatusCode) Error() string { + return "status code encountered: " + strconv.Itoa(e.StatusCode) +} + +// IsRetryable returns if the provided status code is considered retryable. +func (e ErrStatusCode) IsRetryable() bool { + return e.IsRetryableStatusCode +} + +// BuildRequestFn2 is a function used to rebuild an HTTP request for use in retryable code. +type BuildRequestFn2 func(ctx context.Context) Result[*http.Request, error] + +// DecodeAnyFn is a function used to decode the response body into the desired type. +type DecodeAnyFn func(ctx context.Context, resp *http.Response, maxMemory bytesext.Bytes, v any) error + +// IsRetryableStatusCodeFn2 is a function used to determine if the provided status code is considered retryable. +type IsRetryableStatusCodeFn2 func(ctx context.Context, code int) bool + +// Retryer is used to retry any fallible operation. +// +// The `Retryer` is designed to be stateless and reusable. Configuration is also copy and so a base `Retryer` can be +// used and changed for one-off requests eg. changing max attempts resulting in a new `Retrier` for that request. +type Retryer struct { + isRetryableFn errorsext.IsRetryableFn2[error] + isRetryableStatusCodeFn IsRetryableStatusCodeFn2 + isEarlyReturnFn errorsext.EarlyReturnFn[error] + decodeFn DecodeAnyFn + backoffFn errorsext.BackoffFn[error] + client *http.Client + timeout time.Duration + maxBytes bytesext.Bytes + mode errorsext.MaxAttemptsMode + maxAttempts uint8 +} + +// NewRetryer returns a new `Retryer` with sane default values. +// +// The default values are: +// - `IsRetryableFn` uses the existing `errorsext.IsRetryableHTTP` function. +// - `MaxAttemptsMode` is `MaxAttemptsNonRetryableReset`. +// - `MaxAttempts` is 5. +// - `BackoffFn` will sleep for 200ms or is successful `Retry-After` header can be parsed. It's recommended to use +// exponential backoff for production with a quick copy-paste-modify of the default function +// - `Timeout` is 0. +// - `IsRetryableStatusCodeFn` is set to the existing `IsRetryableStatusCode` function. +// - `IsEarlyReturnFn` is set to check if the error is an `ErrStatusCode` and if the status code is non-retryable. +// - `Client` is set to `http.DefaultClient`. +// - `MaxBytes` is set to 2MiB. +// - `DecodeAnyFn` is set to the existing `DecodeResponseAny` function that supports JSON and XML. +// +// WARNING: The default functions may receive enhancements or fixes in the future which could change their behavior, +// however every attempt will be made to maintain backwards compatibility or made additive-only if possible. +func NewRetryer() Retryer { + return Retryer{ + client: http.DefaultClient, + maxBytes: 2 * bytesext.MiB, + mode: errorsext.MaxAttemptsNonRetryableReset, + maxAttempts: 5, + isRetryableFn: func(ctx context.Context, err error) (isRetryable bool) { + _, isRetryable = errorsext.IsRetryableHTTP(err) + return + }, + isRetryableStatusCodeFn: func(_ context.Context, code int) bool { return IsRetryableStatusCode(code) }, + isEarlyReturnFn: func(_ context.Context, err error) bool { + var sce ErrStatusCode + if errors.As(err, &sce) { + return IsNonRetryableStatusCode(sce.StatusCode) + } + return false + }, + decodeFn: func(ctx context.Context, resp *http.Response, maxMemory bytesext.Bytes, v any) error { + err := DecodeResponseAny(resp, maxMemory, v) + if err != nil { + return err + } + return nil + }, + backoffFn: func(ctx context.Context, attempt int, err error) { + + wait := time.Millisecond * 200 + + var sce ErrStatusCode + if errors.As(err, &sce) { + if sce.Headers != nil && (sce.StatusCode == http.StatusTooManyRequests || sce.StatusCode == http.StatusServiceUnavailable) { + if ra := HasRetryAfter(sce.Headers); ra.IsSome() { + wait = ra.Unwrap() + } + } + } + + t := time.NewTimer(wait) + defer t.Stop() + select { + case <-ctx.Done(): + case <-t.C: + } + }, + } +} + +// Client sets the `http.Client` for the `Retryer`. +func (r Retryer) Client(client *http.Client) Retryer { + r.client = client + return r +} + +// IsRetryableFn sets the `IsRetryableFn` for the `Retryer`. +func (r Retryer) IsRetryableFn(fn errorsext.IsRetryableFn2[error]) Retryer { + r.isRetryableFn = fn + return r +} + +// IsRetryableStatusCodeFn is called to determine if the status code is retryable. +func (r Retryer) IsRetryableStatusCodeFn(fn IsRetryableStatusCodeFn2) Retryer { + if fn == nil { + fn = func(_ context.Context, _ int) bool { return false } + } + r.isRetryableStatusCodeFn = fn + return r +} + +// IsEarlyReturnFn sets the `EarlyReturnFn` for the `Retryer`. +func (r Retryer) IsEarlyReturnFn(fn errorsext.EarlyReturnFn[error]) Retryer { + r.isEarlyReturnFn = fn + return r +} + +// DecodeFn sets the decode function for the `Retryer`. +func (r Retryer) DecodeFn(fn DecodeAnyFn) Retryer { + if fn == nil { + fn = func(_ context.Context, _ *http.Response, _ bytesext.Bytes, _ any) error { return nil } + } + r.decodeFn = fn + return r +} + +// MaxAttempts sets the maximum number of attempts for the `Retryer`. +// +// NOTE: Max attempts is optional and if not set will retry indefinitely on retryable errors. +func (r Retryer) MaxAttempts(mode errorsext.MaxAttemptsMode, maxAttempts uint8) Retryer { + r.mode, r.maxAttempts = mode, maxAttempts + return r +} + +// Backoff sets the backoff function for the `Retryer`. +func (r Retryer) Backoff(fn errorsext.BackoffFn[error]) Retryer { + r.backoffFn = fn + return r +} + +// MaxBytes sets the maximum memory to use when decoding the response body including: +// - upon unexpected status codes. +// - when decoding the response body. +// - when draining the response body before closing allowing connection re-use. +func (r Retryer) MaxBytes(i bytesext.Bytes) Retryer { + r.maxBytes = i + return r + +} + +// Timeout sets the timeout for the `Retryer`. This is the timeout per `RetyableFn` attempt and not the entirety +// of the `Retryer` execution. +// +// A timeout of 0 will disable the timeout and is the default. +func (r Retryer) Timeout(timeout time.Duration) Retryer { + r.timeout = timeout + return r +} + +// DoResponse will execute the provided functions code and automatically retry before returning the *http.Response +// based on HTTP status code, if defined, and can be used when processing of the response body may not be necessary +// or something custom is required. +// +// NOTE: it is up to the caller to close the response body if a successful request is made. +func (r Retryer) DoResponse(ctx context.Context, fn BuildRequestFn2, expectedResponseCodes ...int) Result[*http.Response, error] { + return errorsext.NewRetryer[*http.Response, error](). + IsRetryableFn(r.isRetryableFn). + MaxAttempts(r.mode, r.maxAttempts). + Backoff(r.backoffFn). + Timeout(r.timeout). + IsEarlyReturnFn(r.isEarlyReturnFn). + Do(ctx, func(ctx context.Context) Result[*http.Response, error] { + req := fn(ctx) + if req.IsErr() { + return Err[*http.Response, error](req.Err()) + } + + resp, err := r.client.Do(req.Unwrap()) + if err != nil { + return Err[*http.Response, error](err) + } + + if len(expectedResponseCodes) > 0 { + for _, code := range expectedResponseCodes { + if resp.StatusCode == code { + goto RETURN + } + } + b, _ := io.ReadAll(ioext.LimitReader(resp.Body, r.maxBytes)) + _ = resp.Body.Close() + return Err[*http.Response, error](ErrStatusCode{ + StatusCode: resp.StatusCode, + IsRetryableStatusCode: r.isRetryableStatusCodeFn(ctx, resp.StatusCode), + Headers: resp.Header, + Body: b, + }) + } + + RETURN: + return Ok[*http.Response, error](resp) + }) +} + +// Do will execute the provided functions code and automatically retry using the provided retry function decoding +// the response body into the desired type `v`, which must be passed as mutable. +func (r Retryer) Do(ctx context.Context, fn BuildRequestFn2, v any, expectedResponseCodes ...int) error { + result := errorsext.NewRetryer[typesext.Nothing, error](). + IsRetryableFn(r.isRetryableFn). + MaxAttempts(r.mode, r.maxAttempts). + Backoff(r.backoffFn). + Timeout(r.timeout). + IsEarlyReturnFn(r.isEarlyReturnFn). + Do(ctx, func(ctx context.Context) Result[typesext.Nothing, error] { + req := fn(ctx) + if req.IsErr() { + return Err[typesext.Nothing, error](req.Err()) + } + + resp, err := r.client.Do(req.Unwrap()) + if err != nil { + return Err[typesext.Nothing, error](err) + } + defer func() { + _, _ = io.Copy(io.Discard, ioext.LimitReader(resp.Body, r.maxBytes)) + _ = resp.Body.Close() + }() + + if len(expectedResponseCodes) > 0 { + for _, code := range expectedResponseCodes { + if resp.StatusCode == code { + goto DECODE + } + } + + b, _ := io.ReadAll(ioext.LimitReader(resp.Body, r.maxBytes)) + return Err[typesext.Nothing, error](ErrStatusCode{ + StatusCode: resp.StatusCode, + IsRetryableStatusCode: r.isRetryableStatusCodeFn(ctx, resp.StatusCode), + Headers: resp.Header, + Body: b, + }) + } + + DECODE: + if err = r.decodeFn(ctx, resp, r.maxBytes, v); err != nil { + return Err[typesext.Nothing, error](err) + } + return Ok[typesext.Nothing, error](valuesext.Nothing) + }) + if result.IsErr() { + return result.Err() + } + return nil +} diff --git a/net/http/retrier_test.go b/net/http/retrier_test.go new file mode 100644 index 0000000..f4f2a8e --- /dev/null +++ b/net/http/retrier_test.go @@ -0,0 +1,191 @@ +//go:build go1.18 +// +build go1.18 + +package httpext + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + . "github.com/go-playground/assert/v2" + errorsext "github.com/go-playground/pkg/v5/errors" + . "github.com/go-playground/pkg/v5/values/result" +) + +func TestRetryer_SuccessNoRetries(t *testing.T) { + ctx := context.Background() + + type Test struct { + Name string + } + tst := Test{Name: "test"} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = JSON(w, http.StatusOK, tst) + })) + defer server.Close() + + retryer := NewRetryer() + + result := retryer.DoResponse(ctx, func(ctx context.Context) Result[*http.Request, error] { + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + if err != nil { + return Err[*http.Request, error](err) + } + return Ok[*http.Request, error](req) + }, http.StatusOK) + Equal(t, result.IsOk(), true) + Equal(t, result.Unwrap().StatusCode, http.StatusOK) + defer result.Unwrap().Body.Close() + + var responseResult Test + err := retryer.Do(ctx, func(ctx context.Context) Result[*http.Request, error] { + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + if err != nil { + return Err[*http.Request, error](err) + } + return Ok[*http.Request, error](req) + }, &responseResult, http.StatusOK) + Equal(t, err, nil) + Equal(t, responseResult, tst) +} + +func TestRetryer_SuccessWithRetries(t *testing.T) { + ctx := context.Background() + var count int + + type Test struct { + Name string + } + tst := Test{Name: "test"} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if count < 2 { + w.WriteHeader(http.StatusServiceUnavailable) + count++ + return + } + _ = JSON(w, http.StatusOK, tst) + })) + defer server.Close() + + retryer := NewRetryer().Backoff(nil) + + result := retryer.DoResponse(ctx, func(ctx context.Context) Result[*http.Request, error] { + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + if err != nil { + return Err[*http.Request, error](err) + } + return Ok[*http.Request, error](req) + }, http.StatusOK) + Equal(t, result.IsOk(), true) + Equal(t, result.Unwrap().StatusCode, http.StatusOK) + defer result.Unwrap().Body.Close() + + count = 0 // reset count + + var responseResult Test + err := retryer.Do(ctx, func(ctx context.Context) Result[*http.Request, error] { + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + if err != nil { + return Err[*http.Request, error](err) + } + return Ok[*http.Request, error](req) + }, &responseResult, http.StatusOK) + Equal(t, err, nil) + Equal(t, responseResult, tst) +} + +func TestRetryer_FailureMaxRetries(t *testing.T) { + ctx := context.Background() + + type Test struct { + Name string + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + retryer := NewRetryer().Backoff(nil).MaxAttempts(errorsext.MaxAttempts, 2) + + result := retryer.DoResponse(ctx, func(ctx context.Context) Result[*http.Request, error] { + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + if err != nil { + return Err[*http.Request, error](err) + } + return Ok[*http.Request, error](req) + }, http.StatusOK) + Equal(t, result.IsErr(), true) + + var responseResult Test + err := retryer.Do(ctx, func(ctx context.Context) Result[*http.Request, error] { + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + if err != nil { + return Err[*http.Request, error](err) + } + return Ok[*http.Request, error](req) + }, &responseResult, http.StatusOK) + NotEqual(t, err, nil) +} + +func TestRetryer_ExtractStatusBody(t *testing.T) { + ctx := context.Background() + eStr := "nooooooooooooo!" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(eStr)) + })) + defer server.Close() + + retryer := NewRetryer().MaxAttempts(errorsext.MaxAttempts, 3) + + result := retryer.DoResponse(ctx, func(ctx context.Context) Result[*http.Request, error] { + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + if err != nil { + return Err[*http.Request, error](err) + } + return Ok[*http.Request, error](req) + }, http.StatusOK) + Equal(t, result.IsErr(), true) + var esc ErrStatusCode + Equal(t, errors.As(result.Err(), &esc), true) + Equal(t, esc.IsRetryableStatusCode, false) + // check the ultimate failed response body is intact + Equal(t, string(esc.Body), eStr) +} + +func TestRetryer_ExtractStatusBodyEarlyReturn(t *testing.T) { + ctx := context.Background() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(http.StatusText(http.StatusUnauthorized))) + })) + defer server.Close() + + var count int + + retryer := NewRetryer().Backoff(func(_ context.Context, _ int, _ error) { + count++ + }).MaxAttempts(errorsext.MaxAttempts, 2) + + result := retryer.DoResponse(ctx, func(ctx context.Context) Result[*http.Request, error] { + req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil) + if err != nil { + return Err[*http.Request, error](err) + } + return Ok[*http.Request, error](req) + }, http.StatusOK) + Equal(t, result.IsErr(), true) + var esc ErrStatusCode + Equal(t, errors.As(result.Err(), &esc), true) + Equal(t, esc.IsRetryableStatusCode, false) + // check the ultimate failed response body is intact + Equal(t, string(esc.Body), http.StatusText(http.StatusUnauthorized)) + Equal(t, count, 0) +} diff --git a/net/http/retryable.go b/net/http/retryable.go index 6162dde..aaef13d 100644 --- a/net/http/retryable.go +++ b/net/http/retryable.go @@ -28,6 +28,34 @@ var ( // https://support.cloudflare.com/hc/en-us/articles/115003011431-Error-524-A-timeout-occurred#524error 524: true, } + // nonRetryableStatusCodes defines common HTTP responses that are not considered never to be retryable. + nonRetryableStatusCodes = map[int]bool{ + http.StatusBadRequest: true, + http.StatusUnauthorized: true, + http.StatusForbidden: true, + http.StatusMethodNotAllowed: true, + http.StatusNotAcceptable: true, + http.StatusProxyAuthRequired: true, + http.StatusConflict: true, + http.StatusLengthRequired: true, + http.StatusPreconditionFailed: true, + http.StatusRequestEntityTooLarge: true, + http.StatusRequestURITooLong: true, + http.StatusUnsupportedMediaType: true, + http.StatusRequestedRangeNotSatisfiable: true, + http.StatusExpectationFailed: true, + http.StatusTeapot: true, + http.StatusMisdirectedRequest: true, + http.StatusUnprocessableEntity: true, + http.StatusPreconditionRequired: true, + http.StatusRequestHeaderFieldsTooLarge: true, + http.StatusUnavailableForLegalReasons: true, + http.StatusNotImplemented: true, + http.StatusHTTPVersionNotSupported: true, + http.StatusLoopDetected: true, + http.StatusNotExtended: true, + http.StatusNetworkAuthenticationRequired: true, + } ) // ErrRetryableStatusCode can be used to indicate a retryable HTTP status code was encountered as an error. @@ -48,11 +76,16 @@ func (e ErrUnexpectedResponse) Error() string { return "unexpected response encountered" } -// IsRetryableStatusCode returns if the provided status code is considered retryable. +// IsRetryableStatusCode returns true if the provided status code is considered retryable. func IsRetryableStatusCode(code int) bool { return retryableStatusCodes[code] } +// IsNonRetryableStatusCode returns true if the provided status code should generally not be retryable. +func IsNonRetryableStatusCode(code int) bool { + return nonRetryableStatusCodes[code] +} + // BuildRequestFn is a function used to rebuild an HTTP request for use in retryable code. type BuildRequestFn func(ctx context.Context) (*http.Request, error) @@ -60,6 +93,8 @@ type BuildRequestFn func(ctx context.Context) (*http.Request, error) type IsRetryableStatusCodeFn func(code int) bool // DoRetryableResponse will execute the provided functions code and automatically retry before returning the *http.Response. +// +// Deprecated: use `httpext.Retrier` instead which corrects design issues with the current implementation. func DoRetryableResponse(ctx context.Context, onRetryFn errorsext.OnRetryFn[error], isRetryableStatusCode IsRetryableStatusCodeFn, client *http.Client, buildFn BuildRequestFn) Result[*http.Response, error] { if client == nil { client = http.DefaultClient @@ -102,6 +137,8 @@ func DoRetryableResponse(ctx context.Context, onRetryFn errorsext.OnRetryFn[erro // Gzip supported: // - JSON // - XML +// +// Deprecated: use `httpext.Retrier` instead which corrects design issues with the current implementation. func DoRetryable[T any](ctx context.Context, isRetryableFn errorsext.IsRetryableFn[error], onRetryFn errorsext.OnRetryFn[error], isRetryableStatusCode IsRetryableStatusCodeFn, client *http.Client, expectedResponseCode int, maxMemory bytesext.Bytes, buildFn BuildRequestFn) Result[T, error] { return errorsext.DoRetryable(ctx, isRetryableFn, onRetryFn, func(ctx context.Context) Result[T, error] {