Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add OnSuccess, OnPanic, and OnInvalid hooks #586

Merged
merged 1 commit into from Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 60 additions & 1 deletion client.go
Expand Up @@ -85,6 +85,9 @@ type (

// ErrorHook type is for reacting to request errors, called after all retries were attempted
ErrorHook func(*Request, error)

// SuccessHook type is for reacting to request success
SuccessHook func(*Client, *Response)
)

// Client struct is used to create Resty client with client level settings,
Expand Down Expand Up @@ -136,10 +139,13 @@ type Client struct {
beforeRequest []RequestMiddleware
udBeforeRequest []RequestMiddleware
preReqHook PreRequestHook
successHooks []SuccessHook
afterResponse []ResponseMiddleware
requestLog RequestLogCallback
responseLog ResponseLogCallback
errorHooks []ErrorHook
invalidHooks []ErrorHook
panicHooks []ErrorHook
}

// User type is to hold an username and password information
Expand Down Expand Up @@ -440,11 +446,46 @@ func (c *Client) OnAfterResponse(m ResponseMiddleware) *Client {
// }
// // Log the error, increment a metric, etc...
// })
//
// Out of the OnSuccess, OnError, OnInvalid, OnPanic callbacks, exactly one
// set will be invoked for each call to Request.Execute() that comletes.
func (c *Client) OnError(h ErrorHook) *Client {
c.errorHooks = append(c.errorHooks, h)
return c
}

// OnSuccess method adds a callback that will be run whenever a request execution
// succeeds. This is called after all retries have been attempted (if any).
//
// Out of the OnSuccess, OnError, OnInvalid, OnPanic callbacks, exactly one
// set will be invoked for each call to Request.Execute() that comletes.
func (c *Client) OnSuccess(h SuccessHook) *Client {
c.successHooks = append(c.successHooks, h)
return c
}

// OnInvalid method adds a callback that will be run whever a request execution
// fails before it starts because the request is invalid.
//
// Out of the OnSuccess, OnError, OnInvalid, OnPanic callbacks, exactly one
// set will be invoked for each call to Request.Execute() that comletes.
func (c *Client) OnInvalid(h ErrorHook) *Client {
c.invalidHooks = append(c.invalidHooks, h)
return c
}

// OnPanic method adds a callback that will be run whever a request execution
// panics.
//
// Out of the OnSuccess, OnError, OnInvalid, OnPanic callbacks, exactly one
// set will be invoked for each call to Request.Execute() that completes.
// If an OnSuccess, OnError, or OnInvalid callback panics, then the exactly
// one rule can be violated.
func (c *Client) OnPanic(h ErrorHook) *Client {
c.panicHooks = append(c.panicHooks, h)
return c
}

// SetPreRequestHook method sets the given pre-request function into resty client.
// It is called right before the request is fired.
//
Expand Down Expand Up @@ -1020,7 +1061,7 @@ func (e *ResponseError) Unwrap() error {
return e.Err
}

// Helper to run onErrorHooks hooks.
// Helper to run errorHooks hooks.
// It wraps the error in a ResponseError if the resp is not nil
// so hooks can access it.
func (c *Client) onErrorHooks(req *Request, resp *Response, err error) {
Expand All @@ -1031,6 +1072,24 @@ func (c *Client) onErrorHooks(req *Request, resp *Response, err error) {
for _, h := range c.errorHooks {
h(req, err)
}
} else {
for _, h := range c.successHooks {
h(c, resp)
}
}
}

// Helper to run panicHooks hooks.
func (c *Client) onPanicHooks(req *Request, err error) {
for _, h := range c.panicHooks {
h(req, err)
}
}

// Helper to run invalidHooks hooks.
func (c *Client) onInvalidHooks(req *Request, err error) {
for _, h := range c.invalidHooks {
h(req, err)
}
}

Expand Down
58 changes: 57 additions & 1 deletion client_test.go
Expand Up @@ -654,6 +654,7 @@ func TestClientOnResponseError(t *testing.T) {
setup func(*Client)
isError bool
hasResponse bool
panics bool
}{
{
name: "successful_request",
Expand Down Expand Up @@ -702,6 +703,28 @@ func TestClientOnResponseError(t *testing.T) {
isError: true,
hasResponse: true,
},
{
name: "panic with error",
setup: func(client *Client) {
client.OnBeforeRequest(func(client *Client, request *Request) error {
panic(fmt.Errorf("before request"))
})
},
isError: false,
hasResponse: false,
panics: true,
},
{
name: "panic with string",
setup: func(client *Client) {
client.OnBeforeRequest(func(client *Client, request *Request) error {
panic("before request")
})
},
isError: false,
hasResponse: false,
panics: true,
},
}

for _, test := range tests {
Expand All @@ -715,7 +738,16 @@ func TestClientOnResponseError(t *testing.T) {
assertNotNil(t, v.Err)
}
}
var hook1, hook2 int
var hook1, hook2, hook3, hook4, hook5, hook6 int
defer func() {
if rec := recover(); rec != nil {
assertEqual(t, true, test.panics)
assertEqual(t, 0, hook1)
assertEqual(t, 0, hook3)
assertEqual(t, 1, hook5)
assertEqual(t, 1, hook6)
}
}()
c := New().outputLogTo(ioutil.Discard).
SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}).
SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF").
Expand All @@ -734,6 +766,24 @@ func TestClientOnResponseError(t *testing.T) {
OnError(func(r *Request, err error) {
assertErrorHook(r, err)
hook2++
}).
OnPanic(func(r *Request, err error) {
assertErrorHook(r, err)
hook5++
}).
OnPanic(func(r *Request, err error) {
assertErrorHook(r, err)
hook6++
}).
OnSuccess(func(c *Client, resp *Response) {
assertNotNil(t, c)
assertNotNil(t, resp)
hook3++
}).
OnSuccess(func(c *Client, resp *Response) {
assertNotNil(t, c)
assertNotNil(t, resp)
hook4++
})
if test.setup != nil {
test.setup(c)
Expand All @@ -743,8 +793,14 @@ func TestClientOnResponseError(t *testing.T) {
assertNotNil(t, err)
assertEqual(t, 1, hook1)
assertEqual(t, 1, hook2)
assertEqual(t, 0, hook3)
assertEqual(t, 0, hook5)
} else {
assertError(t, err)
assertEqual(t, 0, hook1)
assertEqual(t, 1, hook3)
assertEqual(t, 1, hook4)
assertEqual(t, 0, hook5)
}
})
}
Expand Down
2 changes: 2 additions & 0 deletions context_test.go
Expand Up @@ -192,6 +192,8 @@ func TestClientRetryWithSetContext(t *testing.T) {
SetContext(context.Background()).
Get(ts.URL + "/")

assertNotNil(t, ts)
assertNotNil(t, err)
assertEqual(t, true, (strings.HasPrefix(err.Error(), "Get "+ts.URL+"/") ||
strings.HasPrefix(err.Error(), "Get \""+ts.URL+"/\"")))
}
Expand Down
16 changes: 14 additions & 2 deletions request.go
Expand Up @@ -749,9 +749,22 @@ func (r *Request) Execute(method, url string) (*Response, error) {
var resp *Response
var err error

defer func() {
if rec := recover(); rec != nil {
if err, ok := rec.(error); ok {
r.client.onPanicHooks(r, err)
} else {
r.client.onPanicHooks(r, fmt.Errorf("panic %v", rec))
}
panic(rec)
}
}()

if r.isMultiPart && !(method == MethodPost || method == MethodPut || method == MethodPatch) {
// No OnError hook here since this is a request validation error
return nil, fmt.Errorf("multipart content is not allowed in HTTP verb [%v]", method)
err := fmt.Errorf("multipart content is not allowed in HTTP verb [%v]", method)
r.client.onInvalidHooks(r, err)
return nil, err
}

if r.SRV != nil {
Expand Down Expand Up @@ -793,7 +806,6 @@ func (r *Request) Execute(method, url string) (*Response, error) {
)

r.client.onErrorHooks(r, resp, unwrapNoRetryErr(err))

return resp, unwrapNoRetryErr(err)
}

Expand Down
21 changes: 21 additions & 0 deletions request_test.go
Expand Up @@ -824,6 +824,27 @@ func TestMultiPartUploadFileNotOnGetOrDelete(t *testing.T) {
Delete(ts.URL + "/upload")

assertEqual(t, "multipart content is not allowed in HTTP verb [DELETE]", err.Error())

var hook1Count int
var hook2Count int
_, err = dc().
OnInvalid(func(r *Request, err error) {
assertEqual(t, "multipart content is not allowed in HTTP verb [HEAD]", err.Error())
assertNotNil(t, r)
hook1Count++
}).
OnInvalid(func(r *Request, err error) {
assertEqual(t, "multipart content is not allowed in HTTP verb [HEAD]", err.Error())
assertNotNil(t, r)
hook2Count++
}).
R().
SetFile("profile_img", filepath.Join(basePath, "test-img.png")).
Head(ts.URL + "/upload")

assertEqual(t, "multipart content is not allowed in HTTP verb [HEAD]", err.Error())
assertEqual(t, 1, hook1Count)
assertEqual(t, 1, hook2Count)
}

func TestMultiPartFormData(t *testing.T) {
Expand Down