From 51b0355262146afaffe70a2657df4fca1bb5e8c8 Mon Sep 17 00:00:00 2001 From: Franklin Harding Date: Wed, 23 Oct 2019 22:42:37 -0700 Subject: [PATCH 1/4] Remove context helpers in context.go --- context.go | 18 ------------------ mux.go | 19 +++++++++++-------- test_helpers.go | 2 +- 3 files changed, 12 insertions(+), 27 deletions(-) delete mode 100644 context.go diff --git a/context.go b/context.go deleted file mode 100644 index 665940a2..00000000 --- a/context.go +++ /dev/null @@ -1,18 +0,0 @@ -package mux - -import ( - "context" - "net/http" -) - -func contextGet(r *http.Request, key interface{}) interface{} { - return r.Context().Value(key) -} - -func contextSet(r *http.Request, key, val interface{}) *http.Request { - if val == nil { - return r - } - - return r.WithContext(context.WithValue(r.Context(), key, val)) -} diff --git a/mux.go b/mux.go index 26f9582a..ce795870 100644 --- a/mux.go +++ b/mux.go @@ -5,6 +5,7 @@ package mux import ( + "context" "errors" "fmt" "net/http" @@ -195,8 +196,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { var handler http.Handler if r.Match(req, &match) { handler = match.Handler - req = setVars(req, match.Vars) - req = setCurrentRoute(req, match.Route) + req = requestWithVars(req, match.Vars) + req = requestWithRoute(req, match.Route) } if handler == nil && match.MatchErr == ErrMethodMismatch { @@ -426,7 +427,7 @@ const ( // Vars returns the route variables for the current request, if any. func Vars(r *http.Request) map[string]string { - if rv := contextGet(r, varsKey); rv != nil { + if rv := r.Context().Value(varsKey); rv != nil { return rv.(map[string]string) } return nil @@ -438,18 +439,20 @@ func Vars(r *http.Request) map[string]string { // after the handler returns, unless the KeepContext option is set on the // Router. func CurrentRoute(r *http.Request) *Route { - if rv := contextGet(r, routeKey); rv != nil { + if rv := r.Context().Value(routeKey); rv != nil { return rv.(*Route) } return nil } -func setVars(r *http.Request, val interface{}) *http.Request { - return contextSet(r, varsKey, val) +func requestWithVars(r *http.Request, val interface{}) *http.Request { + ctx := context.WithValue(r.Context(), varsKey, val) + return r.WithContext(ctx) } -func setCurrentRoute(r *http.Request, val interface{}) *http.Request { - return contextSet(r, routeKey, val) +func requestWithRoute(r *http.Request, val interface{}) *http.Request { + ctx := context.WithValue(r.Context(), routeKey, val) + return r.WithContext(ctx) } // ---------------------------------------------------------------------------- diff --git a/test_helpers.go b/test_helpers.go index 32ecffde..5f5c496d 100644 --- a/test_helpers.go +++ b/test_helpers.go @@ -15,5 +15,5 @@ import "net/http" // can be set by making a route that captures the required variables, // starting a server and sending the request to that server. func SetURLVars(r *http.Request, val map[string]string) *http.Request { - return setVars(r, val) + return requestWithVars(r, val) } From faac3f3a7eab3fd370da6dd6f476eed3b665e880 Mon Sep 17 00:00:00 2001 From: Franklin Harding Date: Wed, 23 Oct 2019 22:47:15 -0700 Subject: [PATCH 2/4] Update request context funcs to take concrete types --- mux.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mux.go b/mux.go index ce795870..48c4d1e1 100644 --- a/mux.go +++ b/mux.go @@ -445,13 +445,13 @@ func CurrentRoute(r *http.Request) *Route { return nil } -func requestWithVars(r *http.Request, val interface{}) *http.Request { - ctx := context.WithValue(r.Context(), varsKey, val) +func requestWithVars(r *http.Request, vars map[string]string) *http.Request { + ctx := context.WithValue(r.Context(), varsKey, vars) return r.WithContext(ctx) } -func requestWithRoute(r *http.Request, val interface{}) *http.Request { - ctx := context.WithValue(r.Context(), routeKey, val) +func requestWithRoute(r *http.Request, route *Route) *http.Request { + ctx := context.WithValue(r.Context(), routeKey, route) return r.WithContext(ctx) } From 04da6ef88555f2cc69e86d198fdb10581b79baed Mon Sep 17 00:00:00 2001 From: Franklin Harding Date: Wed, 23 Oct 2019 22:55:37 -0700 Subject: [PATCH 3/4] Move TestNativeContextMiddleware to mux_test.go --- context_test.go | 30 ------------------------------ mux_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 24 insertions(+), 30 deletions(-) delete mode 100644 context_test.go diff --git a/context_test.go b/context_test.go deleted file mode 100644 index d8a56b42..00000000 --- a/context_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package mux - -import ( - "context" - "net/http" - "testing" - "time" -) - -func TestNativeContextMiddleware(t *testing.T) { - withTimeout := func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), time.Minute) - defer cancel() - h.ServeHTTP(w, r.WithContext(ctx)) - }) - } - - r := NewRouter() - r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - vars := Vars(r) - if vars["foo"] != "bar" { - t.Fatal("Expected foo var to be set") - } - }))) - - rec := NewRecorder() - req := newRequest("GET", "/path/bar") - r.ServeHTTP(rec, req) -} diff --git a/mux_test.go b/mux_test.go index 9a740bb8..1c906689 100644 --- a/mux_test.go +++ b/mux_test.go @@ -7,6 +7,7 @@ package mux import ( "bufio" "bytes" + "context" "errors" "fmt" "io/ioutil" @@ -16,6 +17,7 @@ import ( "reflect" "strings" "testing" + "time" ) func (r *Route) GoString() string { @@ -2804,6 +2806,28 @@ func TestSubrouterNotFound(t *testing.T) { } } +func TestContextMiddleware(t *testing.T) { + withTimeout := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), time.Minute) + defer cancel() + h.ServeHTTP(w, r.WithContext(ctx)) + }) + } + + r := NewRouter() + r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + vars := Vars(r) + if vars["foo"] != "bar" { + t.Fatal("Expected foo var to be set") + } + }))) + + rec := NewRecorder() + req := newRequest("GET", "/path/bar") + r.ServeHTTP(rec, req) +} + // mapToPairs converts a string map to a slice of string pairs func mapToPairs(m map[string]string) []string { var i int From d6708c0f9e15c9b336662a2a8e1c25e48eac7c50 Mon Sep 17 00:00:00 2001 From: Franklin Harding Date: Wed, 23 Oct 2019 23:13:43 -0700 Subject: [PATCH 4/4] Clarify KeepContext Go 1.7+ comment Mux doesn't build on Go < 1.7 so the comment doesn't really need to clarify anymore. --- mux.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mux.go b/mux.go index 48c4d1e1..c9ba6470 100644 --- a/mux.go +++ b/mux.go @@ -59,8 +59,7 @@ type Router struct { // If true, do not clear the request context after handling the request. // - // Deprecated: No effect when go1.7+ is used, since the context is stored - // on the request itself. + // Deprecated: No effect, since the context is stored on the request itself. KeepContext bool // Slice of middlewares to be called after a match is found