diff --git a/regexp.go b/regexp.go index b5a15ed9..96dd94ad 100644 --- a/regexp.go +++ b/regexp.go @@ -230,14 +230,51 @@ func (r *routeRegexp) getURLQuery(req *http.Request) string { return "" } templateKey := strings.SplitN(r.template, "=", 2)[0] - for key, vals := range req.URL.Query() { - if key == templateKey && len(vals) > 0 { - return key + "=" + vals[0] - } + val, ok := findFirstQueryKey(req.URL.RawQuery, templateKey) + if ok { + return templateKey + "=" + val } return "" } +// findFirstQueryKey returns the same result as (*url.URL).Query()[key][0]. +// If key was not found, empty string and false is returned. +func findFirstQueryKey(rawQuery, key string) (value string, ok bool) { + query := []byte(rawQuery) + for len(query) > 0 { + foundKey := query + if i := bytes.IndexAny(foundKey, "&;"); i >= 0 { + foundKey, query = foundKey[:i], foundKey[i+1:] + } else { + query = query[:0] + } + if len(foundKey) == 0 { + continue + } + var value []byte + if i := bytes.IndexByte(foundKey, '='); i >= 0 { + foundKey, value = foundKey[:i], foundKey[i+1:] + } + if len(foundKey) < len(key) { + // Cannot possibly be key. + continue + } + keyString, err := url.QueryUnescape(string(foundKey)) + if err != nil { + continue + } + if keyString != key { + continue + } + valueString, err := url.QueryUnescape(string(value)) + if err != nil { + continue + } + return valueString, true + } + return "", false +} + func (r *routeRegexp) matchQueryString(req *http.Request) bool { return r.regexp.MatchString(r.getURLQuery(req)) } diff --git a/regexp_test.go b/regexp_test.go new file mode 100644 index 00000000..0d80e6a5 --- /dev/null +++ b/regexp_test.go @@ -0,0 +1,91 @@ +package mux + +import ( + "net/url" + "reflect" + "strconv" + "testing" +) + +func Test_findFirstQueryKey(t *testing.T) { + tests := []string{ + "a=1&b=2", + "a=1&a=2&a=banana", + "ascii=%3Ckey%3A+0x90%3E", + "a=1;b=2", + "a=1&a=2;a=banana", + "a==", + "a=%2", + "a=20&%20%3F&=%23+%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09:%2F@$%27%28%29%2A%2C%3B&a=30", + "a=1& ?&=#+%!<>#\"{}|\\^[]`☺\t:/@$'()*,;&a=5", + "a=xxxxxxxxxxxxxxxx&b=YYYYYYYYYYYYYYY&c=ppppppppppppppppppp&f=ttttttttttttttttt&a=uuuuuuuuuuuuu", + } + for _, query := range tests { + t.Run(query, func(t *testing.T) { + // Check against url.ParseQuery, ignoring the error. + all, _ := url.ParseQuery(query) + for key, want := range all { + t.Run(key, func(t *testing.T) { + got, ok := findFirstQueryKey(query, key) + if !ok { + t.Error("Did not get expected key", key) + } + if !reflect.DeepEqual(got, want[0]) { + t.Errorf("findFirstQueryKey(%s,%s) = %v, want %v", query, key, got, want[0]) + } + }) + } + }) + } +} + +func Benchmark_findQueryKey(b *testing.B) { + tests := []string{ + "a=1&b=2", + "ascii=%3Ckey%3A+0x90%3E", + "a=20&%20%3F&=%23+%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09:%2F@$%27%28%29%2A%2C%3B&a=30", + "a=xxxxxxxxxxxxxxxx&bbb=YYYYYYYYYYYYYYY&cccc=ppppppppppppppppppp&ddddd=ttttttttttttttttt&a=uuuuuuuuuuuuu", + "a=;b=;c=;d=;e=;f=;g=;h=;i=,j=;k=", + } + for i, query := range tests { + b.Run(strconv.Itoa(i), func(b *testing.B) { + // Check against url.ParseQuery, ignoring the error. + all, _ := url.ParseQuery(query) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + for key, _ := range all { + _, _ = findFirstQueryKey(query, key) + } + } + }) + } +} + +func Benchmark_findQueryKeyGoLib(b *testing.B) { + tests := []string{ + "a=1&b=2", + "ascii=%3Ckey%3A+0x90%3E", + "a=20&%20%3F&=%23+%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09:%2F@$%27%28%29%2A%2C%3B&a=30", + "a=xxxxxxxxxxxxxxxx&bbb=YYYYYYYYYYYYYYY&cccc=ppppppppppppppppppp&ddddd=ttttttttttttttttt&a=uuuuuuuuuuuuu", + "a=;b=;c=;d=;e=;f=;g=;h=;i=,j=;k=", + } + for i, query := range tests { + b.Run(strconv.Itoa(i), func(b *testing.B) { + // Check against url.ParseQuery, ignoring the error. + all, _ := url.ParseQuery(query) + var u url.URL + u.RawQuery = query + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + for key, _ := range all { + v := u.Query()[key] + if len(v) > 0 { + _ = v[0] + } + } + } + }) + } +}