From 946b6237eb8d0ce3225f502b7fd4208d0b60ce5f Mon Sep 17 00:00:00 2001 From: Franklin Harding Date: Thu, 14 Nov 2019 12:19:09 -0800 Subject: [PATCH] Fix the CORSMethodMiddleware bug with subrouters * Adds a test case for the repro given in issue #534 * Fixes the logic in CORSMethodMiddleware to handle matching routes better --- middleware.go | 25 ++++++++++--------------- middleware_test.go | 20 ++++++++++++++++++++ 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/middleware.go b/middleware.go index cf2b26dc..cb51c565 100644 --- a/middleware.go +++ b/middleware.go @@ -58,22 +58,17 @@ func CORSMethodMiddleware(r *Router) MiddlewareFunc { func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) { var allMethods []string - err := r.Walk(func(route *Route, _ *Router, _ []*Route) error { - for _, m := range route.matchers { - if _, ok := m.(*routeRegexp); ok { - if m.Match(req, &RouteMatch{}) { - methods, err := route.GetMethods() - if err != nil { - return err - } - - allMethods = append(allMethods, methods...) - } - break + for _, route := range r.routes { + var match RouteMatch + if route.Match(req, &match) || match.MatchErr == ErrMethodMismatch { + methods, err := route.GetMethods() + if err != nil { + return nil, err } + + allMethods = append(allMethods, methods...) } - return nil - }) + } - return allMethods, err + return allMethods, nil } diff --git a/middleware_test.go b/middleware_test.go index 27647afe..e9f0ef55 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -478,6 +478,26 @@ func TestCORSMethodMiddleware(t *testing.T) { } } +func TestCORSMethodMiddlewareSubrouter(t *testing.T) { + router := NewRouter().StrictSlash(true) + + subrouter := router.PathPrefix("/test").Subrouter() + subrouter.HandleFunc("/hello", stringHandler("a")).Methods(http.MethodGet, http.MethodOptions, http.MethodPost) + subrouter.HandleFunc("/hello/{name}", stringHandler("b")).Methods(http.MethodGet, http.MethodOptions) + + subrouter.Use(CORSMethodMiddleware(subrouter)) + + rw := NewRecorder() + req := newRequest("GET", "/test/hello/asdf") + router.ServeHTTP(rw, req) + + actualMethods := rw.Header().Get("Access-Control-Allow-Methods") + expectedMethods := "GET,OPTIONS" + if actualMethods != expectedMethods { + t.Fatalf("expected methods %q but got: %q", expectedMethods, actualMethods) + } +} + func TestMiddlewareOnMultiSubrouter(t *testing.T) { first := "first" second := "second"