From 989c92264c1757d84836120439d5987a869c2aa7 Mon Sep 17 00:00:00 2001 From: Franklin Harding Date: Wed, 23 Oct 2019 22:55:37 -0700 Subject: [PATCH] Move TestNativeContextMiddleware to mux_test.go --- context_test.go | 30 ------------------------------ mux_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 24 insertions(+), 30 deletions(-) delete mode 100644 context_test.go diff --git a/context_test.go b/context_test.go deleted file mode 100644 index d8a56b42..00000000 --- a/context_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package mux - -import ( - "context" - "net/http" - "testing" - "time" -) - -func TestNativeContextMiddleware(t *testing.T) { - withTimeout := func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), time.Minute) - defer cancel() - h.ServeHTTP(w, r.WithContext(ctx)) - }) - } - - r := NewRouter() - r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - vars := Vars(r) - if vars["foo"] != "bar" { - t.Fatal("Expected foo var to be set") - } - }))) - - rec := NewRecorder() - req := newRequest("GET", "/path/bar") - r.ServeHTTP(rec, req) -} diff --git a/mux_test.go b/mux_test.go index 9a740bb8..41b3d372 100644 --- a/mux_test.go +++ b/mux_test.go @@ -7,6 +7,7 @@ package mux import ( "bufio" "bytes" + "context" "errors" "fmt" "io/ioutil" @@ -16,6 +17,7 @@ import ( "reflect" "strings" "testing" + "time" ) func (r *Route) GoString() string { @@ -2804,6 +2806,28 @@ func TestSubrouterNotFound(t *testing.T) { } } +func TestNativeContextMiddleware(t *testing.T) { + withTimeout := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), time.Minute) + defer cancel() + h.ServeHTTP(w, r.WithContext(ctx)) + }) + } + + r := NewRouter() + r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + vars := Vars(r) + if vars["foo"] != "bar" { + t.Fatal("Expected foo var to be set") + } + }))) + + rec := NewRecorder() + req := newRequest("GET", "/path/bar") + r.ServeHTTP(rec, req) +} + // mapToPairs converts a string map to a slice of string pairs func mapToPairs(m map[string]string) []string { var i int