Skip to content

Commit

Permalink
fixes #2016 - make IP() and IPs() more reliable
Browse files Browse the repository at this point in the history
  • Loading branch information
gbolo committed Aug 17, 2022
1 parent 648e662 commit 539d1e7
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 21 deletions.
50 changes: 33 additions & 17 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,33 +649,49 @@ func (c *Ctx) Port() string {
}

// IP returns the remote IP address of the request.
// If ProxyHeader is configured, it will parse that header and return the first valid IP address
// Please use Config.EnableTrustedProxyCheck to prevent header spoofing, in case when your app is behind the proxy.
func (c *Ctx) IP() string {
if c.IsProxyTrusted() && len(c.app.config.ProxyHeader) > 0 {
return c.Get(c.app.config.ProxyHeader)
return c.extractIPFromHeader(c.app.config.ProxyHeader)
}

return c.fasthttp.RemoteIP().String()
}

// IPs returns an string slice of IP addresses specified in the X-Forwarded-For request header.
func (c *Ctx) IPs() (ips []string) {
header := c.fasthttp.Request.Header.Peek(HeaderXForwardedFor)
if len(header) == 0 {
return
}
ips = make([]string, bytes.Count(header, []byte(","))+1)
var commaPos, i int
for {
commaPos = bytes.IndexByte(header, ',')
if commaPos != -1 {
ips[i] = utils.Trim(c.app.getString(header[:commaPos]), ' ')
header, i = header[commaPos+1:], i+1
} else {
ips[i] = utils.Trim(c.app.getString(header), ' ')
return
// extractValidIPs will return a slice of strings that represent valid IP addresses
// in the input string. The order is maintained. The separator is a comma
func extractValidIPs(input string) (validIPs []string) {
unvalidatedIps := strings.Split(input, ",")
for _, ip := range unvalidatedIps {
if parsedIp := net.ParseIP(strings.TrimSpace(ip)); parsedIp != nil {
validIPs = append(validIPs, parsedIp.String())
}
}
return
}

// extractIPFromHeader will attempt to pull the real client IP from the given header
// currently it will return the first valid IP address in header
func (c *Ctx) extractIPFromHeader(header string) string {
// extract only valid IPs from the header's value
validIps := extractValidIPs(c.Get(header))

// since X-Forwarded-For has no RFC, it's really up to the proxy to decide whether to append
// or prepend IPs to this list. For example, the AWS ALB will prepend but the F5 BIG-IP will append ;(
// for now lets just go with the first value in the list...
if len(validIps) > 0 {
return validIps[0]
}

// return the IP from the stack if we could not find any valid Ips
return c.fasthttp.RemoteIP().String()
}

// IPs returns a string slice of IP addresses specified in the X-Forwarded-For request header.
// Only valid IP addresses are returned
func (c *Ctx) IPs() (ips []string) {
return extractValidIPs(c.Get(HeaderXForwardedFor))
}

// Is returns the matching content type,
Expand Down
58 changes: 54 additions & 4 deletions ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1099,19 +1099,51 @@ func Test_Ctx_PortInHandler(t *testing.T) {
// go test -run Test_Ctx_IP
func Test_Ctx_IP(t *testing.T) {
t.Parallel()

app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)

// default behaviour will return the remote IP from the stack
utils.AssertEqual(t, "0.0.0.0", c.IP())

// X-Forwarded-For is set, but it is ignored because proxyHeader is not set
c.Request().Header.Set(HeaderXForwardedFor, "0.0.0.1")
utils.AssertEqual(t, "0.0.0.0", c.IP())
}

// go test -run Test_Ctx_IP_ProxyHeader
func Test_Ctx_IP_ProxyHeader(t *testing.T) {
t.Parallel()
app := New(Config{ProxyHeader: "Real-Ip"})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
utils.AssertEqual(t, "", c.IP())

// make sure that the same behaviour exists for different proxy header names
proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor}

for _, proxyHeaderName := range proxyHeaderNames {
app := New(Config{ProxyHeader: proxyHeaderName})
c := app.AcquireCtx(&fasthttp.RequestCtx{})

// when proxy header is enabled and the value is a valid IP, we return it
c.Request().Header.Set(proxyHeaderName, "0.0.0.1")
utils.AssertEqual(t, "0.0.0.1", c.IP())

// when proxy header is enabled and the value is a list of IPs, we return the first valid IP
c.Request().Header.Set(proxyHeaderName, "0.0.0.1, 0.0.0.2")
utils.AssertEqual(t, "0.0.0.1", c.IP())

c.Request().Header.Set(proxyHeaderName, "invalid, 0.0.0.2, 0.0.0.3")
utils.AssertEqual(t, "0.0.0.2", c.IP())

// when proxy header is enabled but the value is empty, we will ignore the header
c.Request().Header.Set(proxyHeaderName, "")
utils.AssertEqual(t, "0.0.0.0", c.IP())

// when proxy header is enabled but the value is not an IP, we will ignore the header
c.Request().Header.Set(proxyHeaderName, "not-valid-ip")
utils.AssertEqual(t, "0.0.0.0", c.IP())

app.ReleaseCtx(c)
}
}

// go test -run Test_Ctx_IP_UntrustedProxy
Expand Down Expand Up @@ -1140,14 +1172,32 @@ func Test_Ctx_IPs(t *testing.T) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)

// normal happy path test case
c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1, 127.0.0.2, 127.0.0.3")
utils.AssertEqual(t, []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}, c.IPs())

// inconsistent space formatting
c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1,127.0.0.2 ,127.0.0.3")
utils.AssertEqual(t, []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}, c.IPs())

// invalid IPs are in the header
c.Request().Header.Set(HeaderXForwardedFor, "invalid, 127.0.0.1, 127.0.0.2")
utils.AssertEqual(t, []string{"127.0.0.1", "127.0.0.2"}, c.IPs())
c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1, invalid, 127.0.0.2")
utils.AssertEqual(t, []string{"127.0.0.1", "127.0.0.2"}, c.IPs())

// ensure that the ordering of IPs in the header is maintained
c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.3, 127.0.0.1, 127.0.0.2")
utils.AssertEqual(t, []string{"127.0.0.3", "127.0.0.1", "127.0.0.2"}, c.IPs())

// empty header
c.Request().Header.Set(HeaderXForwardedFor, "")
utils.AssertEqual(t, 0, len(c.IPs()))

// missing header
c.Request()
utils.AssertEqual(t, 0, len(c.IPs()))
}

// go test -v -run=^$ -bench=Benchmark_Ctx_IPs -benchmem -count=4
Expand Down

0 comments on commit 539d1e7

Please sign in to comment.