Skip to content

Commit

Permalink
fix: Reset request body on retry (#774) (#776)
Browse files Browse the repository at this point in the history
Closes: #773
Signed-off-by: Michael Gasch <mgasch@vmware.com>

Co-authored-by: Michael Gasch <mgasch@vmware.com>
  • Loading branch information
n3wscott and Michael Gasch committed Jun 10, 2022
1 parent 448331c commit 11daec8
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 39 deletions.
40 changes: 40 additions & 0 deletions v2/protocol/http/protocol_retry.go
Expand Up @@ -6,8 +6,11 @@
package http

import (
"bytes"
"context"
"errors"
"io"
"io/ioutil"
"net/http"
"net/url"
"time"
Expand Down Expand Up @@ -53,6 +56,24 @@ func (p *Protocol) doWithRetry(ctx context.Context, params *cecontext.RetryParam
retry := 0
results := make([]protocol.Result, 0)

var (
body []byte
err error
)

if req != nil && req.Body != nil {
defer func() {
if err = req.Body.Close(); err != nil {
cecontext.LoggerFrom(ctx).Warnw("could not close request body", zap.Error(err))
}
}()
body, err = ioutil.ReadAll(req.Body)
if err != nil {
panic(err)
}
resetBody(req, body)
}

for {
msg, result := p.doOnce(req)

Expand Down Expand Up @@ -90,6 +111,8 @@ func (p *Protocol) doWithRetry(ctx context.Context, params *cecontext.RetryParam
}

DoBackoff:
resetBody(req, body)

// Wait for the correct amount of backoff time.

// total tries = retry + 1
Expand All @@ -103,3 +126,20 @@ func (p *Protocol) doWithRetry(ctx context.Context, params *cecontext.RetryParam
results = append(results, result)
}
}

// reset body to allow it to be read multiple times, e.g. when retrying http
// requests
func resetBody(req *http.Request, body []byte) {
if req == nil || req.Body == nil {
return
}

req.Body = ioutil.NopCloser(bytes.NewReader(body))

// do not modify existing GetBody function
if req.GetBody == nil {
req.GetBody = func() (io.ReadCloser, error) {
return ioutil.NopCloser(bytes.NewReader(body)), nil
}
}
}
100 changes: 71 additions & 29 deletions v2/protocol/http/protocol_retry_test.go
Expand Up @@ -7,31 +7,33 @@ package http

import (
"context"

"github.com/stretchr/testify/require"

"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/cloudevents/sdk-go/v2/binding"
cecontext "github.com/cloudevents/sdk-go/v2/context"
"github.com/cloudevents/sdk-go/v2/event"
"github.com/cloudevents/sdk-go/v2/protocol"
)

func TestRequestWithRetries_linear(t *testing.T) {
dummyEvent := event.New()
dummyMsg := binding.ToMessage(&dummyEvent)
ctx := cecontext.WithTarget(context.Background(), "http://test")
testCases := map[string]struct {
event event.Event
// roundTripperTest
statusCodes []int // -1 = timeout

// Linear Backoff
// retry configuration
delay time.Duration
retries int

// http server
respDelay []time.Duration // slice maps to []statusCodes, 0 for no delay

// Wants
wantResult protocol.Result
wantRequestCount int
Expand All @@ -41,7 +43,8 @@ func TestRequestWithRetries_linear(t *testing.T) {
// Custom IsRetriable handler
isRetriableFunc IsRetriable
}{
"no retries, ACK": {
"no retries, no event body, ACK": {
event: newEvent(t, "", nil),
statusCodes: []int{200},
retries: 0,
wantResult: &RetriesResult{
Expand All @@ -50,7 +53,8 @@ func TestRequestWithRetries_linear(t *testing.T) {
},
wantRequestCount: 1,
},
"no retries, NACK": {
"no retries, no event body, NACK": {
event: newEvent(t, "", nil),
statusCodes: []int{404},
retries: 0,
wantResult: &RetriesResult{
Expand All @@ -59,16 +63,20 @@ func TestRequestWithRetries_linear(t *testing.T) {
},
wantRequestCount: 1,
},
"retries, no NACK": {
statusCodes: []int{200},
"no retries, with default handler, with event body, 500, 200, ACK": {
event: newEvent(t, event.ApplicationJSON, "hello world"),
statusCodes: []int{500, 200},
delay: time.Nanosecond,
retries: 3,
wantResult: &RetriesResult{
Result: NewResult(200, "%w", protocol.ResultACK),
Result: NewResult(500, "%w", protocol.ResultNACK),
Duration: time.Nanosecond,
Attempts: []protocol.Result{NewResult(500, "%w", protocol.ResultNACK)},
},
wantRequestCount: 1,
},
"3 retries, 425, 200, ACK": {
"3 retries, no event body, 425, 200, ACK": {
event: newEvent(t, "", nil),
statusCodes: []int{425, 200},
delay: time.Nanosecond,
retries: 3,
Expand All @@ -80,18 +88,26 @@ func TestRequestWithRetries_linear(t *testing.T) {
},
wantRequestCount: 2,
},
"no retries with default handler, 500, 200, ACK": {
statusCodes: []int{500, 200},
"3 retries, with event body, 503, 503, 503, NACK": {
event: newEvent(t, event.ApplicationJSON, map[string]string{"hello": "world"}),
delay: time.Nanosecond,
statusCodes: []int{503, 503, 503, 503},
retries: 3,
wantResult: &RetriesResult{
Result: NewResult(500, "%w", protocol.ResultNACK),
Result: NewResult(503, "%w", protocol.ResultNACK),
Retries: 3,
Duration: time.Nanosecond,
Attempts: []protocol.Result{NewResult(500, "%w", protocol.ResultNACK)},
Attempts: []protocol.Result{
NewResult(503, "%w", protocol.ResultNACK),
NewResult(503, "%w", protocol.ResultNACK),
NewResult(503, "%w", protocol.ResultNACK),
},
},
wantRequestCount: 1,
wantRequestCount: 4,
skipResults: true,
},
"3 retry with custom handler, 500, 500, 200, ACK": {
"3 retries, with custom handler, with event body, 500, 500, 200, ACK": {
event: newEvent(t, event.ApplicationJSON, map[string]string{"hello": "world"}),
statusCodes: []int{500, 500, 200},
delay: time.Nanosecond,
retries: 3,
Expand All @@ -107,7 +123,8 @@ func TestRequestWithRetries_linear(t *testing.T) {
wantRequestCount: 3,
isRetriableFunc: func(sc int) bool { return sc == 500 },
},
"1 retry, 425, 429, 200, NACK": {
"1 retry, no event body, 425, 429, 200, NACK": {
event: newEvent(t, "", nil),
statusCodes: []int{425, 429, 200},
delay: time.Nanosecond,
retries: 1,
Expand All @@ -119,7 +136,8 @@ func TestRequestWithRetries_linear(t *testing.T) {
},
wantRequestCount: 2,
},
"10 retries, 425, 429, 502, 503, 504, 200, ACK": {
"10 retries, with event body, 425, 429, 502, 503, 504, 200, ACK": {
event: newEvent(t, event.ApplicationJSON, map[string]string{"hello": "world"}),
statusCodes: []int{425, 429, 502, 503, 504, 200},
delay: time.Nanosecond,
retries: 10,
Expand All @@ -136,14 +154,16 @@ func TestRequestWithRetries_linear(t *testing.T) {
},
wantRequestCount: 6,
},
"retries, timeout, 200, ACK": {
delay: time.Nanosecond,
statusCodes: []int{-1, 200},
"5 retries, with event body, timeout, 200, ACK": {
event: newEvent(t, event.ApplicationJSON, map[string]string{"hello": "world"}),
delay: time.Millisecond * 500,
statusCodes: []int{200, 200}, // client will time out before first 200
retries: 5,
respDelay: []time.Duration{time.Second, 0},
wantResult: &RetriesResult{
Result: NewResult(200, "%w", protocol.ResultACK),
Retries: 1,
Duration: time.Nanosecond,
Duration: time.Millisecond * 500,
Attempts: nil, // skipping test as it contains internal http errors
},
wantRequestCount: 2,
Expand All @@ -152,10 +172,15 @@ func TestRequestWithRetries_linear(t *testing.T) {
}
for n, tc := range testCases {
t.Run(n, func(t *testing.T) {
roundTripper := roundTripperTest{statusCodes: tc.statusCodes}
mockSrv := &roundTripperTest{
statusCodes: tc.statusCodes,
delays: tc.respDelay,
}
srv := httptest.NewServer(mockSrv)
defer srv.Close()

opts := []Option{
WithClient(http.Client{Timeout: time.Second}),
WithRoundTripper(&roundTripper),
}
if tc.isRetriableFunc != nil {
opts = append(opts, WithIsRetriableFunc(tc.isRetriableFunc))
Expand All @@ -165,12 +190,19 @@ func TestRequestWithRetries_linear(t *testing.T) {
if err != nil {
t.Fatalf("no protocol")
}

ctx := cecontext.WithTarget(context.Background(), srv.URL)
ctxWithRetries := cecontext.WithRetriesLinearBackoff(ctx, tc.delay, tc.retries)

dummyMsg := binding.ToMessage(&tc.event)
_, got := p.Request(ctxWithRetries, dummyMsg)

if roundTripper.requestCount != tc.wantRequestCount {
t.Errorf("expected %d requests, got %d", tc.wantRequestCount, roundTripper.requestCount)
srvCount := func() int {
mockSrv.Lock()
defer mockSrv.Unlock()
return mockSrv.requestCount
}
assert.Equal(t, tc.wantRequestCount, srvCount())

if tc.skipResults {
got.(*RetriesResult).Attempts = nil
Expand All @@ -180,3 +212,13 @@ func TestRequestWithRetries_linear(t *testing.T) {
})
}
}

func newEvent(t *testing.T, encoding string, body interface{}) event.Event {
e := event.New()
if body != nil {
err := e.SetData(encoding, body)
require.NoError(t, err)
}

return e
}
33 changes: 23 additions & 10 deletions v2/protocol/http/protocol_test.go
Expand Up @@ -7,20 +7,21 @@ package http

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"

"github.com/cloudevents/sdk-go/v2/binding"
"github.com/cloudevents/sdk-go/v2/protocol"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"

"github.com/cloudevents/sdk-go/v2/binding"
"github.com/cloudevents/sdk-go/v2/protocol"
)

func TestNew(t *testing.T) {
Expand Down Expand Up @@ -265,7 +266,7 @@ func ReceiveTest(t *testing.T, p *Protocol, ctx context.Context, rec *httptest.R
}

func TestServeHTTP_ReceiveWithLimiter(t *testing.T) {
var testCases = map[string]struct {
testCases := map[string]struct {
limiter RateLimiter
delay time.Duration // client send

Expand Down Expand Up @@ -354,19 +355,31 @@ func (rl *rateLimiterTest) Close(_ context.Context) error {
}

type roundTripperTest struct {
sync.Mutex
statusCodes []int
requestCount int
delays []time.Duration
}

func (r *roundTripperTest) RoundTrip(req *http.Request) (*http.Response, error) {
func (r *roundTripperTest) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
defer func() {
r.Lock()
r.requestCount++
r.Unlock()
}()

r.Lock()
code := r.statusCodes[r.requestCount]
r.requestCount++
if code == -1 {
time.Sleep(2 * time.Second)
return nil, errors.New("timeout")
delay := time.Duration(0)
if r.delays != nil {
delay = r.delays[r.requestCount]
}
r.Unlock()

return &http.Response{StatusCode: code}, nil
time.Sleep(delay)
if code != 200 {
http.Error(w, http.StatusText(code), code)
}
}

func newDoneContext() context.Context {
Expand Down

0 comments on commit 11daec8

Please sign in to comment.