Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

馃敟 Update: add timeout context middleware #2090

Merged
merged 9 commits into from Sep 16, 2022
81 changes: 73 additions & 8 deletions middleware/timeout/README.md
@@ -1,5 +1,9 @@
# Timeout
Timeout middleware for [Fiber](https://github.com/gofiber/fiber) wraps a `fiber.Handler` with a timeout. If the handler takes longer than the given duration to return, the timeout error is set and forwarded to the centralized [ErrorHandler](https://docs.gofiber.io/error-handling).
Timeout middleware for Fiber. As a `fiber.Handler` wrapper, it creates a context with `context.WithTimeout` and pass it in `UserContext`.

If the context passed executions (eg. DB ops, Http calls) takes longer than the given duration to return, the timeout error is set and forwarded to the centralized `ErrorHandler`.

It has no race conditions, ready to use on production.

### Table of Contents
- [Signatures](#signatures)
Expand All @@ -8,7 +12,7 @@ Timeout middleware for [Fiber](https://github.com/gofiber/fiber) wraps a `fiber.

### Signatures
```go
func New(h fiber.Handler, t time.Duration) fiber.Handler
func New(handler fiber.Handler, timeout time.Duration, timeoutErrors ...error) fiber.Handler
```

### Examples
Expand All @@ -20,15 +24,76 @@ import (
)
```

After you initiate your Fiber app, you can use the following possibilities:
Sample timeout middleware usage
```go
handler := func(ctx *fiber.Ctx) error {
err := ctx.SendString("Hello, World 馃憢!")
if err != nil {
return err
func main() {
app := fiber.New()
h := func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContext(c.UserContext(), sleepTime); err != nil {
return fmt.Errorf("%w: execution error", err)
}
return nil
}

app.Get("/foo/:sleepTime", timeout.New(h, 2*time.Second))
_ = app.Listen(":3000")
}

func sleepWithContext(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d)

select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return context.DeadlineExceeded
case <-timer.C:
}
return nil
}
```

Test http 200 with curl:
```bash
curl --location -I --request GET 'http://localhost:3000/foo/1000'
```

Test http 408 with curl:
```bash
curl --location -I --request GET 'http://localhost:3000/foo/3000'
```


When using with custom error:
```go
var ErrFooTimeOut = errors.New("foo context canceled")

func main() {
app := fiber.New()
h := func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContextWithCustomError(c.UserContext(), sleepTime); err != nil {
return fmt.Errorf("%w: execution error", err)
}
return nil
}

app.Get("/foo/:sleepTime", timeout.New(h, 2*time.Second), ErrFooTimeOut)
_ = app.Listen(":3000")
}

app.Get("/foo", timeout.New(handler, 5 * time.Second))
func sleepWithContext(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return ErrFooTimeOut
case <-timer.C:
}
return nil
}
```
47 changes: 17 additions & 30 deletions middleware/timeout/timeout.go
@@ -1,43 +1,30 @@
package timeout

import (
"fmt"
"sync"
"context"
"errors"
"time"

"github.com/gofiber/fiber/v2"
)

var once sync.Once

// New wraps a handler and aborts the process of the handler if the timeout is reached
func New(handler fiber.Handler, timeout time.Duration) fiber.Handler {
once.Do(func() {
fmt.Println("[Warning] timeout contains data race issues, not ready for production!")
})

if timeout <= 0 {
return handler
}

// logic is from fasthttp.TimeoutWithCodeHandler https://github.com/valyala/fasthttp/blob/master/server.go#L418
// New implementation of timeout middleware. Set custom errors(context.DeadlineExceeded vs) for get fiber.ErrRequestTimeout response.
func New(h fiber.Handler, t time.Duration, tErrs ...error) fiber.Handler {
return func(ctx *fiber.Ctx) error {
ch := make(chan struct{}, 1)

go func() {
defer func() {
_ = recover()
}()
_ = handler(ctx)
ch <- struct{}{}
}()

select {
case <-ch:
case <-time.After(timeout):
return fiber.ErrRequestTimeout
timeoutContext, cancel := context.WithTimeout(ctx.UserContext(), t)
defer cancel()
ctx.SetUserContext(timeoutContext)
if err := h(ctx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return fiber.ErrRequestTimeout
}
for i := range tErrs {
if errors.Is(err, tErrs[i]) {
return fiber.ErrRequestTimeout
}
}
return err
}

return nil
}
}
125 changes: 77 additions & 48 deletions middleware/timeout/timeout_test.go
@@ -1,55 +1,84 @@
package timeout

// // go test -run Test_Middleware_Timeout
// func Test_Middleware_Timeout(t *testing.T) {
// app := fiber.New(fiber.Config{DisableStartupMessage: true})
import (
"context"
"errors"
"fmt"
"net/http/httptest"
"testing"
"time"

// h := New(func(c *fiber.Ctx) error {
// sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
// time.Sleep(sleepTime)
// return c.SendString("After " + c.Params("sleepTime") + "ms sleeping")
// }, 5*time.Millisecond)
// app.Get("/test/:sleepTime", h)
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)

// testTimeout := func(timeoutStr string) {
// resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
// utils.AssertEqual(t, nil, err, "app.Test(req)")
// utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
// go test -run Test_Timeout
func Test_Timeout(t *testing.T) {
// fiber instance
app := fiber.New()
h := New(func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil {
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
}
return nil
}, 100*time.Millisecond)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
testTimeout("300")
testTimeout("500")
testSucces("50")
testSucces("30")
}

// body, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, "Request Timeout", string(body))
// }
// testSucces := func(timeoutStr string) {
// resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
// utils.AssertEqual(t, nil, err, "app.Test(req)")
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
var ErrFooTimeOut = errors.New("foo context canceled")

// body, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, "After "+timeoutStr+"ms sleeping", string(body))
// }
// go test -run Test_TimeoutWithCustomError
func Test_TimeoutWithCustomError(t *testing.T) {
// fiber instance
app := fiber.New()
h := New(func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil {
return fmt.Errorf("%w: execution error", err)
}
return nil
}, 100*time.Millisecond, ErrFooTimeOut)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
testTimeout("300")
testTimeout("500")
testSucces("50")
testSucces("30")
}

// testTimeout("15")
// testSucces("2")
// testTimeout("30")
// testSucces("3")
// }

// // go test -run -v Test_Timeout_Panic
// func Test_Timeout_Panic(t *testing.T) {
// app := fiber.New(fiber.Config{DisableStartupMessage: true})

// app.Get("/panic", recover.New(), New(func(c *fiber.Ctx) error {
// c.Set("dummy", "this should not be here")
// panic("panic in timeout handler")
// }, 5*time.Millisecond))

// resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil))
// utils.AssertEqual(t, nil, err, "app.Test(req)")
// utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")

// body, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, "Request Timeout", string(body))
// }
func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
timer := time.NewTimer(d)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return te
case <-timer.C:
}
return nil
}