diff --git a/ctx.go b/ctx.go index a6959e6941..327b319386 100644 --- a/ctx.go +++ b/ctx.go @@ -688,72 +688,108 @@ func (c *Ctx) IP() string { return c.fasthttp.RemoteIP().String() } -// validateIPIfEnabled will return the input IP when validation is disabled. -// when validation is enabled, it will return an empty string if the input is not a valid IP. -func (c *Ctx) validateIPIfEnabled(ip string) string { - if c.app.config.EnableIPValidation && net.ParseIP(ip) == nil { - return "" - } - return ip -} - // extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear. // When IP validation is enabled, any invalid IPs will be omitted. -func (c *Ctx) extractIPsFromHeader(header string) (ipsFound []string) { +func (c *Ctx) extractIPsFromHeader(header string) []string { headerValue := c.Get(header) - // try to gather IPs in the input with minimal allocations to improve performance - ips := make([]string, bytes.Count([]byte(headerValue), []byte(","))+1) - var commaPos, i, validCount int + // We can't know how many IPs we will return, but we will try to guess with this constant division. + // Counting ',' makes function slower for about 50ns in general case. + estimatedCount := len(headerValue) / 8 + if estimatedCount > 8 { + estimatedCount = 8 // Avoid big allocation on big header + } + + ipsFound := make([]string, 0, estimatedCount) + + i := 0 + j := -1 + +iploop: for { - commaPos = bytes.IndexByte([]byte(headerValue), ',') - if commaPos != -1 { - ips[i] = c.validateIPIfEnabled(utils.Trim(headerValue[:commaPos], ' ')) - if ips[i] != "" { - validCount++ - } - headerValue, i = headerValue[commaPos+1:], i+1 - } else { - ips[i] = c.validateIPIfEnabled(utils.Trim(headerValue, ' ')) - if ips[i] != "" { - validCount++ - } + v4 := false + v6 := false + + // Manually splitting string without allocating slice, working with parts directly + i, j = j+1, j+2 + + if j > len(headerValue) { break } - } - // filter out any invalid IP(s) that we found - if len(ips) == validCount { - ipsFound = ips - } else { - ipsFound = make([]string, validCount) - var validIndex int - for n := range ips { - if ips[n] != "" { - ipsFound[validIndex] = ips[n] - validIndex++ + for j < len(headerValue) && headerValue[j] != ',' { + if headerValue[j] == ':' { + v6 = true + } else if headerValue[j] == '.' { + v4 = true } + j++ } + + for i < j && headerValue[i] == ' ' { + i++ + } + + s := utils.TrimRight(headerValue[i:j], ' ') + + if c.app.config.EnableIPValidation { + // Skip validation if IP is clearly not IPv4/IPv6, otherwise validate without allocations + if (!v6 && !v4) || (v6 && !utils.IsIPv6(s)) || (v4 && !utils.IsIPv4(s)) { + continue iploop + } + } + + ipsFound = append(ipsFound, s) } - return + + return ipsFound } // extractIPFromHeader will attempt to pull the real client IP from the given header when IP validation is enabled. // currently, it will return the first valid IP address in header. // when IP validation is disabled, it will simply return the value of the header without any inspection. +// Implementation is almost the same as in extractIPsFromHeader, but without allocation of []string. func (c *Ctx) extractIPFromHeader(header string) string { if c.app.config.EnableIPValidation { - // extract all IPs from the header's value - ips := c.extractIPsFromHeader(header) - - // since X-Forwarded-For has no RFC, it's really up to the proxy to decide whether to append - // or prepend IPs to this list. For example, the AWS ALB will prepend but the F5 BIG-IP will append ;( - // for now lets just go with the first value in the list... - if len(ips) > 0 { - return ips[0] + headerValue := c.Get(header) + + i := 0 + j := -1 + + iploop: + for { + v4 := false + v6 := false + i, j = j+1, j+2 + + if j > len(headerValue) { + break + } + + for j < len(headerValue) && headerValue[j] != ',' { + if headerValue[j] == ':' { + v6 = true + } else if headerValue[j] == '.' { + v4 = true + } + j++ + } + + for i < j && headerValue[i] == ' ' { + i++ + } + + s := utils.TrimRight(headerValue[i:j], ' ') + + if c.app.config.EnableIPValidation { + if (!v6 && !v4) || (v6 && !utils.IsIPv6(s)) || (v4 && !utils.IsIPv4(s)) { + continue iploop + } + } + + return s } - // return the IP from the stack if we could not find any valid Ips return c.fasthttp.RemoteIP().String() } diff --git a/ctx_test.go b/ctx_test.go index 9a60fd447a..ad975de9cf 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -1265,6 +1265,10 @@ func Test_Ctx_IPs(t *testing.T) { c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.3, 127.0.0.1, 127.0.0.2") utils.AssertEqual(t, []string{"127.0.0.3", "127.0.0.1", "127.0.0.2"}, c.IPs()) + // ensure for IPv6 + c.Request().Header.Set(HeaderXForwardedFor, "9396:9549:b4f7:8ed0:4791:1330:8c06:e62d, invalid, 2345:0425:2CA1::0567:5673:23b5") + utils.AssertEqual(t, []string{"9396:9549:b4f7:8ed0:4791:1330:8c06:e62d", "invalid", "2345:0425:2CA1::0567:5673:23b5"}, c.IPs()) + // empty header c.Request().Header.Set(HeaderXForwardedFor, "") utils.AssertEqual(t, 0, len(c.IPs())) @@ -1298,6 +1302,10 @@ func Test_Ctx_IPs_With_IP_Validation(t *testing.T) { c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.3, 127.0.0.1, 127.0.0.2") utils.AssertEqual(t, []string{"127.0.0.3", "127.0.0.1", "127.0.0.2"}, c.IPs()) + // ensure for IPv6 + c.Request().Header.Set(HeaderXForwardedFor, "f037:825e:eadb:1b7b:1667:6f0a:5356:f604, invalid, 9396:9549:b4f7:8ed0:4791:1330:8c06:e62d") + utils.AssertEqual(t, []string{"f037:825e:eadb:1b7b:1667:6f0a:5356:f604", "9396:9549:b4f7:8ed0:4791:1330:8c06:e62d"}, c.IPs()) + // empty header c.Request().Header.Set(HeaderXForwardedFor, "") utils.AssertEqual(t, 0, len(c.IPs())) @@ -1322,6 +1330,20 @@ func Benchmark_Ctx_IPs(b *testing.B) { utils.AssertEqual(b, []string{"127.0.0.1", "invalid", "127.0.0.1"}, res) } +func Benchmark_Ctx_IPs_v6(b *testing.B) { + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.Request().Header.Set(HeaderXForwardedFor, "f037:825e:eadb:1b7b:1667:6f0a:5356:f604, invalid, 2345:0425:2CA1::0567:5673:23b5") + var res []string + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = c.IPs() + } + utils.AssertEqual(b, []string{"f037:825e:eadb:1b7b:1667:6f0a:5356:f604", "invalid", "2345:0425:2CA1::0567:5673:23b5"}, res) +} + func Benchmark_Ctx_IPs_With_IP_Validation(b *testing.B) { app := New(Config{EnableIPValidation: true}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -1336,6 +1358,20 @@ func Benchmark_Ctx_IPs_With_IP_Validation(b *testing.B) { utils.AssertEqual(b, []string{"127.0.0.1", "127.0.0.1"}, res) } +func Benchmark_Ctx_IPs_v6_With_IP_Validation(b *testing.B) { + app := New(Config{EnableIPValidation: true}) + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.Request().Header.Set(HeaderXForwardedFor, "2345:0425:2CA1:0000:0000:0567:5673:23b5, invalid, 2345:0425:2CA1::0567:5673:23b5") + var res []string + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + res = c.IPs() + } + utils.AssertEqual(b, []string{"2345:0425:2CA1:0000:0000:0567:5673:23b5", "2345:0425:2CA1::0567:5673:23b5"}, res) +} + func Benchmark_Ctx_IP_With_ProxyHeader(b *testing.B) { app := New(Config{ProxyHeader: HeaderXForwardedFor}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) diff --git a/utils/ips.go b/utils/ips.go new file mode 100644 index 0000000000..6d2446470a --- /dev/null +++ b/utils/ips.go @@ -0,0 +1,141 @@ +package utils + +import "net" + +// IsIPv4 works the same way as net.ParseIP, +// but without check for IPv6 case and without returning net.IP slice, whereby IsIPv4 makes no allocations. +func IsIPv4(s string) bool { + for i := 0; i < net.IPv4len; i++ { + if len(s) == 0 { + return false + } + + if i > 0 { + if s[0] != '.' { + return false + } + s = s[1:] + } + + n, ci := 0, 0 + + for ci = 0; ci < len(s) && '0' <= s[ci] && s[ci] <= '9'; ci++ { + n = n*10 + int(s[ci]-'0') + if n >= 0xFF { + return false + } + } + + if ci == 0 || n > 0xFF || (ci > 1 && s[0] == '0') { + return false + } + + s = s[ci:] + } + + return len(s) == 0 +} + +// IsIPv6 works the same way as net.ParseIP, +// but without check for IPv4 case and without returning net.IP slice, whereby IsIPv6 makes no allocations. +func IsIPv6(s string) bool { + ellipsis := -1 // position of ellipsis in ip + + // Might have leading ellipsis + if len(s) >= 2 && s[0] == ':' && s[1] == ':' { + ellipsis = 0 + s = s[2:] + // Might be only ellipsis + if len(s) == 0 { + return true + } + } + + // Loop, parsing hex numbers followed by colon. + i := 0 + for i < net.IPv6len { + // Hex number. + n, ci := 0, 0 + + for ci = 0; ci < len(s); ci++ { + if '0' <= s[ci] && s[ci] <= '9' { + n *= 16 + n += int(s[ci] - '0') + } else if 'a' <= s[ci] && s[ci] <= 'f' { + n *= 16 + n += int(s[ci]-'a') + 10 + } else if 'A' <= s[ci] && s[ci] <= 'F' { + n *= 16 + n += int(s[ci]-'A') + 10 + } else { + break + } + if n > 0xFFFF { + return false + } + } + if ci == 0 || n > 0xFFFF { + return false + } + + if ci < len(s) && s[ci] == '.' { + if ellipsis < 0 && i != net.IPv6len-net.IPv4len { + return false + } + if i+net.IPv4len > net.IPv6len { + return false + } + + if !IsIPv4(s) { + return false + } + + s = "" + i += net.IPv4len + break + } + + // Save this 16-bit chunk. + i += 2 + + // Stop at end of string. + s = s[ci:] + if len(s) == 0 { + break + } + + // Otherwise must be followed by colon and more. + if s[0] != ':' || len(s) == 1 { + return false + } + s = s[1:] + + // Look for ellipsis. + if s[0] == ':' { + if ellipsis >= 0 { // already have one + return false + } + ellipsis = i + s = s[1:] + if len(s) == 0 { // can be at end + break + } + } + } + + // Must have used entire string. + if len(s) != 0 { + return false + } + + // If didn't parse enough, expand ellipsis. + if i < net.IPv6len { + if ellipsis < 0 { + return false + } + } else if ellipsis >= 0 { + // Ellipsis must represent at least one 0 group. + return false + } + return true +} diff --git a/utils/ips_test.go b/utils/ips_test.go new file mode 100644 index 0000000000..49c258e604 --- /dev/null +++ b/utils/ips_test.go @@ -0,0 +1,85 @@ +// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ +// 🤖 Github Repository: https://github.com/gofiber/fiber +// 📌 API Documentation: https://docs.gofiber.io + +package utils + +import ( + "net" + "testing" +) + +func Test_IsIPv4(t *testing.T) { + t.Parallel() + + AssertEqual(t, true, IsIPv4("174.23.33.100")) + AssertEqual(t, true, IsIPv4("127.0.0.1")) + AssertEqual(t, true, IsIPv4("0.0.0.0")) + + AssertEqual(t, false, IsIPv4(".0.0.0")) + AssertEqual(t, false, IsIPv4("0.0.0.")) + AssertEqual(t, false, IsIPv4("0.0.0")) + AssertEqual(t, false, IsIPv4(".0.0.0.")) + AssertEqual(t, false, IsIPv4("0.0.0.0.0")) + AssertEqual(t, false, IsIPv4("0")) + AssertEqual(t, false, IsIPv4("")) + AssertEqual(t, false, IsIPv4("2345:0425:2CA1::0567:5673:23b5")) + AssertEqual(t, false, IsIPv4("invalid")) +} + +// go test -v -run=^$ -bench=UnsafeString -benchmem -count=2 + +func Benchmark_IsIPv4(b *testing.B) { + ip := "174.23.33.100" + var res bool + + b.Run("fiber", func(b *testing.B) { + for n := 0; n < b.N; n++ { + res = IsIPv4(ip) + } + AssertEqual(b, true, res) + }) + + b.Run("default", func(b *testing.B) { + for n := 0; n < b.N; n++ { + res = net.ParseIP(ip) != nil + } + AssertEqual(b, true, res) + }) +} + +func Test_IsIPv6(t *testing.T) { + t.Parallel() + + AssertEqual(t, true, IsIPv6("9396:9549:b4f7:8ed0:4791:1330:8c06:e62d")) + AssertEqual(t, true, IsIPv6("2345:0425:2CA1::0567:5673:23b5")) + AssertEqual(t, true, IsIPv6("2001:1:2:3:4:5:6:7")) + + AssertEqual(t, false, IsIPv6("1.1.1.1")) + AssertEqual(t, false, IsIPv6("2001:1:2:3:4:5:6:")) + AssertEqual(t, false, IsIPv6(":1:2:3:4:5:6:")) + AssertEqual(t, false, IsIPv6("1:2:3:4:5:6:")) + AssertEqual(t, false, IsIPv6("")) + AssertEqual(t, false, IsIPv6("invalid")) +} + +// go test -v -run=^$ -bench=UnsafeString -benchmem -count=2 + +func Benchmark_IsIPv6(b *testing.B) { + ip := "9396:9549:b4f7:8ed0:4791:1330:8c06:e62d" + var res bool + + b.Run("fiber", func(b *testing.B) { + for n := 0; n < b.N; n++ { + res = IsIPv6(ip) + } + AssertEqual(b, true, res) + }) + + b.Run("default", func(b *testing.B) { + for n := 0; n < b.N; n++ { + res = net.ParseIP(ip) != nil + } + AssertEqual(b, true, res) + }) +}