Skip to content

Commit

Permalink
Fix using IP ranges in config.TrustedProxies (#1607) (#1614)
Browse files Browse the repository at this point in the history
* Fix using IP ranges in config.TrustedProxies (#1607)

* Add tests

* Remove debugging var

* Remove tests

* Update test

Co-authored-by: RW <rene@gofiber.io>
  • Loading branch information
hi019 and ReneWerner87 committed Dec 31, 2021
1 parent 59240b5 commit eee279b
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 33 deletions.
27 changes: 12 additions & 15 deletions app.go
Expand Up @@ -350,8 +350,9 @@ type Config struct {
// Read EnableTrustedProxyCheck doc.
//
// Default: []string
TrustedProxies []string `json:"trusted_proxies"`
trustedProxiesMap map[string]struct{}
TrustedProxies []string `json:"trusted_proxies"`
trustedProxiesMap map[string]struct{}
trustedProxyRanges []*net.IPNet

//If set to true, will print all routes with their method, path and handler.
// Default: false
Expand Down Expand Up @@ -501,8 +502,8 @@ func New(config ...Config) *App {
}

app.config.trustedProxiesMap = make(map[string]struct{}, len(app.config.TrustedProxies))
for _, ip := range app.config.TrustedProxies {
app.handleTrustedProxy(ip)
for _, ipAddress := range app.config.TrustedProxies {
app.handleTrustedProxy(ipAddress)
}

// Init app
Expand All @@ -512,23 +513,19 @@ func New(config ...Config) *App {
return app
}

// Checks if the given IP address is a range whether or not, adds it to the trustedProxiesMap
// Adds an ip address to trustedProxyRanges or trustedProxiesMap based on whether it is an IP range or not
func (app *App) handleTrustedProxy(ipAddress string) {
// Detects IP address is range whether or not
if strings.Contains(ipAddress, "/") {
// Parsing IP address
ip, ipnet, err := net.ParseCIDR(ipAddress)
_, ipNet, err := net.ParseCIDR(ipAddress)

if err != nil {
fmt.Printf("[Warning] IP range `%s` could not be parsed. \n", ipAddress)
return
}
// Iterates IP address which is between range
for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); utils.IncrementIPRange(ip) {
app.config.trustedProxiesMap[ip.String()] = struct{}{}
}
return

app.config.trustedProxyRanges = append(app.config.trustedProxyRanges, ipNet)
} else {
app.config.trustedProxiesMap[ipAddress] = struct{}{}
}
app.config.trustedProxiesMap[ipAddress] = struct{}{}
}

// Mount attaches another app instance as a sub-router along a routing path.
Expand Down
14 changes: 12 additions & 2 deletions ctx.go
Expand Up @@ -1310,8 +1310,18 @@ func (c *Ctx) IsProxyTrusted() bool {
return true
}

_, trustProxy := c.app.config.trustedProxiesMap[c.fasthttp.RemoteIP().String()]
return trustProxy
_, trusted := c.app.config.trustedProxiesMap[c.fasthttp.RemoteIP().String()]
if trusted {
return trusted
}

for _, ipNet := range c.app.config.trustedProxyRanges {
if ipNet.Contains(c.fasthttp.RemoteIP()) {
return true
}
}

return false
}

// IsLocalHost will return true if address is a localhost address.
Expand Down
92 changes: 76 additions & 16 deletions ctx_test.go
Expand Up @@ -985,6 +985,30 @@ func Test_Ctx_Hostname_TrustedProxy(t *testing.T) {
}
}

// go test -run Test_Ctx_Hostname_UntrustedProxyRange
func Test_Ctx_Hostname_TrustedProxyRange(t *testing.T) {
t.Parallel()

app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.0.0.0/30"}})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
utils.AssertEqual(t, "google1.com", c.Hostname())
app.ReleaseCtx(c)
}

// go test -run Test_Ctx_Hostname_UntrustedProxyRange
func Test_Ctx_Hostname_UntrustedProxyRange(t *testing.T) {
t.Parallel()

app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"1.0.0.0/30"}})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
utils.AssertEqual(t, "google.com", c.Hostname())
app.ReleaseCtx(c)
}

// go test -run Test_Ctx_Port
func Test_Ctx_Port(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -1032,22 +1056,6 @@ func Test_Ctx_IP_TrustedProxy(t *testing.T) {
utils.AssertEqual(t, "0.0.0.1", c.IP())
}

// go test -run Test_Ctx_IP_Range_TrustedProxy
func Test_Ctx_IP_Range_TrustedProxy(t *testing.T) {
t.Parallel()
app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.0.0.0", "1.1.1.1/30", "1.1.1.1/100"}, ProxyHeader: HeaderXForwardedFor})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
expected := map[string]struct{}{
"0.0.0.0": {},
"1.1.1.0": {},
"1.1.1.1": {},
"1.1.1.2": {},
"1.1.1.3": {},
}
utils.AssertEqual(t, expected, app.config.trustedProxiesMap)
}

// go test -run Test_Ctx_IPs -parallel
func Test_Ctx_IPs(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -1399,6 +1407,58 @@ func Test_Ctx_Protocol_TrustedProxy(t *testing.T) {
utils.AssertEqual(t, "http", c.Protocol())
}

// go test -run Test_Ctx_Protocol_TrustedProxyRange
func Test_Ctx_Protocol_TrustedProxyRange(t *testing.T) {
t.Parallel()
app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.0.0.0/30"}})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)

c.Request().Header.Set(HeaderXForwardedProto, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Reset()

c.Request().Header.Set(HeaderXForwardedProtocol, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Reset()

c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Reset()

c.Request().Header.Set(HeaderXUrlScheme, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Reset()

utils.AssertEqual(t, "http", c.Protocol())
}

// go test -run Test_Ctx_Protocol_UntrustedProxyRange
func Test_Ctx_Protocol_UntrustedProxyRange(t *testing.T) {
t.Parallel()
app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"1.1.1.1/30"}})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)

c.Request().Header.Set(HeaderXForwardedProto, "https")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Reset()

c.Request().Header.Set(HeaderXForwardedProtocol, "https")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Reset()

c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Reset()

c.Request().Header.Set(HeaderXUrlScheme, "https")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Reset()

utils.AssertEqual(t, "http", c.Protocol())
}

// go test -run Test_Ctx_Protocol_UnTrustedProxy
func Test_Ctx_Protocol_UnTrustedProxy(t *testing.T) {
t.Parallel()
Expand Down

0 comments on commit eee279b

Please sign in to comment.