diff --git a/app.go b/app.go index fcc84effa7..ee695a1223 100644 --- a/app.go +++ b/app.go @@ -569,12 +569,13 @@ func (app *App) Mount(prefix string, fiber *App) Router { // Assign name to specific route. func (app *App) Name(name string) Router { + latestRoute.mu.Lock() if strings.HasPrefix(latestRoute.route.path, latestGroup.prefix) { latestRoute.route.Name = latestGroup.name + name } else { latestRoute.route.Name = name } - + latestRoute.mu.Unlock() return app } diff --git a/client_test.go b/client_test.go index 81480b29ec..f9ab5fdae1 100644 --- a/client_test.go +++ b/client_test.go @@ -953,7 +953,7 @@ func Test_Client_Agent_Timeout(t *testing.T) { go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() a := Get("http://example.com"). - Timeout(time.Millisecond * 100) + Timeout(time.Millisecond * 50) a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } diff --git a/ctx.go b/ctx.go index 96380b3b95..efb9b312f1 100644 --- a/ctx.go +++ b/ctx.go @@ -1060,6 +1060,51 @@ func (c *Ctx) Redirect(location string, status ...int) error { return nil } +// get URL location from route using parameters +func (c *Ctx) getLocationFromRoute(route Route, params Map) (string, error) { + buf := bytebufferpool.Get() + for _, segment := range route.routeParser.segs { + if segment.IsParam { + for key, val := range params { + if key == segment.ParamName || segment.IsGreedy { + _, err := buf.WriteString(utils.ToString(val)) + if err != nil { + return "", err + } + } + } + } else { + _, err := buf.WriteString(segment.Const) + if err != nil { + return "", err + } + } + } + location := buf.String() + bytebufferpool.Put(buf) + return location, nil +} + +// RedirectToRoute to the Route registered in the app with appropriate parameters +// If status is not specified, status defaults to 302 Found. +func (c *Ctx) RedirectToRoute(routeName string, params Map, status ...int) error { + location, err := c.getLocationFromRoute(c.App().GetRoute(routeName), params) + if err != nil { + return err + } + return c.Redirect(location, status...) +} + +// RedirectBack to the URL to referer +// If status is not specified, status defaults to 302 Found. +func (c *Ctx) RedirectBack(fallback string, status ...int) error { + location := c.Get(HeaderReferer) + if location == "" { + location = fallback + } + return c.Redirect(location, status...) +} + // Render a template with data and sends a text/html response. // We support the following engines: html, amber, handlebars, mustache, pug func (c *Ctx) Render(name string, bind interface{}, layouts ...string) error { diff --git a/ctx_test.go b/ctx_test.go index d016d61ba8..333bdb6b64 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -2035,6 +2035,105 @@ func Test_Ctx_Redirect(t *testing.T) { utils.AssertEqual(t, "http://example.com", string(c.Response().Header.Peek(HeaderLocation))) } +// go test -run Test_Ctx_RedirectToRouteWithParams +func Test_Ctx_RedirectToRouteWithParams(t *testing.T) { + t.Parallel() + app := New() + app.Get("/user/:name", func(c *Ctx) error { + return c.JSON(c.Params("name")) + }).Name("user") + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + + c.RedirectToRoute("user", Map{ + "name": "fiber", + }) + utils.AssertEqual(t, 302, c.Response().StatusCode()) + utils.AssertEqual(t, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation))) +} + +// go test -run Test_Ctx_RedirectToRouteWithOptionalParams +func Test_Ctx_RedirectToRouteWithOptionalParams(t *testing.T) { + t.Parallel() + app := New() + app.Get("/user/:name?", func(c *Ctx) error { + return c.JSON(c.Params("name")) + }).Name("user") + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + + c.RedirectToRoute("user", Map{ + "name": "fiber", + }) + utils.AssertEqual(t, 302, c.Response().StatusCode()) + utils.AssertEqual(t, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation))) +} + +// go test -run Test_Ctx_RedirectToRouteWithOptionalParamsWithoutValue +func Test_Ctx_RedirectToRouteWithOptionalParamsWithoutValue(t *testing.T) { + t.Parallel() + app := New() + app.Get("/user/:name?", func(c *Ctx) error { + return c.JSON(c.Params("name")) + }).Name("user") + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + + c.RedirectToRoute("user", Map{}) + utils.AssertEqual(t, 302, c.Response().StatusCode()) + utils.AssertEqual(t, "/user/", string(c.Response().Header.Peek(HeaderLocation))) +} + +// go test -run Test_Ctx_RedirectToRouteWithGreedyParameters +func Test_Ctx_RedirectToRouteWithGreedyParameters(t *testing.T) { + t.Parallel() + app := New() + app.Get("/user/+", func(c *Ctx) error { + return c.JSON(c.Params("+")) + }).Name("user") + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + + c.RedirectToRoute("user", Map{ + "+": "test/routes", + }) + utils.AssertEqual(t, 302, c.Response().StatusCode()) + utils.AssertEqual(t, "/user/test/routes", string(c.Response().Header.Peek(HeaderLocation))) +} + +// go test -run Test_Ctx_RedirectBack +func Test_Ctx_RedirectBack(t *testing.T) { + t.Parallel() + app := New() + app.Get("/", func(c *Ctx) error { + return c.JSON("Home") + }).Name("home") + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.RedirectBack("/") + utils.AssertEqual(t, 302, c.Response().StatusCode()) + utils.AssertEqual(t, "/", string(c.Response().Header.Peek(HeaderLocation))) +} + +// go test -run Test_Ctx_RedirectBackWithReferer +func Test_Ctx_RedirectBackWithReferer(t *testing.T) { + t.Parallel() + app := New() + app.Get("/", func(c *Ctx) error { + return c.JSON("Home") + }).Name("home") + app.Get("/back", func(c *Ctx) error { + return c.JSON("Back") + }).Name("back") + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.Request().Header.Set(HeaderReferer, "/back") + c.RedirectBack("/") + utils.AssertEqual(t, 302, c.Response().StatusCode()) + utils.AssertEqual(t, "/back", c.Get(HeaderReferer)) + utils.AssertEqual(t, "/back", string(c.Response().Header.Peek(HeaderLocation))) +} + // go test -run Test_Ctx_Render func Test_Ctx_Render(t *testing.T) { t.Parallel() @@ -2300,6 +2399,19 @@ func Benchmark_Ctx_Render_Engine(b *testing.B) { utils.AssertEqual(b, "

Hello, World!

", string(c.Response().Body())) } +// go test -v -run=^$ -bench=Benchmark_Ctx_Get_Location_From_Route -benchmem -count=4 +func Benchmark_Ctx_Get_Location_From_Route(b *testing.B) { + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + app.Get("/user/:name", func(c *Ctx) error { + return c.SendString(c.Params("name")) + }).Name("User") + for n := 0; n < b.N; n++ { + c.getLocationFromRoute(app.GetRoute("User"), Map{"name": "fiber"}) + } +} + type errorTemplateEngine struct{} func (t errorTemplateEngine) Render(w io.Writer, name string, bind interface{}, layout ...string) error { diff --git a/utils/convert.go b/utils/convert.go index ae99723438..32b04b7cda 100644 --- a/utils/convert.go +++ b/utils/convert.go @@ -5,9 +5,11 @@ package utils import ( + "fmt" "reflect" "strconv" "strings" + "time" "unsafe" ) @@ -83,3 +85,51 @@ func ByteSize(bytes uint64) string { result = strings.TrimSuffix(result, ".0") return result + unit } + +// ToString Change arg to string +func ToString(arg interface{}, timeFormat ...string) string { + var tmp = reflect.Indirect(reflect.ValueOf(arg)).Interface() + switch v := tmp.(type) { + case int: + return strconv.Itoa(v) + case int8: + return strconv.FormatInt(int64(v), 10) + case int16: + return strconv.FormatInt(int64(v), 10) + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case uint: + return strconv.Itoa(int(v)) + case uint8: + return strconv.FormatInt(int64(v), 10) + case uint16: + return strconv.FormatInt(int64(v), 10) + case uint32: + return strconv.FormatInt(int64(v), 10) + case uint64: + return strconv.FormatInt(int64(v), 10) + case string: + return v + case []byte: + return string(v) + case bool: + return strconv.FormatBool(v) + case float32: + return strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + return strconv.FormatFloat(v, 'f', -1, 64) + case time.Time: + if len(timeFormat) > 0 { + return v.Format(timeFormat[0]) + } + return v.Format("2006-01-02 15:04:05") + case reflect.Value: + return ToString(v.Interface(), timeFormat...) + case fmt.Stringer: + return v.String() + default: + return "" + } +} diff --git a/utils/convert_test.go b/utils/convert_test.go index d788dead42..59ce625e53 100644 --- a/utils/convert_test.go +++ b/utils/convert_test.go @@ -61,3 +61,21 @@ func Test_CopyString(t *testing.T) { res := CopyString("Hello, World!") AssertEqual(t, "Hello, World!", res) } + +func Test_ToString(t *testing.T) { + t.Parallel() + res := ToString([]byte("Hello, World!")) + AssertEqual(t, "Hello, World!", res) + res = ToString(true) + AssertEqual(t, "true", res) + res = ToString(uint(100)) + AssertEqual(t, "100", res) +} + +// go test -v -run=^$ -bench=ToString -benchmem -count=2 +func Benchmark_ToString(b *testing.B) { + hello := []byte("Hello, World!") + for n := 0; n < b.N; n++ { + ToString(hello) + } +}