Skip to content

Commit

Permalink
Merge pull request #207 from hashicorp/sebasslash/handle-go-away
Browse files Browse the repository at this point in the history
Sets request's GetBody field on create wrapper
  • Loading branch information
sebasslash committed Nov 8, 2023
2 parents 571a88b + f95735f commit 309c58e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 7 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/go-retryablehttp.yml
Expand Up @@ -10,7 +10,7 @@ jobs:
- name: Setup go
uses: actions/setup-go@4d34df0c2316fe8122ab82dc22947d607c0c91f9 # v4.0.0
with:
go-version: 1.14.2
go-version: 1.18
- uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0
- run: mkdir -p "$TEST_RESULTS"/go-retryablyhttp
- name: restore_cache
Expand All @@ -20,6 +20,7 @@ jobs:
restore-keys: go-mod-v1-{{ checksum "go.sum" }}
path: "/go/pkg/mod"
- run: go mod download
- run: go mod tidy
- name: Run go format
run: |-
files=$(go fmt ./...)
Expand All @@ -29,7 +30,7 @@ jobs:
exit 1
fi
- name: Install gotestsum
run: go get gotest.tools/gotestsum
run: go install gotest.tools/gotestsum@latest
- name: Run unit tests
run: |-
PACKAGE_NAMES=$(go list ./...)
Expand Down
25 changes: 20 additions & 5 deletions client.go
Expand Up @@ -160,6 +160,20 @@ func (r *Request) SetBody(rawBody interface{}) error {
}
r.body = bodyReader
r.ContentLength = contentLength
if bodyReader != nil {
r.GetBody = func() (io.ReadCloser, error) {
body, err := bodyReader()
if err != nil {
return nil, err
}
if rc, ok := body.(io.ReadCloser); ok {
return rc, nil
}
return io.NopCloser(body), nil
}
} else {
r.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil }
}
return nil
}

Expand Down Expand Up @@ -302,18 +316,19 @@ func NewRequest(method, url string, rawBody interface{}) (*Request, error) {
// The context controls the entire lifetime of a request and its response:
// obtaining a connection, sending the request, and reading the response headers and body.
func NewRequestWithContext(ctx context.Context, method, url string, rawBody interface{}) (*Request, error) {
bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody)
httpReq, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
return nil, err
}

httpReq, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
req := &Request{
Request: httpReq,
}
if err := req.SetBody(rawBody); err != nil {
return nil, err
}
httpReq.ContentLength = contentLength

return &Request{body: bodyReader, Request: httpReq}, nil
return req, nil
}

// Logger interface allows to use other loggers than
Expand Down
59 changes: 59 additions & 0 deletions client_test.go
Expand Up @@ -978,3 +978,62 @@ func TestClient_StandardClient(t *testing.T) {
t.Fatalf("expected %v, got %v", client, v)
}
}

func TestClient_RedirectWithBody(t *testing.T) {
var redirects int32
// Mock server which always responds 200.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.RequestURI {
case "/redirect":
w.Header().Set("Location", "/target")
w.WriteHeader(http.StatusTemporaryRedirect)
case "/target":
atomic.AddInt32(&redirects, 1)
w.WriteHeader(http.StatusCreated)
default:
t.Fatalf("bad uri: %s", r.RequestURI)
}
}))
defer ts.Close()

client := NewClient()
client.RequestLogHook = func(logger Logger, req *http.Request, retryNumber int) {
if _, err := req.GetBody(); err != nil {
t.Fatalf("unexpected error with GetBody: %v", err)
}
}
// create a request with a body
req, err := NewRequest(http.MethodPost, ts.URL+"/redirect", strings.NewReader(`{"foo":"bar"}`))
if err != nil {
t.Fatalf("err: %v", err)
}

resp, err := client.Do(req)
if err != nil {
t.Fatalf("err: %v", err)
}
resp.Body.Close()

if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected status code 201, got: %d", resp.StatusCode)
}

// now one without a body
if err := req.SetBody(nil); err != nil {
t.Fatalf("err: %v", err)
}

resp, err = client.Do(req)
if err != nil {
t.Fatalf("err: %v", err)
}
resp.Body.Close()

if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected status code 201, got: %d", resp.StatusCode)
}

if atomic.LoadInt32(&redirects) != 2 {
t.Fatalf("Expected the client to be redirected 2 times, got: %d", atomic.LoadInt32(&redirects))
}
}

0 comments on commit 309c58e

Please sign in to comment.