Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove/cleanup request context helpers #525

Merged
merged 4 commits into from Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 0 additions & 18 deletions context.go

This file was deleted.

30 changes: 0 additions & 30 deletions context_test.go

This file was deleted.

22 changes: 12 additions & 10 deletions mux.go
Expand Up @@ -5,6 +5,7 @@
package mux

import (
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -58,8 +59,7 @@ type Router struct {

// If true, do not clear the request context after handling the request.
//
// Deprecated: No effect when go1.7+ is used, since the context is stored
// on the request itself.
// Deprecated: No effect, since the context is stored on the request itself.
KeepContext bool

// Slice of middlewares to be called after a match is found
Expand Down Expand Up @@ -195,8 +195,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var handler http.Handler
if r.Match(req, &match) {
handler = match.Handler
req = setVars(req, match.Vars)
req = setCurrentRoute(req, match.Route)
req = requestWithVars(req, match.Vars)
req = requestWithRoute(req, match.Route)
}

if handler == nil && match.MatchErr == ErrMethodMismatch {
Expand Down Expand Up @@ -426,7 +426,7 @@ const (

// Vars returns the route variables for the current request, if any.
func Vars(r *http.Request) map[string]string {
if rv := contextGet(r, varsKey); rv != nil {
if rv := r.Context().Value(varsKey); rv != nil {
return rv.(map[string]string)
}
return nil
Expand All @@ -438,18 +438,20 @@ func Vars(r *http.Request) map[string]string {
// after the handler returns, unless the KeepContext option is set on the
// Router.
func CurrentRoute(r *http.Request) *Route {
if rv := contextGet(r, routeKey); rv != nil {
if rv := r.Context().Value(routeKey); rv != nil {
return rv.(*Route)
}
return nil
}

func setVars(r *http.Request, val interface{}) *http.Request {
return contextSet(r, varsKey, val)
func requestWithVars(r *http.Request, vars map[string]string) *http.Request {
ctx := context.WithValue(r.Context(), varsKey, vars)
return r.WithContext(ctx)
}

func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
return contextSet(r, routeKey, val)
func requestWithRoute(r *http.Request, route *Route) *http.Request {
ctx := context.WithValue(r.Context(), routeKey, route)
return r.WithContext(ctx)
}

// ----------------------------------------------------------------------------
Expand Down
24 changes: 24 additions & 0 deletions mux_test.go
Expand Up @@ -7,6 +7,7 @@ package mux
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
Expand All @@ -16,6 +17,7 @@ import (
"reflect"
"strings"
"testing"
"time"
)

func (r *Route) GoString() string {
Expand Down Expand Up @@ -2804,6 +2806,28 @@ func TestSubrouterNotFound(t *testing.T) {
}
}

func TestContextMiddleware(t *testing.T) {
withTimeout := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
defer cancel()
h.ServeHTTP(w, r.WithContext(ctx))
})
}

r := NewRouter()
r.Handle("/path/{foo}", withTimeout(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
vars := Vars(r)
if vars["foo"] != "bar" {
t.Fatal("Expected foo var to be set")
}
})))

rec := NewRecorder()
req := newRequest("GET", "/path/bar")
r.ServeHTTP(rec, req)
}

// mapToPairs converts a string map to a slice of string pairs
func mapToPairs(m map[string]string) []string {
var i int
Expand Down
2 changes: 1 addition & 1 deletion test_helpers.go
Expand Up @@ -15,5 +15,5 @@ import "net/http"
// can be set by making a route that captures the required variables,
// starting a server and sending the request to that server.
func SetURLVars(r *http.Request, val map[string]string) *http.Request {
return setVars(r, val)
return requestWithVars(r, val)
}