Skip to content

Commit

Permalink
Remove/cleanup request context helpers (#525)
Browse files Browse the repository at this point in the history
* Remove context helpers in context.go
* Update request context funcs to take concrete types
* Move TestNativeContextMiddleware to mux_test.go
* Clarify KeepContext Go 1.7+ comment

Mux doesn't build on Go < 1.7 so the comment doesn't really need to
clarify anymore.
  • Loading branch information
fharding1 authored and elithrar committed Oct 24, 2019
1 parent ff4e71f commit f395758
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 59 deletions.
18 changes: 0 additions & 18 deletions context.go

This file was deleted.

30 changes: 0 additions & 30 deletions context_test.go

This file was deleted.

22 changes: 12 additions & 10 deletions mux.go
Expand Up @@ -5,6 +5,7 @@
package mux

import (
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -58,8 +59,7 @@ type Router struct {

// If true, do not clear the request context after handling the request.
//
// Deprecated: No effect when go1.7+ is used, since the context is stored
// on the request itself.
// Deprecated: No effect, since the context is stored on the request itself.
KeepContext bool

// Slice of middlewares to be called after a match is found
Expand Down Expand Up @@ -195,8 +195,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var handler http.Handler
if r.Match(req, &match) {
handler = match.Handler
req = setVars(req, match.Vars)
req = setCurrentRoute(req, match.Route)
req = requestWithVars(req, match.Vars)
req = requestWithRoute(req, match.Route)
}

if handler == nil && match.MatchErr == ErrMethodMismatch {
Expand Down Expand Up @@ -426,7 +426,7 @@ const (

// Vars returns the route variables for the current request, if any.
func Vars(r *http.Request) map[string]string {
if rv := contextGet(r, varsKey); rv != nil {
if rv := r.Context().Value(varsKey); rv != nil {
return rv.(map[string]string)
}
return nil
Expand All @@ -438,18 +438,20 @@ func Vars(r *http.Request) map[string]string {
// after the handler returns, unless the KeepContext option is set on the
// Router.
func CurrentRoute(r *http.Request) *Route {
if rv := contextGet(r, routeKey); rv != nil {
if rv := r.Context().Value(routeKey); rv != nil {
return rv.(*Route)
}
return nil
}

func setVars(r *http.Request, val interface{}) *http.Request {
return contextSet(r, varsKey, val)
func requestWithVars(r *http.Request, vars map[string]string) *http.Request {
ctx := context.WithValue(r.Context(), varsKey, vars)
return r.WithContext(ctx)
}

func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
return contextSet(r, routeKey, val)
func requestWithRoute(r *http.Request, route *Route) *http.Request {
ctx := context.WithValue(r.Context(), routeKey, route)
return r.WithContext(ctx)
}

// ----------------------------------------------------------------------------
Expand Down
24 changes: 24 additions & 0 deletions mux_test.go
Expand Up @@ -7,6 +7,7 @@ package mux
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
Expand All @@ -16,6 +17,7 @@ import (
"reflect"
"strings"
"testing"
"time"
)

func (r *Route) GoString() string {
Expand Down Expand Up @@ -2804,6 +2806,28 @@ func TestSubrouterNotFound(t *testing.T) {
}
}

func TestContextMiddleware(t *testing.T) {
withTimeout := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
defer cancel()
h.ServeHTTP(w, r.WithContext(ctx))
})
}

r := NewRouter()
r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
vars := Vars(r)
if vars["foo"] != "bar" {
t.Fatal("Expected foo var to be set")
}
})))

rec := NewRecorder()
req := newRequest("GET", "/path/bar")
r.ServeHTTP(rec, req)
}

// mapToPairs converts a string map to a slice of string pairs
func mapToPairs(m map[string]string) []string {
var i int
Expand Down
2 changes: 1 addition & 1 deletion test_helpers.go
Expand Up @@ -15,5 +15,5 @@ import "net/http"
// can be set by making a route that captures the required variables,
// starting a server and sending the request to that server.
func SetURLVars(r *http.Request, val map[string]string) *http.Request {
return setVars(r, val)
return requestWithVars(r, val)
}

0 comments on commit f395758

Please sign in to comment.