diff --git a/app.go b/app.go index ebb8418c55..f97564733f 100644 --- a/app.go +++ b/app.go @@ -350,8 +350,9 @@ type Config struct { // Read EnableTrustedProxyCheck doc. // // Default: []string - TrustedProxies []string `json:"trusted_proxies"` - trustedProxiesMap map[string]struct{} + TrustedProxies []string `json:"trusted_proxies"` + trustedProxiesMap map[string]struct{} + trustedProxyRanges []*net.IPNet //If set to true, will print all routes with their method, path and handler. // Default: false @@ -501,8 +502,8 @@ func New(config ...Config) *App { } app.config.trustedProxiesMap = make(map[string]struct{}, len(app.config.TrustedProxies)) - for _, ip := range app.config.TrustedProxies { - app.handleTrustedProxy(ip) + for _, ipAddress := range app.config.TrustedProxies { + app.handleTrustedProxy(ipAddress) } // Init app @@ -512,23 +513,19 @@ func New(config ...Config) *App { return app } -// Checks if the given IP address is a range whether or not, adds it to the trustedProxiesMap +// Adds an ip address to trustedProxyRanges or trustedProxiesMap based on whether it is an IP range or not func (app *App) handleTrustedProxy(ipAddress string) { - // Detects IP address is range whether or not if strings.Contains(ipAddress, "/") { - // Parsing IP address - ip, ipnet, err := net.ParseCIDR(ipAddress) + _, ipNet, err := net.ParseCIDR(ipAddress) + if err != nil { fmt.Printf("[Warning] IP range `%s` could not be parsed. \n", ipAddress) - return - } - // Iterates IP address which is between range - for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); utils.IncrementIPRange(ip) { - app.config.trustedProxiesMap[ip.String()] = struct{}{} } - return + + app.config.trustedProxyRanges = append(app.config.trustedProxyRanges, ipNet) + } else { + app.config.trustedProxiesMap[ipAddress] = struct{}{} } - app.config.trustedProxiesMap[ipAddress] = struct{}{} } // Mount attaches another app instance as a sub-router along a routing path. diff --git a/ctx.go b/ctx.go index df2b754f86..9f06c3df69 100644 --- a/ctx.go +++ b/ctx.go @@ -1310,8 +1310,18 @@ func (c *Ctx) IsProxyTrusted() bool { return true } - _, trustProxy := c.app.config.trustedProxiesMap[c.fasthttp.RemoteIP().String()] - return trustProxy + _, trusted := c.app.config.trustedProxiesMap[c.fasthttp.RemoteIP().String()] + if trusted { + return trusted + } + + for _, ipNet := range c.app.config.trustedProxyRanges { + if ipNet.Contains(c.fasthttp.RemoteIP()) { + return true + } + } + + return false } // IsLocalHost will return true if address is a localhost address. diff --git a/ctx_test.go b/ctx_test.go index 7591372d33..0406ada7d4 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -985,6 +985,30 @@ func Test_Ctx_Hostname_TrustedProxy(t *testing.T) { } } +// go test -run Test_Ctx_Hostname_UntrustedProxyRange +func Test_Ctx_Hostname_TrustedProxyRange(t *testing.T) { + t.Parallel() + + app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.0.0.0/30"}}) + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().SetRequestURI("http://google.com/test") + c.Request().Header.Set(HeaderXForwardedHost, "google1.com") + utils.AssertEqual(t, "google1.com", c.Hostname()) + app.ReleaseCtx(c) +} + +// go test -run Test_Ctx_Hostname_UntrustedProxyRange +func Test_Ctx_Hostname_UntrustedProxyRange(t *testing.T) { + t.Parallel() + + app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"1.0.0.0/30"}}) + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().SetRequestURI("http://google.com/test") + c.Request().Header.Set(HeaderXForwardedHost, "google1.com") + utils.AssertEqual(t, "google.com", c.Hostname()) + app.ReleaseCtx(c) +} + // go test -run Test_Ctx_Port func Test_Ctx_Port(t *testing.T) { t.Parallel() @@ -1032,22 +1056,6 @@ func Test_Ctx_IP_TrustedProxy(t *testing.T) { utils.AssertEqual(t, "0.0.0.1", c.IP()) } -// go test -run Test_Ctx_IP_Range_TrustedProxy -func Test_Ctx_IP_Range_TrustedProxy(t *testing.T) { - t.Parallel() - app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.0.0.0", "1.1.1.1/30", "1.1.1.1/100"}, ProxyHeader: HeaderXForwardedFor}) - c := app.AcquireCtx(&fasthttp.RequestCtx{}) - defer app.ReleaseCtx(c) - expected := map[string]struct{}{ - "0.0.0.0": {}, - "1.1.1.0": {}, - "1.1.1.1": {}, - "1.1.1.2": {}, - "1.1.1.3": {}, - } - utils.AssertEqual(t, expected, app.config.trustedProxiesMap) -} - // go test -run Test_Ctx_IPs -parallel func Test_Ctx_IPs(t *testing.T) { t.Parallel() @@ -1399,6 +1407,58 @@ func Test_Ctx_Protocol_TrustedProxy(t *testing.T) { utils.AssertEqual(t, "http", c.Protocol()) } +// go test -run Test_Ctx_Protocol_TrustedProxyRange +func Test_Ctx_Protocol_TrustedProxyRange(t *testing.T) { + t.Parallel() + app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.0.0.0/30"}}) + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + + c.Request().Header.Set(HeaderXForwardedProto, "https") + utils.AssertEqual(t, "https", c.Protocol()) + c.Request().Header.Reset() + + c.Request().Header.Set(HeaderXForwardedProtocol, "https") + utils.AssertEqual(t, "https", c.Protocol()) + c.Request().Header.Reset() + + c.Request().Header.Set(HeaderXForwardedSsl, "on") + utils.AssertEqual(t, "https", c.Protocol()) + c.Request().Header.Reset() + + c.Request().Header.Set(HeaderXUrlScheme, "https") + utils.AssertEqual(t, "https", c.Protocol()) + c.Request().Header.Reset() + + utils.AssertEqual(t, "http", c.Protocol()) +} + +// go test -run Test_Ctx_Protocol_UntrustedProxyRange +func Test_Ctx_Protocol_UntrustedProxyRange(t *testing.T) { + t.Parallel() + app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"1.1.1.1/30"}}) + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + + c.Request().Header.Set(HeaderXForwardedProto, "https") + utils.AssertEqual(t, "http", c.Protocol()) + c.Request().Header.Reset() + + c.Request().Header.Set(HeaderXForwardedProtocol, "https") + utils.AssertEqual(t, "http", c.Protocol()) + c.Request().Header.Reset() + + c.Request().Header.Set(HeaderXForwardedSsl, "on") + utils.AssertEqual(t, "http", c.Protocol()) + c.Request().Header.Reset() + + c.Request().Header.Set(HeaderXUrlScheme, "https") + utils.AssertEqual(t, "http", c.Protocol()) + c.Request().Header.Reset() + + utils.AssertEqual(t, "http", c.Protocol()) +} + // go test -run Test_Ctx_Protocol_UnTrustedProxy func Test_Ctx_Protocol_UnTrustedProxy(t *testing.T) { t.Parallel()