Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rate limiter to client #715

Merged
merged 5 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ type Client struct {
errorHooks []ErrorHook
invalidHooks []ErrorHook
panicHooks []ErrorHook
rateLimiter RateLimiter
}

// User type is to hold an username and password information
Expand Down Expand Up @@ -920,6 +921,13 @@ func (c *Client) SetOutputDirectory(dirPath string) *Client {
return c
}

// SetRateLimiter sets an optional `RateLimiter`. If set the rate limiter will control
// all requests made with this client.
func (c *Client) SetRateLimiter(rl RateLimiter) *Client {
c.rateLimiter = rl
return c
}

// SetTransport method sets custom `*http.Transport` or any `http.RoundTripper`
// compatible interface implementation in the resty client.
//
Expand Down Expand Up @@ -1141,6 +1149,14 @@ func (c *Client) execute(req *Request) (*Response, error) {
}
}

// If there is a rate limiter set for this client, the Execute call
// will return an error if the rate limit is exceeded.
if req.client.rateLimiter != nil {
if !req.client.rateLimiter.Allow() {
return nil, wrapNoRetryErr(ErrRateLimitExceeded)
}
}

// resty middlewares
for _, f := range c.beforeRequest {
if err = f(c, req); err != nil {
Expand Down
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ module github.com/go-resty/resty/v2

go 1.16

require golang.org/x/net v0.15.0
require (
golang.org/x/net v0.15.0
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
Expand Down
40 changes: 40 additions & 0 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"strings"
"testing"
"time"

"golang.org/x/time/rate"
)

type AuthSuccess struct {
Expand Down Expand Up @@ -66,6 +68,44 @@ func TestGetGH524(t *testing.T) {
assertEqual(t, resp.Request.Header.Get("Content-Type"), "") // unable to reproduce reported issue
}

func TestRateLimiter(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

// Test a burst with a valid capacity and then a consecutive request that must fail.

// Allow a rate of 1 every 100 ms but also allow bursts of 10 requests.
client := dc().SetRateLimiter(rate.NewLimiter(rate.Every(100*time.Millisecond), 10))

// Execute a burst of 10 requests.
for i := 0; i < 10; i++ {
resp, err := client.R().
SetQueryParam("request_no", strconv.Itoa(i)).Get(ts.URL + "/")
assertError(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())
}
// Next request issued directly should fail because burst of 10 has been consumed.
{
_, err := client.R().
SetQueryParam("request_no", strconv.Itoa(11)).Get(ts.URL + "/")
assertErrorIs(t, ErrRateLimitExceeded, err)
}

// Test continues request at a valid rate

// Allow a rate of 1 every ms with no burst.
client = dc().SetRateLimiter(rate.NewLimiter(rate.Every(1*time.Millisecond), 1))

// Sending requests every ms+tiny delta must succeed.
for i := 0; i < 100; i++ {
resp, err := client.R().
SetQueryParam("request_no", strconv.Itoa(i)).Get(ts.URL + "/")
assertError(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())
time.Sleep(1*time.Millisecond + 100*time.Microsecond)
}
}

func TestIllegalRetryCount(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()
Expand Down
11 changes: 11 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package resty

import (
"bytes"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -64,6 +65,16 @@ func (l *logger) output(format string, v ...interface{}) {
l.l.Printf(format, v...)
}

//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
// Rate Limiter interface
//_______________________________________________________________________

type RateLimiter interface {
Allow() bool
}

var ErrRateLimitExceeded = errors.New("rate limit exceeded")

//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
// Package Helper methods
//_______________________________________________________________________
Expand Down