Skip to content

Commit

Permalink
add OnSuccess, OnPanic, and OnInvalid hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
muir committed Oct 1, 2022
1 parent 313f419 commit 3a6d1a8
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 4 deletions.
61 changes: 60 additions & 1 deletion client.go
Expand Up @@ -85,6 +85,9 @@ type (

// ErrorHook type is for reacting to request errors, called after all retries were attempted
ErrorHook func(*Request, error)

// SuccessHook type is for reacting to request success
SuccessHook func(*Client, *Response)
)

// Client struct is used to create Resty client with client level settings,
Expand Down Expand Up @@ -136,10 +139,13 @@ type Client struct {
beforeRequest []RequestMiddleware
udBeforeRequest []RequestMiddleware
preReqHook PreRequestHook
successHooks []SuccessHook
afterResponse []ResponseMiddleware
requestLog RequestLogCallback
responseLog ResponseLogCallback
errorHooks []ErrorHook
invalidHooks []ErrorHook
panicHooks []ErrorHook
}

// User type is to hold an username and password information
Expand Down Expand Up @@ -439,11 +445,46 @@ func (c *Client) OnAfterResponse(m ResponseMiddleware) *Client {
// }
// // Log the error, increment a metric, etc...
// })
//
// Out of the OnSuccess, OnError, OnInvalid, OnPanic callbacks, exactly one
// set will be invoked for each call to Request.Execute() that comletes.
func (c *Client) OnError(h ErrorHook) *Client {
c.errorHooks = append(c.errorHooks, h)
return c
}

// OnSuccess method adds a callback that will be run whenever a request execution
// succeeds. This is called after all retries have been attempted (if any).
//
// Out of the OnSuccess, OnError, OnInvalid, OnPanic callbacks, exactly one
// set will be invoked for each call to Request.Execute() that comletes.
func (c *Client) OnSuccess(h SuccessHook) *Client {
c.successHooks = append(c.successHooks, h)
return c
}

// OnInvalid method adds a callback that will be run whever a request execution
// fails before it starts because the request is invalid.
//
// Out of the OnSuccess, OnError, OnInvalid, OnPanic callbacks, exactly one
// set will be invoked for each call to Request.Execute() that comletes.
func (c *Client) OnInvalid(h ErrorHook) *Client {
c.invalidHooks = append(c.invalidHooks, h)
return c
}

// OnPanic method adds a callback that will be run whever a request execution
// panics.
//
// Out of the OnSuccess, OnError, OnInvalid, OnPanic callbacks, exactly one
// set will be invoked for each call to Request.Execute() that completes.
// If an OnSuccess, OnError, or OnInvalid callback panics, then the exactly
// one rule can be violated.
func (c *Client) OnPanic(h ErrorHook) *Client {
c.panicHooks = append(c.panicHooks, h)
return c
}

// SetPreRequestHook method sets the given pre-request function into resty client.
// It is called right before the request is fired.
//
Expand Down Expand Up @@ -1019,7 +1060,7 @@ func (e *ResponseError) Unwrap() error {
return e.Err
}

// Helper to run onErrorHooks hooks.
// Helper to run errorHooks hooks.
// It wraps the error in a ResponseError if the resp is not nil
// so hooks can access it.
func (c *Client) onErrorHooks(req *Request, resp *Response, err error) {
Expand All @@ -1030,6 +1071,24 @@ func (c *Client) onErrorHooks(req *Request, resp *Response, err error) {
for _, h := range c.errorHooks {
h(req, err)
}
} else {
for _, h := range c.successHooks {
h(c, resp)
}
}
}

// Helper to run panicHooks hooks.
func (c *Client) onPanicHooks(req *Request, err error) {
for _, h := range c.panicHooks {
h(req, err)
}
}

// Helper to run invalidHooks hooks.
func (c *Client) onInvalidHooks(req *Request, err error) {
for _, h := range c.invalidHooks {
h(req, err)
}
}

Expand Down
58 changes: 57 additions & 1 deletion client_test.go
Expand Up @@ -639,6 +639,7 @@ func TestClientOnResponseError(t *testing.T) {
setup func(*Client)
isError bool
hasResponse bool
panics bool
}{
{
name: "successful_request",
Expand Down Expand Up @@ -687,6 +688,28 @@ func TestClientOnResponseError(t *testing.T) {
isError: true,
hasResponse: true,
},
{
name: "panic with error",
setup: func(client *Client) {
client.OnBeforeRequest(func(client *Client, request *Request) error {
panic(fmt.Errorf("before request"))
})
},
isError: false,
hasResponse: false,
panics: true,
},
{
name: "panic with string",
setup: func(client *Client) {
client.OnBeforeRequest(func(client *Client, request *Request) error {
panic("before request")
})
},
isError: false,
hasResponse: false,
panics: true,
},
}

for _, test := range tests {
Expand All @@ -700,7 +723,16 @@ func TestClientOnResponseError(t *testing.T) {
assertNotNil(t, v.Err)
}
}
var hook1, hook2 int
var hook1, hook2, hook3, hook4, hook5, hook6 int
defer func() {
if rec := recover(); rec != nil {
assertEqual(t, true, test.panics)
assertEqual(t, 0, hook1)
assertEqual(t, 0, hook3)
assertEqual(t, 1, hook5)
assertEqual(t, 1, hook6)
}
}()
c := New().outputLogTo(ioutil.Discard).
SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}).
SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF").
Expand All @@ -719,6 +751,24 @@ func TestClientOnResponseError(t *testing.T) {
OnError(func(r *Request, err error) {
assertErrorHook(r, err)
hook2++
}).
OnPanic(func(r *Request, err error) {
assertErrorHook(r, err)
hook5++
}).
OnPanic(func(r *Request, err error) {
assertErrorHook(r, err)
hook6++
}).
OnSuccess(func(c *Client, resp *Response) {
assertNotNil(t, c)
assertNotNil(t, resp)
hook3++
}).
OnSuccess(func(c *Client, resp *Response) {
assertNotNil(t, c)
assertNotNil(t, resp)
hook4++
})
if test.setup != nil {
test.setup(c)
Expand All @@ -728,8 +778,14 @@ func TestClientOnResponseError(t *testing.T) {
assertNotNil(t, err)
assertEqual(t, 1, hook1)
assertEqual(t, 1, hook2)
assertEqual(t, 0, hook3)
assertEqual(t, 0, hook5)
} else {
assertError(t, err)
assertEqual(t, 0, hook1)
assertEqual(t, 1, hook3)
assertEqual(t, 1, hook4)
assertEqual(t, 0, hook5)
}
})
}
Expand Down
2 changes: 2 additions & 0 deletions context_test.go
Expand Up @@ -192,6 +192,8 @@ func TestClientRetryWithSetContext(t *testing.T) {
SetContext(context.Background()).
Get(ts.URL + "/")

assertNotNil(t, ts)
assertNotNil(t, err)
assertEqual(t, true, (strings.HasPrefix(err.Error(), "Get "+ts.URL+"/") ||
strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/\"")))
}
Expand Down
16 changes: 14 additions & 2 deletions request.go
Expand Up @@ -737,9 +737,22 @@ func (r *Request) Execute(method, url string) (*Response, error) {
var resp *Response
var err error

defer func() {
if rec := recover(); rec != nil {
if err, ok := rec.(error); ok {
r.client.onPanicHooks(r, err)
} else {
r.client.onPanicHooks(r, fmt.Errorf("panic %v", rec))
}
panic(rec)
}
}()

if r.isMultiPart && !(method == MethodPost || method == MethodPut || method == MethodPatch) {
// No OnError hook here since this is a request validation error
return nil, fmt.Errorf("multipart content is not allowed in HTTP verb [%v]", method)
err := fmt.Errorf("multipart content is not allowed in HTTP verb [%v]", method)
r.client.onInvalidHooks(r, err)
return nil, err
}

if r.SRV != nil {
Expand Down Expand Up @@ -781,7 +794,6 @@ func (r *Request) Execute(method, url string) (*Response, error) {
)

r.client.onErrorHooks(r, resp, unwrapNoRetryErr(err))

return resp, unwrapNoRetryErr(err)
}

Expand Down
21 changes: 21 additions & 0 deletions request_test.go
Expand Up @@ -779,6 +779,27 @@ func TestMultiPartUploadFileNotOnGetOrDelete(t *testing.T) {
Delete(ts.URL + "/upload")

assertEqual(t, "multipart content is not allowed in HTTP verb [DELETE]", err.Error())

var hook1Count int
var hook2Count int
_, err = dc().
OnInvalid(func(r *Request, err error) {
assertEqual(t, "multipart content is not allowed in HTTP verb [HEAD]", err.Error())
assertNotNil(t, r)
hook1Count++
}).
OnInvalid(func(r *Request, err error) {
assertEqual(t, "multipart content is not allowed in HTTP verb [HEAD]", err.Error())
assertNotNil(t, r)
hook2Count++
}).
R().
SetFile("profile_img", filepath.Join(basePath, "test-img.png")).
Head(ts.URL + "/upload")

assertEqual(t, "multipart content is not allowed in HTTP verb [HEAD]", err.Error())
assertEqual(t, 1, hook1Count)
assertEqual(t, 1, hook2Count)
}

func TestMultiPartFormData(t *testing.T) {
Expand Down

0 comments on commit 3a6d1a8

Please sign in to comment.