From ff4e71f144166b1dfe3017a146f8ed32a82e688b Mon Sep 17 00:00:00 2001 From: Euan Kemp Date: Thu, 17 Oct 2019 17:48:19 -0700 Subject: [PATCH] Guess the scheme if r.URL.Scheme is unset (#474) * Guess the scheme if r.URL.Scheme is unset It's not expected that the request's URL is fully populated when used on the server-side (it's more of a client-side field), so we shouldn't expect it to be present. In practice, it's only rarely set at all on the server, making mux's `Schemes` matcher tricky to use as it is. This commit adds a test which would have failed before demonstrating the problem, as well as a fix which I think makes `.Schemes` match what users expect. * [doc] Add more detail to Schemes and URL godocs * Add route url test for schemes * Make httpserver test use more specific scheme matchers * Update test to have different responses per route --- mux_httpserver_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++ mux_test.go | 6 ++---- old_test.go | 17 +++++---------- route.go | 32 ++++++++++++++++++++++++--- 4 files changed, 85 insertions(+), 19 deletions(-) create mode 100644 mux_httpserver_test.go diff --git a/mux_httpserver_test.go b/mux_httpserver_test.go new file mode 100644 index 00000000..5d2f4d3a --- /dev/null +++ b/mux_httpserver_test.go @@ -0,0 +1,49 @@ +// +build go1.9 + +package mux + +import ( + "bytes" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" +) + +func TestSchemeMatchers(t *testing.T) { + router := NewRouter() + router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { + rw.Write([]byte("hello http world")) + }).Schemes("http") + router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { + rw.Write([]byte("hello https world")) + }).Schemes("https") + + assertResponseBody := func(t *testing.T, s *httptest.Server, expectedBody string) { + resp, err := s.Client().Get(s.URL) + if err != nil { + t.Fatalf("unexpected error getting from server: %v", err) + } + if resp.StatusCode != 200 { + t.Fatalf("expected a status code of 200, got %v", resp.StatusCode) + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unexpected error reading body: %v", err) + } + if !bytes.Equal(body, []byte(expectedBody)) { + t.Fatalf("response should be hello world, was: %q", string(body)) + } + } + + t.Run("httpServer", func(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + assertResponseBody(t, s, "hello http world") + }) + t.Run("httpsServer", func(t *testing.T) { + s := httptest.NewTLSServer(router) + defer s.Close() + assertResponseBody(t, s, "hello https world") + }) +} diff --git a/mux_test.go b/mux_test.go index edcee572..9a740bb8 100644 --- a/mux_test.go +++ b/mux_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/http/httptest" "net/url" "reflect" "strings" @@ -2895,10 +2896,7 @@ func newRequestWithHeaders(method, url string, headers ...string) *http.Request // newRequestHost a new request with a method, url, and host header func newRequestHost(method, url, host string) *http.Request { - req, err := http.NewRequest(method, url, nil) - if err != nil { - panic(err) - } + req := httptest.NewRequest(method, url, nil) req.Host = host return req } diff --git a/old_test.go b/old_test.go index b228983c..f088a951 100644 --- a/old_test.go +++ b/old_test.go @@ -385,6 +385,11 @@ var urlBuildingTests = []urlBuildingTest{ vars: []string{"subdomain", "foo", "category", "technology", "id", "42"}, url: "http://foo.domain.com/articles/technology/42", }, + { + route: new(Route).Host("example.com").Schemes("https", "http"), + vars: []string{}, + url: "https://example.com", + }, } func TestHeaderMatcher(t *testing.T) { @@ -502,18 +507,6 @@ func TestUrlBuilding(t *testing.T) { url := u.String() if url != v.url { t.Errorf("expected %v, got %v", v.url, url) - /* - reversePath := "" - reverseHost := "" - if v.route.pathTemplate != nil { - reversePath = v.route.pathTemplate.Reverse - } - if v.route.hostTemplate != nil { - reverseHost = v.route.hostTemplate.Reverse - } - - t.Errorf("%#v:\nexpected: %q\ngot: %q\nreverse path: %q\nreverse host: %q", v.route, v.url, url, reversePath, reverseHost) - */ } } diff --git a/route.go b/route.go index 4098cb79..750afe57 100644 --- a/route.go +++ b/route.go @@ -412,11 +412,30 @@ func (r *Route) Queries(pairs ...string) *Route { type schemeMatcher []string func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool { - return matchInArray(m, r.URL.Scheme) + scheme := r.URL.Scheme + // https://golang.org/pkg/net/http/#Request + // "For [most] server requests, fields other than Path and RawQuery will be + // empty." + // Since we're an http muxer, the scheme is either going to be http or https + // though, so we can just set it based on the tls termination state. + if scheme == "" { + if r.TLS == nil { + scheme = "http" + } else { + scheme = "https" + } + } + return matchInArray(m, scheme) } // Schemes adds a matcher for URL schemes. // It accepts a sequence of schemes to be matched, e.g.: "http", "https". +// If the request's URL has a scheme set, it will be matched against. +// Generally, the URL scheme will only be set if a previous handler set it, +// such as the ProxyHeaders handler from gorilla/handlers. +// If unset, the scheme will be determined based on the request's TLS +// termination state. +// The first argument to Schemes will be used when constructing a route URL. func (r *Route) Schemes(schemes ...string) *Route { for k, v := range schemes { schemes[k] = strings.ToLower(v) @@ -493,8 +512,8 @@ func (r *Route) Subrouter() *Router { // This also works for host variables: // // r := mux.NewRouter() -// r.Host("{subdomain}.domain.com"). -// HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). +// r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). +// Host("{subdomain}.domain.com"). // Name("article") // // // url.String() will be "http://news.domain.com/articles/technology/42" @@ -502,6 +521,13 @@ func (r *Route) Subrouter() *Router { // "category", "technology", // "id", "42") // +// The scheme of the resulting url will be the first argument that was passed to Schemes: +// +// // url.String() will be "https://example.com" +// r := mux.NewRouter() +// url, err := r.Host("example.com") +// .Schemes("https", "http").URL() +// // All variables defined in the route are required, and their values must // conform to the corresponding patterns. func (r *Route) URL(pairs ...string) (*url.URL, error) {