diff --git a/ctx.go b/ctx.go index acd18494d5..68536c9a24 100644 --- a/ctx.go +++ b/ctx.go @@ -1289,3 +1289,23 @@ func (c *Ctx) IsProxyTrusted() bool { _, trustProxy := c.app.config.trustedProxiesMap[c.fasthttp.RemoteIP().String()] return trustProxy } + +// IsLocalHost will return true if address is a localhost address. +func (c *Ctx) isLocalHost(address string) bool { + localHosts := []string{"127.0.0.1", "0.0.0.0", "::1"} + for _, h := range localHosts { + if strings.Contains(address, h) { + return true + } + } + return false +} + +// IsFromLocal will return true if request came from local. +func (c *Ctx) IsFromLocal() bool { + ips := c.IPs() + if len(ips) == 0 { + ips = append(ips, c.IP()) + } + return c.isLocalHost(ips[0]) +} diff --git a/ctx_test.go b/ctx_test.go index e1bec56646..28c0894d6b 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -2751,3 +2751,58 @@ func Test_Ctx_GetRespHeader(t *testing.T) { utils.AssertEqual(t, c.GetRespHeader("test"), "Hello, World 👋!") utils.AssertEqual(t, c.GetRespHeader(HeaderContentType), "application/json") } + +// go test -run Test_Ctx_IsFromLocal +func Test_Ctx_IsFromLocal(t *testing.T) { + t.Parallel() + // Test "0.0.0.0", "127.0.0.1" and "::1". + { + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + utils.AssertEqual(t, true, c.IsFromLocal()) + } + // This is a test for "0.0.0.0" + { + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Header.Set(HeaderXForwardedFor, "0.0.0.0") + defer app.ReleaseCtx(c) + utils.AssertEqual(t, true, c.IsFromLocal()) + } + + // This is a test for "127.0.0.1" + { + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1") + defer app.ReleaseCtx(c) + utils.AssertEqual(t, true, c.IsFromLocal()) + } + + // This is a test for "localhost" + { + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + utils.AssertEqual(t, true, c.IsFromLocal()) + } + + // This is testing "::1", it is the compressed format IPV6 loopback address 0:0:0:0:0:0:0:1. + // It is the equivalent of the IPV4 address 127.0.0.1. + { + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Header.Set(HeaderXForwardedFor, "::1") + defer app.ReleaseCtx(c) + utils.AssertEqual(t, true, c.IsFromLocal()) + } + + { + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Header.Set(HeaderXForwardedFor, "93.46.8.90") + defer app.ReleaseCtx(c) + utils.AssertEqual(t, false, c.IsFromLocal()) + } +}