diff --git a/middleware/timeout/README.md b/middleware/timeout/README.md index 4406d1ffac..4cc68d0710 100644 --- a/middleware/timeout/README.md +++ b/middleware/timeout/README.md @@ -1,5 +1,9 @@ # Timeout -Timeout middleware for [Fiber](https://github.com/gofiber/fiber) wraps a `fiber.Handler` with a timeout. If the handler takes longer than the given duration to return, the timeout error is set and forwarded to the centralized [ErrorHandler](https://docs.gofiber.io/error-handling). +Timeout middleware for Fiber. As a `fiber.Handler` wrapper, it creates a context with `context.WithTimeout` and pass it in `UserContext`. + +If the context passed executions (eg. DB ops, Http calls) takes longer than the given duration to return, the timeout error is set and forwarded to the centralized `ErrorHandler`. + +It has no race conditions, ready to use on production. ### Table of Contents - [Signatures](#signatures) @@ -8,7 +12,7 @@ Timeout middleware for [Fiber](https://github.com/gofiber/fiber) wraps a `fiber. ### Signatures ```go -func New(h fiber.Handler, t time.Duration) fiber.Handler +func New(handler fiber.Handler, timeout time.Duration, timeoutErrors ...error) fiber.Handler ``` ### Examples @@ -20,15 +24,76 @@ import ( ) ``` -After you initiate your Fiber app, you can use the following possibilities: +Sample timeout middleware usage ```go -handler := func(ctx *fiber.Ctx) error { - err := ctx.SendString("Hello, World 👋!") - if err != nil { - return err +func main() { + app := fiber.New() + h := func(c *fiber.Ctx) error { + sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms") + if err := sleepWithContext(c.UserContext(), sleepTime); err != nil { + return fmt.Errorf("%w: execution error", err) + } + return nil + } + + app.Get("/foo/:sleepTime", timeout.New(h, 2*time.Second)) + _ = app.Listen(":3000") +} + +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return context.DeadlineExceeded + case <-timer.C: } return nil } +``` + +Test http 200 with curl: +```bash +curl --location -I --request GET 'http://localhost:3000/foo/1000' +``` + +Test http 408 with curl: +```bash +curl --location -I --request GET 'http://localhost:3000/foo/3000' +``` + + +When using with custom error: +```go +var ErrFooTimeOut = errors.New("foo context canceled") + +func main() { + app := fiber.New() + h := func(c *fiber.Ctx) error { + sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms") + if err := sleepWithContextWithCustomError(c.UserContext(), sleepTime); err != nil { + return fmt.Errorf("%w: execution error", err) + } + return nil + } + + app.Get("/foo/:sleepTime", timeout.New(h, 2*time.Second), ErrFooTimeOut) + _ = app.Listen(":3000") +} -app.Get("/foo", timeout.New(handler, 5 * time.Second)) +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return ErrFooTimeOut + case <-timer.C: + } + return nil +} ``` diff --git a/middleware/timeout/timeout.go b/middleware/timeout/timeout.go index 6cb980aa01..9f1fd21997 100644 --- a/middleware/timeout/timeout.go +++ b/middleware/timeout/timeout.go @@ -1,43 +1,30 @@ package timeout import ( - "fmt" - "sync" + "context" + "errors" "time" "github.com/gofiber/fiber/v2" ) -var once sync.Once - -// New wraps a handler and aborts the process of the handler if the timeout is reached -func New(handler fiber.Handler, timeout time.Duration) fiber.Handler { - once.Do(func() { - fmt.Println("[Warning] timeout contains data race issues, not ready for production!") - }) - - if timeout <= 0 { - return handler - } - - // logic is from fasthttp.TimeoutWithCodeHandler https://github.com/valyala/fasthttp/blob/master/server.go#L418 +// New implementation of timeout middleware. Set custom errors(context.DeadlineExceeded vs) for get fiber.ErrRequestTimeout response. +func New(h fiber.Handler, t time.Duration, tErrs ...error) fiber.Handler { return func(ctx *fiber.Ctx) error { - ch := make(chan struct{}, 1) - - go func() { - defer func() { - _ = recover() - }() - _ = handler(ctx) - ch <- struct{}{} - }() - - select { - case <-ch: - case <-time.After(timeout): - return fiber.ErrRequestTimeout + timeoutContext, cancel := context.WithTimeout(ctx.UserContext(), t) + defer cancel() + ctx.SetUserContext(timeoutContext) + if err := h(ctx); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return fiber.ErrRequestTimeout + } + for i := range tErrs { + if errors.Is(err, tErrs[i]) { + return fiber.ErrRequestTimeout + } + } + return err } - return nil } } diff --git a/middleware/timeout/timeout_test.go b/middleware/timeout/timeout_test.go index dc60e2eba5..225eabc152 100644 --- a/middleware/timeout/timeout_test.go +++ b/middleware/timeout/timeout_test.go @@ -1,55 +1,84 @@ package timeout -// // go test -run Test_Middleware_Timeout -// func Test_Middleware_Timeout(t *testing.T) { -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) +import ( + "context" + "errors" + "fmt" + "net/http/httptest" + "testing" + "time" -// h := New(func(c *fiber.Ctx) error { -// sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms") -// time.Sleep(sleepTime) -// return c.SendString("After " + c.Params("sleepTime") + "ms sleeping") -// }, 5*time.Millisecond) -// app.Get("/test/:sleepTime", h) + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/utils" +) -// testTimeout := func(timeoutStr string) { -// resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) -// utils.AssertEqual(t, nil, err, "app.Test(req)") -// utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") +// go test -run Test_Timeout +func Test_Timeout(t *testing.T) { + // fiber instance + app := fiber.New() + h := New(func(c *fiber.Ctx) error { + sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms") + if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil { + return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err)) + } + return nil + }, 100*time.Millisecond) + app.Get("/test/:sleepTime", h) + testTimeout := func(timeoutStr string) { + resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") + } + testSucces := func(timeoutStr string) { + resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") + } + testTimeout("300") + testTimeout("500") + testSucces("50") + testSucces("30") +} -// body, err := ioutil.ReadAll(resp.Body) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, "Request Timeout", string(body)) -// } -// testSucces := func(timeoutStr string) { -// resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) -// utils.AssertEqual(t, nil, err, "app.Test(req)") -// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") +var ErrFooTimeOut = errors.New("foo context canceled") -// body, err := ioutil.ReadAll(resp.Body) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, "After "+timeoutStr+"ms sleeping", string(body)) -// } +// go test -run Test_TimeoutWithCustomError +func Test_TimeoutWithCustomError(t *testing.T) { + // fiber instance + app := fiber.New() + h := New(func(c *fiber.Ctx) error { + sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms") + if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil { + return fmt.Errorf("%w: execution error", err) + } + return nil + }, 100*time.Millisecond, ErrFooTimeOut) + app.Get("/test/:sleepTime", h) + testTimeout := func(timeoutStr string) { + resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") + } + testSucces := func(timeoutStr string) { + resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") + } + testTimeout("300") + testTimeout("500") + testSucces("50") + testSucces("30") +} -// testTimeout("15") -// testSucces("2") -// testTimeout("30") -// testSucces("3") -// } - -// // go test -run -v Test_Timeout_Panic -// func Test_Timeout_Panic(t *testing.T) { -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/panic", recover.New(), New(func(c *fiber.Ctx) error { -// c.Set("dummy", "this should not be here") -// panic("panic in timeout handler") -// }, 5*time.Millisecond)) - -// resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil)) -// utils.AssertEqual(t, nil, err, "app.Test(req)") -// utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") - -// body, err := ioutil.ReadAll(resp.Body) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, "Request Timeout", string(body)) -// } +func sleepWithContext(ctx context.Context, d time.Duration, te error) error { + timer := time.NewTimer(d) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return te + case <-timer.C: + } + return nil +}