Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

馃殌 Make IP validation 2x faster #2158

Merged
merged 4 commits into from Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
128 changes: 82 additions & 46 deletions ctx.go
Expand Up @@ -684,72 +684,108 @@ func (c *Ctx) IP() string {
return c.fasthttp.RemoteIP().String()
}

// validateIPIfEnabled will return the input IP when validation is disabled.
// when validation is enabled, it will return an empty string if the input is not a valid IP.
func (c *Ctx) validateIPIfEnabled(ip string) string {
if c.app.config.EnableIPValidation && net.ParseIP(ip) == nil {
return ""
}
return ip
}

// extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear.
// When IP validation is enabled, any invalid IPs will be omitted.
func (c *Ctx) extractIPsFromHeader(header string) (ipsFound []string) {
func (c *Ctx) extractIPsFromHeader(header string) []string {
headerValue := c.Get(header)

// try to gather IPs in the input with minimal allocations to improve performance
ips := make([]string, bytes.Count([]byte(headerValue), []byte(","))+1)
var commaPos, i, validCount int
// We can't know how many IPs we will return, but we will try to guess with this constant division.
// Counting ',' makes function slower for about 50ns in general case.
estimatedCount := len(headerValue) / 8
if estimatedCount > 8 {
estimatedCount = 8 // Avoid big allocation on big header
}

ipsFound := make([]string, 0, estimatedCount)

i := 0
j := -1

iploop:
for {
commaPos = bytes.IndexByte([]byte(headerValue), ',')
if commaPos != -1 {
ips[i] = c.validateIPIfEnabled(utils.Trim(headerValue[:commaPos], ' '))
if ips[i] != "" {
validCount++
}
headerValue, i = headerValue[commaPos+1:], i+1
} else {
ips[i] = c.validateIPIfEnabled(utils.Trim(headerValue, ' '))
if ips[i] != "" {
validCount++
}
v4 := false
v6 := false

// Manually splitting string without allocating slice, working with parts directly
i, j = j+1, j+2

if j > len(headerValue) {
break
}
}

// filter out any invalid IP(s) that we found
if len(ips) == validCount {
ipsFound = ips
} else {
ipsFound = make([]string, validCount)
var validIndex int
for n := range ips {
if ips[n] != "" {
ipsFound[validIndex] = ips[n]
validIndex++
for j < len(headerValue) && headerValue[j] != ',' {
if headerValue[j] == ':' {
v6 = true
} else if headerValue[j] == '.' {
v4 = true
}
j++
}

for i < j && headerValue[i] == ' ' {
i++
}

s := utils.TrimRight(headerValue[i:j], ' ')

if c.app.config.EnableIPValidation {
// Skip validation if IP is clearly not IPv4/IPv6, otherwise validate without allocations
if (!v6 && !v4) || (v6 && !utils.IsIPv6(s)) || (v4 && !utils.IsIPv4(s)) {
continue iploop
}
}

ipsFound = append(ipsFound, s)
}
return

return ipsFound
}

// extractIPFromHeader will attempt to pull the real client IP from the given header when IP validation is enabled.
// currently, it will return the first valid IP address in header.
// when IP validation is disabled, it will simply return the value of the header without any inspection.
// Implementation is almost the same as in extractIPsFromHeader, but without allocation of []string.
func (c *Ctx) extractIPFromHeader(header string) string {
if c.app.config.EnableIPValidation {
// extract all IPs from the header's value
ips := c.extractIPsFromHeader(header)

// since X-Forwarded-For has no RFC, it's really up to the proxy to decide whether to append
// or prepend IPs to this list. For example, the AWS ALB will prepend but the F5 BIG-IP will append ;(
// for now lets just go with the first value in the list...
if len(ips) > 0 {
return ips[0]
headerValue := c.Get(header)

i := 0
j := -1

iploop:
for {
v4 := false
v6 := false
i, j = j+1, j+2

if j > len(headerValue) {
break
}

for j < len(headerValue) && headerValue[j] != ',' {
if headerValue[j] == ':' {
v6 = true
} else if headerValue[j] == '.' {
v4 = true
}
j++
}

for i < j && headerValue[i] == ' ' {
i++
}

s := utils.TrimRight(headerValue[i:j], ' ')

if c.app.config.EnableIPValidation {
if (!v6 && !v4) || (v6 && !utils.IsIPv6(s)) || (v4 && !utils.IsIPv4(s)) {
continue iploop
}
}

return s
}

// return the IP from the stack if we could not find any valid Ips
return c.fasthttp.RemoteIP().String()
}

Expand Down
36 changes: 36 additions & 0 deletions ctx_test.go
Expand Up @@ -1252,6 +1252,10 @@ func Test_Ctx_IPs(t *testing.T) {
c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.3, 127.0.0.1, 127.0.0.2")
utils.AssertEqual(t, []string{"127.0.0.3", "127.0.0.1", "127.0.0.2"}, c.IPs())

// ensure for IPv6
c.Request().Header.Set(HeaderXForwardedFor, "9396:9549:b4f7:8ed0:4791:1330:8c06:e62d, invalid, 2345:0425:2CA1::0567:5673:23b5")
utils.AssertEqual(t, []string{"9396:9549:b4f7:8ed0:4791:1330:8c06:e62d", "invalid", "2345:0425:2CA1::0567:5673:23b5"}, c.IPs())

// empty header
c.Request().Header.Set(HeaderXForwardedFor, "")
utils.AssertEqual(t, 0, len(c.IPs()))
Expand Down Expand Up @@ -1285,6 +1289,10 @@ func Test_Ctx_IPs_With_IP_Validation(t *testing.T) {
c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.3, 127.0.0.1, 127.0.0.2")
utils.AssertEqual(t, []string{"127.0.0.3", "127.0.0.1", "127.0.0.2"}, c.IPs())

// ensure for IPv6
c.Request().Header.Set(HeaderXForwardedFor, "f037:825e:eadb:1b7b:1667:6f0a:5356:f604, invalid, 9396:9549:b4f7:8ed0:4791:1330:8c06:e62d")
utils.AssertEqual(t, []string{"f037:825e:eadb:1b7b:1667:6f0a:5356:f604", "9396:9549:b4f7:8ed0:4791:1330:8c06:e62d"}, c.IPs())

// empty header
c.Request().Header.Set(HeaderXForwardedFor, "")
utils.AssertEqual(t, 0, len(c.IPs()))
Expand All @@ -1309,6 +1317,20 @@ func Benchmark_Ctx_IPs(b *testing.B) {
utils.AssertEqual(b, []string{"127.0.0.1", "invalid", "127.0.0.1"}, res)
}

func Benchmark_Ctx_IPs_v6(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set(HeaderXForwardedFor, "f037:825e:eadb:1b7b:1667:6f0a:5356:f604, invalid, 2345:0425:2CA1::0567:5673:23b5")
var res []string
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
res = c.IPs()
}
utils.AssertEqual(b, []string{"f037:825e:eadb:1b7b:1667:6f0a:5356:f604", "invalid", "2345:0425:2CA1::0567:5673:23b5"}, res)
}

func Benchmark_Ctx_IPs_With_IP_Validation(b *testing.B) {
app := New(Config{EnableIPValidation: true})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
Expand All @@ -1323,6 +1345,20 @@ func Benchmark_Ctx_IPs_With_IP_Validation(b *testing.B) {
utils.AssertEqual(b, []string{"127.0.0.1", "127.0.0.1"}, res)
}

func Benchmark_Ctx_IPs_v6_With_IP_Validation(b *testing.B) {
app := New(Config{EnableIPValidation: true})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set(HeaderXForwardedFor, "2345:0425:2CA1:0000:0000:0567:5673:23b5, invalid, 2345:0425:2CA1::0567:5673:23b5")
var res []string
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
res = c.IPs()
}
utils.AssertEqual(b, []string{"2345:0425:2CA1:0000:0000:0567:5673:23b5", "2345:0425:2CA1::0567:5673:23b5"}, res)
}

func Benchmark_Ctx_IP_With_ProxyHeader(b *testing.B) {
app := New(Config{ProxyHeader: HeaderXForwardedFor})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
Expand Down
141 changes: 141 additions & 0 deletions utils/ips.go
@@ -0,0 +1,141 @@
package utils

import "net"

// IsIPv4 works the same way as net.ParseIP,
// but without check for IPv6 case and without returning net.IP slice, whereby IsIPv4 makes no allocations.
func IsIPv4(s string) bool {
for i := 0; i < net.IPv4len; i++ {
if len(s) == 0 {
return false
}

if i > 0 {
if s[0] != '.' {
return false
}
s = s[1:]
}

n, ci := 0, 0

for ci = 0; ci < len(s) && '0' <= s[ci] && s[ci] <= '9'; ci++ {
n = n*10 + int(s[ci]-'0')
if n >= 0xFF {
return false
}
}

if ci == 0 || n > 0xFF || (ci > 1 && s[0] == '0') {
return false
}

s = s[ci:]
}

return len(s) == 0
}

// IsIPv6 works the same way as net.ParseIP,
// but without check for IPv4 case and without returning net.IP slice, whereby IsIPv6 makes no allocations.
func IsIPv6(s string) bool {
ellipsis := -1 // position of ellipsis in ip

// Might have leading ellipsis
if len(s) >= 2 && s[0] == ':' && s[1] == ':' {
ellipsis = 0
s = s[2:]
// Might be only ellipsis
if len(s) == 0 {
return true
}
}

// Loop, parsing hex numbers followed by colon.
i := 0
for i < net.IPv6len {
// Hex number.
n, ci := 0, 0

for ci = 0; ci < len(s); ci++ {
if '0' <= s[ci] && s[ci] <= '9' {
n *= 16
n += int(s[ci] - '0')
} else if 'a' <= s[ci] && s[ci] <= 'f' {
n *= 16
n += int(s[ci]-'a') + 10
} else if 'A' <= s[ci] && s[ci] <= 'F' {
n *= 16
n += int(s[ci]-'A') + 10
} else {
break
}
if n > 0xFFFF {
return false
}
}
if ci == 0 || n > 0xFFFF {
return false
}

if ci < len(s) && s[ci] == '.' {
if ellipsis < 0 && i != net.IPv6len-net.IPv4len {
return false
}
if i+net.IPv4len > net.IPv6len {
return false
}

if !IsIPv4(s) {
return false
}

s = ""
i += net.IPv4len
break
}

// Save this 16-bit chunk.
i += 2

// Stop at end of string.
s = s[ci:]
if len(s) == 0 {
break
}

// Otherwise must be followed by colon and more.
if s[0] != ':' || len(s) == 1 {
return false
}
s = s[1:]

// Look for ellipsis.
if s[0] == ':' {
if ellipsis >= 0 { // already have one
return false
}
ellipsis = i
s = s[1:]
if len(s) == 0 { // can be at end
break
}
}
}

// Must have used entire string.
if len(s) != 0 {
return false
}

// If didn't parse enough, expand ellipsis.
if i < net.IPv6len {
if ellipsis < 0 {
return false
}
} else if ellipsis >= 0 {
// Ellipsis must represent at least one 0 group.
return false
}
return true
}