Skip to content

Commit

Permalink
Fix POST retries (#45)
Browse files Browse the repository at this point in the history
* Add test to verify backoff Context cancellation behaviour
* Verify that call will return error when Context-cancelled during backoff
* Verify that error is due to Context-cancellation

* Fix Context-cancellation during backoff
* Context cancellation would result in no result being sent, thus
  causing the operation to hang in many circumstances

* Add tests to verify retry behaviour with POST requests
* Add test for POST-requests provided by Do() with a body
* Add test for POST-requests provided by Post() with a body
* Add middlewareServer to support POST tests

* Fix retries for requests with bodies
* Provide each concurrent request with its own Request so they do not
  interfere with each other. Eg reading/closing each others bodies.
* Reset body before retrying when requests have bodies

* Fix first request with Do and a body failing due to copyBody closing it
  • Loading branch information
lvanoort authored Feb 9, 2022
1 parent c02ad50 commit 32a1beb
Show file tree
Hide file tree
Showing 2 changed files with 292 additions and 20 deletions.
88 changes: 68 additions & 20 deletions pester.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,16 @@ func (c *Client) copyBody(src io.ReadCloser) ([]byte, error) {
return b, nil
}

// resetBody resets the Body and GetBody fields of an http.Request to new Readers over
// the originalBody. This is used to refresh http.Requests that may have had their
// bodies closed already.
func resetBody(request *http.Request, originalBody []byte) {
request.Body = io.NopCloser(bytes.NewBuffer(originalBody))
request.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewBuffer(originalBody)), nil
}
}

// pester provides all the logic of retries, concurrency, backoff, and logging
func (c *Client) pester(p params) (*http.Response, error) {
resultCh := make(chan result)
Expand Down Expand Up @@ -242,7 +252,6 @@ func (c *Client) pester(p params) (*http.Response, error) {

// if we have a request body, we need to save it for later
var (
request *http.Request
originalBody []byte
err error
)
Expand All @@ -252,23 +261,52 @@ func (c *Client) pester(p params) (*http.Response, error) {
} else if p.body != nil {
originalBody, err = c.copyBody(p.body)
}
if err != nil {
return nil, err
}

// check to make sure that we aren't trying to use an unsupported method
switch p.method {
case methodDo:
request = p.req
case methodGet, methodHead:
request, err = http.NewRequest(p.verb, p.url, nil)
case methodPostForm, methodPost:
request, err = http.NewRequest(http.MethodPost, p.url, ioutil.NopCloser(bytes.NewBuffer(originalBody)))
case methodDo, methodGet, methodHead, methodPostForm, methodPost:
default:
err = ErrUnexpectedMethod
}
if err != nil {
return nil, err
return nil, ErrUnexpectedMethod
}

if len(p.bodyType) > 0 {
request.Header.Set(headerKeyContentType, p.bodyType)
// provideRequest returns an HTTP request to be use when retrying.
// if concurrency is 1, it will return the same request that was supplied to the Do() method
// for Do() calls, otherwise it will generate a Clone() of the request each time it is called.
// For non-Do() calls, it creates a new request each time it is called. This re-creation behaviour
// is because requests are not supposed to be used again until the RoundTripper is finished
// with them, which cannot be guaranteed with concurrent callers
// https://pkg.go.dev/net/http#RoundTripper
provideRequest := func() (request *http.Request, err error) {
switch p.method {
case methodDo:
if concurrency > 1 {
request = p.req.Clone(p.req.Context())
} else {
request = p.req
}
if request.Body != nil {
// reset the body since Clone() doesn't do that for us
// and we drained it earlier when performing the Copy
// ex: https://go.dev/play/p/jlc6A-fjaOi
resetBody(request, originalBody)
}
case methodGet, methodHead:
request, err = http.NewRequest(p.verb, p.url, nil)
case methodPostForm, methodPost:
request, err = http.NewRequest(http.MethodPost, p.url, bytes.NewBuffer(originalBody))
}
if err != nil {
return
}

if len(p.bodyType) > 0 {
request.Header.Set(headerKeyContentType, p.bodyType)
}

return
}

AttemptLimit := c.MaxRetries
Expand All @@ -279,9 +317,15 @@ func (c *Client) pester(p params) (*http.Response, error) {
for n := 0; n < concurrency; n++ {
c.wg.Add(1)
totalSentRequests.Add(1)
go func(n int, req *http.Request) {
go func(n int) {
defer c.wg.Done()
defer totalSentRequests.Done()
req, err := provideRequest()
// couldn't get a request to use, so don't proceed
if err != nil {
multiplexCh <- result{err: err, req: n}
return
}

for i := 1; i <= AttemptLimit; i++ {
c.wg.Add(1)
Expand Down Expand Up @@ -340,15 +384,19 @@ func (c *Client) pester(p params) (*http.Response, error) {
case <-time.After(c.Backoff(i) + 1*time.Microsecond):
// allow context cancellation to cancel during backoff
case <-req.Context().Done():
multiplexCh <- result{resp: resp, err: req.Context().Err()}
return
}
}
}(n, request)

// rehydrate the body (it is drained each read)
if request.Body != nil {
request.Body = ioutil.NopCloser(bytes.NewBuffer(originalBody))
}
// we are about to retry, if we had a Body, we will need to restore it
// to a non-closed one in order to work reliably. If you do not do this,
// there are a number of curious edge cases depending on the type of the
// underlying reader: https://go.dev/play/p/gZLVUe2EXSE
if req.Body != nil {
resetBody(req, originalBody)
}
}
}(n)
}

// spin off the go routine so it can continually listen in on late results and close the response bodies
Expand Down
224 changes: 224 additions & 0 deletions pester_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -743,6 +744,171 @@ func TestRetriesNotAttemptedIfContextIsCancelled(t *testing.T) {
}
}

type roundTripperFunc func(r *http.Request) (*http.Response, error)

func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}

func TestRetriesContextCancelledDuringWait(t *testing.T) {
t.Parallel()
// in order for this test to work we need to be able to reliably put the client in a
// waiting state. To achieve this, we create a client that will fail fast
// via a custom RoundTripper that always fails and pair it with a custom BackoffStrategy
// that waits for a long time. This results in a client that should spend
// almost all of its time waiting.

ctx, cancel := context.WithCancel(context.Background())

c := NewExtendedClient(&http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("always fail")
}),
Timeout: 5 * time.Second,
})
c.MaxRetries = 2
c.Backoff = func(retry int) time.Duration {
return 5 * time.Second
}
// req details don't really matter, round-tripper will fail it anyway
req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost", nil)
if err != nil {
t.Fatalf("unable to create request %v", err)
}

// we want to perform the call in a goroutine so we can explicitly check for indefinite
// blocking behaviour. Since you cannot use t.Fatal/t.Error/etc. in a goroutine, we
// create a channel to communicate back to our main goroutine what happened
errReturn := make(chan error)
go func() {
// perform call in goroutine to check for indefinite blocks
_, err := c.Do(req)
errReturn <- err
}()

// wait a hundred ms to let the client fail and get into a waiting state
<-time.After(100 * time.Millisecond)
// cancel our context
cancel()

// if all has gone well, we should have aborted our wait period and the
// err channel should contain a Context-cancellation error

select {
case recdErr := <-errReturn:
if recdErr == nil {
t.Fatal("nil error returned from Do(req) routine")
}
// check that it is the right error message
if context.Canceled != recdErr {
t.Fatalf("unexpected error returned: %v", recdErr)
}
case <-time.After(time.Second):
// give it a second, then treat this as failing to return
t.Fatal("failed to receive error return")
}
}

func TestRetriesWithBodies_Do(t *testing.T) {
t.Parallel()

const testContent = "TestRetriesWithBodies_Do"
// using a channel to route these errors back into this goroutine
// it is important that this channel have enough capacity to hold all
// of the errors that will be generated by the test so that we do not
// deadlock. Therefore, MaxAttempts must be the same size as the channel capacity
// and each execution must only put at most one error on the channel.
serverReqErrCh := make(chan error, 4)
port, closeFn, err := middlewareServer(
contentVerificationMiddleware(serverReqErrCh, testContent),
always500RequestMiddleware(),
)
if err != nil {
t.Fatal("unable to start timeout server", err)
}
defer closeFn()

<-time.After(2 * time.Second)

iseUrl := fmt.Sprintf("http://localhost:%d", port)

req, err := http.NewRequest("POST", iseUrl, strings.NewReader(testContent))
if err != nil {
t.Fatalf("unable to create request %v", err)
}

c := New()
c.MaxRetries = cap(serverReqErrCh)
c.KeepLog = true
c.Backoff = func(retry int) time.Duration {
// backoff isn't important for this test
return 0
}

resp, err := c.Do(req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp == nil {
t.Error("response was unexpectedly nil")
} else if resp.StatusCode != http.StatusInternalServerError {
t.Errorf("unexpected response StatusCode: %v", resp.StatusCode)
}
// we're done making requests, so close the return channel and drain it
close(serverReqErrCh)
for v := range serverReqErrCh {
if v != nil {
t.Errorf("unexpected error occurred when server processed request: %v", v)
}
}
}

func TestRetriesWithBodies_POST(t *testing.T) {
t.Parallel()

const testContent = "TestRetriesWithBodies_POST"
// using a channel to route these errors back into this goroutine
// it is important that this channel have enough capacity to hold all
// of the errors that will be generated by the test so that we do not
// deadlock. Therefore, MaxAttempts must be the same size as the channel capacity
// and each execution must only put at most one error on the channel.
serverReqErrCh := make(chan error, 4)
port, closeFn, err := middlewareServer(
contentVerificationMiddleware(serverReqErrCh, testContent),
always500RequestMiddleware(),
)
if err != nil {
t.Fatal("unable to start timeout server", err)
}
defer closeFn()

c := New()
c.MaxRetries = cap(serverReqErrCh)
c.KeepLog = true
c.Backoff = func(retry int) time.Duration {
// backoff isn't important for this test
return 0
}

iseUrl := fmt.Sprintf("http://localhost:%d", port)
resp, err := c.Post(iseUrl, "text/plain", strings.NewReader(testContent))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp == nil {
t.Error("response was unexpectedly nil")
} else if resp.StatusCode != http.StatusInternalServerError {
t.Errorf("unexpected response StatusCode: %v", resp.StatusCode)
}
// we're done making requests, so close the return channel and drain it
close(serverReqErrCh)
for v := range serverReqErrCh {
if v != nil {
t.Errorf("unexpected error occurred when server processed request: %v", v)
}
}
}

func withinEpsilon(got, want int64, epslion float64) bool {
if want <= int64(epslion*float64(got)) || want >= int64(epslion*float64(got)) {
return false
Expand Down Expand Up @@ -880,3 +1046,61 @@ func serverWith400() (int, error) {

return port, nil
}

func contentVerificationMiddleware(errorCh chan<- error, expectedContent string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
content, err := ioutil.ReadAll(r.Body)
defer r.Body.Close()
if err != nil {
errorCh <- err
} else if string(content) != expectedContent {
errorCh <- fmt.Errorf(
"unexpected body content: expected \"%v\", got \"%v\"",
expectedContent,
string(content),
)
}
})
}

func always500RequestMiddleware() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("500 Internal Server Error"))
})
}

// middlewareServer stands up a server that accepts varags of middleware that conforms to the
// http.Handler interface
func middlewareServer(requestMiddleware ...http.Handler) (int, func(), error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
for _, v := range requestMiddleware {
v.ServeHTTP(w, r)
}
})
l, err := net.Listen("tcp", ":0")
if err != nil {
return -1, nil, fmt.Errorf("unable to secure listener %v", err)
}
server := &http.Server{
Handler: mux,
}
go func() {
if err := server.Serve(l); err != nil && err != http.ErrServerClosed {
log.Fatalf("middleware-server error %v", err)
}
}()

var port int
_, sport, err := net.SplitHostPort(l.Addr().String())
if err == nil {
port, err = strconv.Atoi(sport)
}

if err != nil {
return -1, nil, fmt.Errorf("unable to determine port %v", err)
}

return port, func() { server.Close() }, nil
}

0 comments on commit 32a1beb

Please sign in to comment.