diff --git a/v2/protocol/http/protocol_retry.go b/v2/protocol/http/protocol_retry.go index fb7bcd27e..71e7346f3 100644 --- a/v2/protocol/http/protocol_retry.go +++ b/v2/protocol/http/protocol_retry.go @@ -6,8 +6,11 @@ package http import ( + "bytes" "context" "errors" + "io" + "io/ioutil" "net/http" "net/url" "time" @@ -53,6 +56,24 @@ func (p *Protocol) doWithRetry(ctx context.Context, params *cecontext.RetryParam retry := 0 results := make([]protocol.Result, 0) + var ( + body []byte + err error + ) + + if req != nil && req.Body != nil { + defer func() { + if err = req.Body.Close(); err != nil { + cecontext.LoggerFrom(ctx).Warnw("could not close request body", zap.Error(err)) + } + }() + body, err = ioutil.ReadAll(req.Body) + if err != nil { + panic(err) + } + resetBody(req, body) + } + for { msg, result := p.doOnce(req) @@ -90,6 +111,8 @@ func (p *Protocol) doWithRetry(ctx context.Context, params *cecontext.RetryParam } DoBackoff: + resetBody(req, body) + // Wait for the correct amount of backoff time. // total tries = retry + 1 @@ -103,3 +126,20 @@ func (p *Protocol) doWithRetry(ctx context.Context, params *cecontext.RetryParam results = append(results, result) } } + +// reset body to allow it to be read multiple times, e.g. when retrying http +// requests +func resetBody(req *http.Request, body []byte) { + if req == nil || req.Body == nil { + return + } + + req.Body = ioutil.NopCloser(bytes.NewReader(body)) + + // do not modify existing GetBody function + if req.GetBody == nil { + req.GetBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewReader(body)), nil + } + } +} diff --git a/v2/protocol/http/protocol_retry_test.go b/v2/protocol/http/protocol_retry_test.go index 06a7bd504..425b98add 100644 --- a/v2/protocol/http/protocol_retry_test.go +++ b/v2/protocol/http/protocol_retry_test.go @@ -7,13 +7,14 @@ package http import ( "context" - - "github.com/stretchr/testify/require" - "net/http" + "net/http/httptest" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/cloudevents/sdk-go/v2/binding" cecontext "github.com/cloudevents/sdk-go/v2/context" "github.com/cloudevents/sdk-go/v2/event" @@ -21,17 +22,18 @@ import ( ) func TestRequestWithRetries_linear(t *testing.T) { - dummyEvent := event.New() - dummyMsg := binding.ToMessage(&dummyEvent) - ctx := cecontext.WithTarget(context.Background(), "http://test") testCases := map[string]struct { + event event.Event // roundTripperTest statusCodes []int // -1 = timeout - // Linear Backoff + // retry configuration delay time.Duration retries int + // http server + respDelay []time.Duration // slice maps to []statusCodes, 0 for no delay + // Wants wantResult protocol.Result wantRequestCount int @@ -41,7 +43,8 @@ func TestRequestWithRetries_linear(t *testing.T) { // Custom IsRetriable handler isRetriableFunc IsRetriable }{ - "no retries, ACK": { + "no retries, no event body, ACK": { + event: newEvent(t, "", nil), statusCodes: []int{200}, retries: 0, wantResult: &RetriesResult{ @@ -50,7 +53,8 @@ func TestRequestWithRetries_linear(t *testing.T) { }, wantRequestCount: 1, }, - "no retries, NACK": { + "no retries, no event body, NACK": { + event: newEvent(t, "", nil), statusCodes: []int{404}, retries: 0, wantResult: &RetriesResult{ @@ -59,16 +63,20 @@ func TestRequestWithRetries_linear(t *testing.T) { }, wantRequestCount: 1, }, - "retries, no NACK": { - statusCodes: []int{200}, + "no retries, with default handler, with event body, 500, 200, ACK": { + event: newEvent(t, event.ApplicationJSON, "hello world"), + statusCodes: []int{500, 200}, delay: time.Nanosecond, retries: 3, wantResult: &RetriesResult{ - Result: NewResult(200, "%w", protocol.ResultACK), + Result: NewResult(500, "%w", protocol.ResultNACK), + Duration: time.Nanosecond, + Attempts: []protocol.Result{NewResult(500, "%w", protocol.ResultNACK)}, }, wantRequestCount: 1, }, - "3 retries, 425, 200, ACK": { + "3 retries, no event body, 425, 200, ACK": { + event: newEvent(t, "", nil), statusCodes: []int{425, 200}, delay: time.Nanosecond, retries: 3, @@ -80,18 +88,26 @@ func TestRequestWithRetries_linear(t *testing.T) { }, wantRequestCount: 2, }, - "no retries with default handler, 500, 200, ACK": { - statusCodes: []int{500, 200}, + "3 retries, with event body, 503, 503, 503, NACK": { + event: newEvent(t, event.ApplicationJSON, map[string]string{"hello": "world"}), delay: time.Nanosecond, + statusCodes: []int{503, 503, 503, 503}, retries: 3, wantResult: &RetriesResult{ - Result: NewResult(500, "%w", protocol.ResultNACK), + Result: NewResult(503, "%w", protocol.ResultNACK), + Retries: 3, Duration: time.Nanosecond, - Attempts: []protocol.Result{NewResult(500, "%w", protocol.ResultNACK)}, + Attempts: []protocol.Result{ + NewResult(503, "%w", protocol.ResultNACK), + NewResult(503, "%w", protocol.ResultNACK), + NewResult(503, "%w", protocol.ResultNACK), + }, }, - wantRequestCount: 1, + wantRequestCount: 4, + skipResults: true, }, - "3 retry with custom handler, 500, 500, 200, ACK": { + "3 retries, with custom handler, with event body, 500, 500, 200, ACK": { + event: newEvent(t, event.ApplicationJSON, map[string]string{"hello": "world"}), statusCodes: []int{500, 500, 200}, delay: time.Nanosecond, retries: 3, @@ -107,7 +123,8 @@ func TestRequestWithRetries_linear(t *testing.T) { wantRequestCount: 3, isRetriableFunc: func(sc int) bool { return sc == 500 }, }, - "1 retry, 425, 429, 200, NACK": { + "1 retry, no event body, 425, 429, 200, NACK": { + event: newEvent(t, "", nil), statusCodes: []int{425, 429, 200}, delay: time.Nanosecond, retries: 1, @@ -119,7 +136,8 @@ func TestRequestWithRetries_linear(t *testing.T) { }, wantRequestCount: 2, }, - "10 retries, 425, 429, 502, 503, 504, 200, ACK": { + "10 retries, with event body, 425, 429, 502, 503, 504, 200, ACK": { + event: newEvent(t, event.ApplicationJSON, map[string]string{"hello": "world"}), statusCodes: []int{425, 429, 502, 503, 504, 200}, delay: time.Nanosecond, retries: 10, @@ -136,14 +154,16 @@ func TestRequestWithRetries_linear(t *testing.T) { }, wantRequestCount: 6, }, - "retries, timeout, 200, ACK": { - delay: time.Nanosecond, - statusCodes: []int{-1, 200}, + "5 retries, with event body, timeout, 200, ACK": { + event: newEvent(t, event.ApplicationJSON, map[string]string{"hello": "world"}), + delay: time.Millisecond * 500, + statusCodes: []int{200, 200}, // client will time out before first 200 retries: 5, + respDelay: []time.Duration{time.Second, 0}, wantResult: &RetriesResult{ Result: NewResult(200, "%w", protocol.ResultACK), Retries: 1, - Duration: time.Nanosecond, + Duration: time.Millisecond * 500, Attempts: nil, // skipping test as it contains internal http errors }, wantRequestCount: 2, @@ -152,10 +172,15 @@ func TestRequestWithRetries_linear(t *testing.T) { } for n, tc := range testCases { t.Run(n, func(t *testing.T) { - roundTripper := roundTripperTest{statusCodes: tc.statusCodes} + mockSrv := &roundTripperTest{ + statusCodes: tc.statusCodes, + delays: tc.respDelay, + } + srv := httptest.NewServer(mockSrv) + defer srv.Close() + opts := []Option{ WithClient(http.Client{Timeout: time.Second}), - WithRoundTripper(&roundTripper), } if tc.isRetriableFunc != nil { opts = append(opts, WithIsRetriableFunc(tc.isRetriableFunc)) @@ -165,12 +190,19 @@ func TestRequestWithRetries_linear(t *testing.T) { if err != nil { t.Fatalf("no protocol") } + + ctx := cecontext.WithTarget(context.Background(), srv.URL) ctxWithRetries := cecontext.WithRetriesLinearBackoff(ctx, tc.delay, tc.retries) + + dummyMsg := binding.ToMessage(&tc.event) _, got := p.Request(ctxWithRetries, dummyMsg) - if roundTripper.requestCount != tc.wantRequestCount { - t.Errorf("expected %d requests, got %d", tc.wantRequestCount, roundTripper.requestCount) + srvCount := func() int { + mockSrv.Lock() + defer mockSrv.Unlock() + return mockSrv.requestCount } + assert.Equal(t, tc.wantRequestCount, srvCount()) if tc.skipResults { got.(*RetriesResult).Attempts = nil @@ -180,3 +212,13 @@ func TestRequestWithRetries_linear(t *testing.T) { }) } } + +func newEvent(t *testing.T, encoding string, body interface{}) event.Event { + e := event.New() + if body != nil { + err := e.SetData(encoding, body) + require.NoError(t, err) + } + + return e +} diff --git a/v2/protocol/http/protocol_test.go b/v2/protocol/http/protocol_test.go index 0624f333b..818ef60c2 100644 --- a/v2/protocol/http/protocol_test.go +++ b/v2/protocol/http/protocol_test.go @@ -7,20 +7,21 @@ package http import ( "context" - "errors" "net/http" "net/http/httptest" "strconv" + "sync" "testing" "time" - "github.com/cloudevents/sdk-go/v2/binding" - "github.com/cloudevents/sdk-go/v2/protocol" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/time/rate" + + "github.com/cloudevents/sdk-go/v2/binding" + "github.com/cloudevents/sdk-go/v2/protocol" ) func TestNew(t *testing.T) { @@ -265,7 +266,7 @@ func ReceiveTest(t *testing.T, p *Protocol, ctx context.Context, rec *httptest.R } func TestServeHTTP_ReceiveWithLimiter(t *testing.T) { - var testCases = map[string]struct { + testCases := map[string]struct { limiter RateLimiter delay time.Duration // client send @@ -354,19 +355,31 @@ func (rl *rateLimiterTest) Close(_ context.Context) error { } type roundTripperTest struct { + sync.Mutex statusCodes []int requestCount int + delays []time.Duration } -func (r *roundTripperTest) RoundTrip(req *http.Request) (*http.Response, error) { +func (r *roundTripperTest) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + defer func() { + r.Lock() + r.requestCount++ + r.Unlock() + }() + + r.Lock() code := r.statusCodes[r.requestCount] - r.requestCount++ - if code == -1 { - time.Sleep(2 * time.Second) - return nil, errors.New("timeout") + delay := time.Duration(0) + if r.delays != nil { + delay = r.delays[r.requestCount] } + r.Unlock() - return &http.Response{StatusCode: code}, nil + time.Sleep(delay) + if code != 200 { + http.Error(w, http.StatusText(code), code) + } } func newDoneContext() context.Context {