Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guess the scheme if r.URL.Scheme is unset #474

Merged
merged 5 commits into from Oct 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 49 additions & 0 deletions mux_httpserver_test.go
@@ -0,0 +1,49 @@
// +build go1.9

package mux

import (
"bytes"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)

func TestSchemeMatchers(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense for this test to have different responses for the different schemes so we can use that to differentiate? I believe this test would pass if the httpsServer subtest was matching on the http route.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right. I made it so there's one router that has both matchers, and that each matcher has a different response. I think that should better test for what we care about.
Thanks for the review!

router := NewRouter()
router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("hello http world"))
}).Schemes("http")
router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("hello https world"))
}).Schemes("https")

assertResponseBody := func(t *testing.T, s *httptest.Server, expectedBody string) {
resp, err := s.Client().Get(s.URL)
if err != nil {
t.Fatalf("unexpected error getting from server: %v", err)
}
if resp.StatusCode != 200 {
t.Fatalf("expected a status code of 200, got %v", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unexpected error reading body: %v", err)
}
if !bytes.Equal(body, []byte(expectedBody)) {
t.Fatalf("response should be hello world, was: %q", string(body))
}
}

t.Run("httpServer", func(t *testing.T) {
s := httptest.NewServer(router)
defer s.Close()
assertResponseBody(t, s, "hello http world")
})
t.Run("httpsServer", func(t *testing.T) {
s := httptest.NewTLSServer(router)
defer s.Close()
assertResponseBody(t, s, "hello https world")
})
}
6 changes: 2 additions & 4 deletions mux_test.go
Expand Up @@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
Expand Down Expand Up @@ -2837,10 +2838,7 @@ func newRequestWithHeaders(method, url string, headers ...string) *http.Request

// newRequestHost a new request with a method, url, and host header
func newRequestHost(method, url, host string) *http.Request {
req, err := http.NewRequest(method, url, nil)
if err != nil {
panic(err)
}
req := httptest.NewRequest(method, url, nil)
req.Host = host
return req
}
17 changes: 5 additions & 12 deletions old_test.go
Expand Up @@ -385,6 +385,11 @@ var urlBuildingTests = []urlBuildingTest{
vars: []string{"subdomain", "foo", "category", "technology", "id", "42"},
url: "http://foo.domain.com/articles/technology/42",
},
{
route: new(Route).Host("example.com").Schemes("https", "http"),
vars: []string{},
url: "https://example.com",
},
}

func TestHeaderMatcher(t *testing.T) {
Expand Down Expand Up @@ -502,18 +507,6 @@ func TestUrlBuilding(t *testing.T) {
url := u.String()
if url != v.url {
t.Errorf("expected %v, got %v", v.url, url)
/*
reversePath := ""
reverseHost := ""
if v.route.pathTemplate != nil {
reversePath = v.route.pathTemplate.Reverse
}
if v.route.hostTemplate != nil {
reverseHost = v.route.hostTemplate.Reverse
}

t.Errorf("%#v:\nexpected: %q\ngot: %q\nreverse path: %q\nreverse host: %q", v.route, v.url, url, reversePath, reverseHost)
*/
}
}

Expand Down
32 changes: 29 additions & 3 deletions route.go
Expand Up @@ -412,11 +412,30 @@ func (r *Route) Queries(pairs ...string) *Route {
type schemeMatcher []string

func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, r.URL.Scheme)
scheme := r.URL.Scheme
// https://golang.org/pkg/net/http/#Request
// "For [most] server requests, fields other than Path and RawQuery will be
// empty."
// Since we're an http muxer, the scheme is either going to be http or https
// though, so we can just set it based on the tls termination state.
if scheme == "" {
if r.TLS == nil {
scheme = "http"
} else {
scheme = "https"
}
}
return matchInArray(m, scheme)
}

// Schemes adds a matcher for URL schemes.
// It accepts a sequence of schemes to be matched, e.g.: "http", "https".
// If the request's URL has a scheme set, it will be matched against.
// Generally, the URL scheme will only be set if a previous handler set it,
// such as the ProxyHeaders handler from gorilla/handlers.
// If unset, the scheme will be determined based on the request's TLS
// termination state.
// The first argument to Schemes will be used when constructing a route URL.
func (r *Route) Schemes(schemes ...string) *Route {
for k, v := range schemes {
schemes[k] = strings.ToLower(v)
Expand Down Expand Up @@ -493,15 +512,22 @@ func (r *Route) Subrouter() *Router {
// This also works for host variables:
//
// r := mux.NewRouter()
// r.Host("{subdomain}.domain.com").
// HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
// r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
// Host("{subdomain}.domain.com").
// Name("article")
//
// // url.String() will be "http://news.domain.com/articles/technology/42"
// url, err := r.Get("article").URL("subdomain", "news",
// "category", "technology",
// "id", "42")
//
// The scheme of the resulting url will be the first argument that was passed to Schemes:
//
// // url.String() will be "https://example.com"
// r := mux.NewRouter()
// url, err := r.Host("example.com")
// .Schemes("https", "http").URL()
//
// All variables defined in the route are required, and their values must
// conform to the corresponding patterns.
func (r *Route) URL(pairs ...string) (*url.URL, error) {
Expand Down