diff --git a/app.go b/app.go index 3091bc403c..a5f0b77de4 100644 --- a/app.go +++ b/app.go @@ -552,6 +552,10 @@ func (app *App) handleTrustedProxy(ipAddress string) { func (app *App) Mount(prefix string, fiber *App) Router { stack := fiber.Stack() prefix = strings.TrimRight(prefix, "/") + if prefix == "" { + prefix = "/" + } + for m := range stack { for r := range stack[m] { route := app.copyRoute(stack[m][r]) @@ -1075,7 +1079,7 @@ func (app *App) ErrorHandler(ctx *Ctx, err error) error { ) for prefix, subApp := range app.appList { - if strings.HasPrefix(ctx.path, prefix) && prefix != "" { + if prefix != "" && strings.HasPrefix(ctx.path, prefix) { parts := len(strings.Split(prefix, "/")) if mountedPrefixParts <= parts { mountedErrHandler = subApp.config.ErrorHandler diff --git a/app_test.go b/app_test.go index 178fdbea4e..d1a40ea684 100644 --- a/app_test.go +++ b/app_test.go @@ -43,6 +43,15 @@ func testStatus200(t *testing.T, app *App, url string, method string) { utils.AssertEqual(t, 200, resp.StatusCode, "Status code") } +func testErrorResponse(t *testing.T, err error, resp *http.Response, expectedBodyError string) { + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 500, resp.StatusCode, "Status code") + + body, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, expectedBodyError, string(body), "Response body") +} + func Test_App_MethodNotAllowed(t *testing.T) { app := New() @@ -256,12 +265,26 @@ func Test_App_ErrorHandler_GroupMount(t *testing.T) { v1.Mount("/john", micro) resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", nil)) - utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, 500, resp.StatusCode, "Status code") + testErrorResponse(t, err, resp, "1: custom error") +} - body, err := ioutil.ReadAll(resp.Body) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "1: custom error", string(body)) +func Test_App_ErrorHandler_GroupMountRootLevel(t *testing.T) { + micro := New(Config{ + ErrorHandler: func(c *Ctx, err error) error { + utils.AssertEqual(t, "0: GET error", err.Error()) + return c.Status(500).SendString("1: custom error") + }, + }) + micro.Get("/john/doe", func(c *Ctx) error { + return errors.New("0: GET error") + }) + + app := New() + v1 := app.Group("/v1") + v1.Mount("/", micro) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", nil)) + testErrorResponse(t, err, resp, "1: custom error") } func Test_App_Nested_Params(t *testing.T) { @@ -1603,7 +1626,7 @@ func Test_App_UseMountedErrorHandler(t *testing.T) { fiber := New(Config{ ErrorHandler: func(ctx *Ctx, err error) error { - return ctx.Status(200).SendString("hi, i'm a custom error") + return ctx.Status(500).SendString("hi, i'm a custom error") }, }) fiber.Get("/", func(c *Ctx) error { @@ -1613,12 +1636,25 @@ func Test_App_UseMountedErrorHandler(t *testing.T) { app.Mount("/api", fiber) resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", nil)) - utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + testErrorResponse(t, err, resp, "hi, i'm a custom error") +} - b, err := ioutil.ReadAll(resp.Body) - utils.AssertEqual(t, nil, err, "iotuil.ReadAll()") - utils.AssertEqual(t, "hi, i'm a custom error", string(b), "Response body") +func Test_App_UseMountedErrorHandlerRootLevel(t *testing.T) { + app := New() + + fiber := New(Config{ + ErrorHandler: func(ctx *Ctx, err error) error { + return ctx.Status(500).SendString("hi, i'm a custom error") + }, + }) + fiber.Get("/api", func(c *Ctx) error { + return errors.New("something happened") + }) + + app.Mount("/", fiber) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", nil)) + testErrorResponse(t, err, resp, "hi, i'm a custom error") } func Test_App_UseMountedErrorHandlerForBestPrefixMatch(t *testing.T) { diff --git a/group.go b/group.go index 20caf1fa68..65c14e0520 100644 --- a/group.go +++ b/group.go @@ -25,6 +25,10 @@ type Group struct { func (grp *Group) Mount(prefix string, fiber *App) Router { stack := fiber.Stack() groupPath := getGroupPath(grp.Prefix, prefix) + groupPath = strings.TrimRight(groupPath, "/") + if groupPath == "" { + groupPath = "/" + } for m := range stack { for r := range stack[m] { @@ -34,7 +38,6 @@ func (grp *Group) Mount(prefix string, fiber *App) Router { } // Support for configs of mounted-apps and sub-mounted-apps - groupPath = strings.TrimRight(groupPath, "/") for mountedPrefixes, subApp := range fiber.appList { grp.app.appList[groupPath+mountedPrefixes] = subApp subApp.init()