diff --git a/mux_httpserver_test.go b/mux_httpserver_test.go new file mode 100644 index 00000000..c46b647a --- /dev/null +++ b/mux_httpserver_test.go @@ -0,0 +1,46 @@ +// +build go1.9 + +package mux + +import ( + "bytes" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" +) + +func TestSchemeMatchers(t *testing.T) { + router := NewRouter() + router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { + rw.Write([]byte("hello world")) + }).Schemes("http", "https") + + assertHelloWorldResponse := func(t *testing.T, s *httptest.Server) { + resp, err := s.Client().Get(s.URL) + if err != nil { + t.Fatalf("unexpected error getting from server: %v", err) + } + if resp.StatusCode != 200 { + t.Fatalf("expected a status code of 200, got %v", resp.StatusCode) + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unexpected error reading body: %v", err) + } + if !bytes.Equal(body, []byte("hello world")) { + t.Fatalf("response should be hello world, was: %q", string(body)) + } + } + + t.Run("httpServer", func(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + assertHelloWorldResponse(t, s) + }) + t.Run("httpsServer", func(t *testing.T) { + s := httptest.NewTLSServer(router) + defer s.Close() + assertHelloWorldResponse(t, s) + }) +} diff --git a/mux_test.go b/mux_test.go index f5c1e9c5..696b5a2f 100644 --- a/mux_test.go +++ b/mux_test.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "net/http" + "net/http/httptest" "net/url" "reflect" "strings" @@ -2837,10 +2838,7 @@ func newRequestWithHeaders(method, url string, headers ...string) *http.Request // newRequestHost a new request with a method, url, and host header func newRequestHost(method, url, host string) *http.Request { - req, err := http.NewRequest(method, url, nil) - if err != nil { - panic(err) - } + req := httptest.NewRequest(method, url, nil) req.Host = host return req } diff --git a/route.go b/route.go index 8479c68c..4f9917a8 100644 --- a/route.go +++ b/route.go @@ -412,7 +412,18 @@ func (r *Route) Queries(pairs ...string) *Route { type schemeMatcher []string func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool { - return matchInArray(m, r.URL.Scheme) + scheme := r.URL.Scheme + // https://golang.org/pkg/net/http/#Request + // "For [most] server requests, fields other than Path and RawQuery will be + // empty." + // Since we're an http muxer, the scheme is either going to be http or https + // though, so we can just set it based on the tls termination state. + if r.URL.Scheme == "" && r.TLS == nil { + scheme = "http" + } else if r.URL.Scheme == "" && r.TLS != nil { + scheme = "https" + } + return matchInArray(m, scheme) } // Schemes adds a matcher for URL schemes.