diff --git a/mux_test.go b/mux_test.go index 34c00dd2..e252d391 100644 --- a/mux_test.go +++ b/mux_test.go @@ -9,6 +9,7 @@ import ( "bytes" "errors" "fmt" + "io/ioutil" "net/http" "net/url" "reflect" @@ -2729,6 +2730,63 @@ func TestMethodNotAllowed(t *testing.T) { } } +type customMethodNotAllowedHandler struct { + msg string +} + +func (h customMethodNotAllowedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusMethodNotAllowed) + fmt.Fprint(w, h.msg) +} + +func TestSubrouterCustomMethodNotAllowed(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } + + router := NewRouter() + router.HandleFunc("/test", handler).Methods(http.MethodGet) + router.MethodNotAllowedHandler = customMethodNotAllowedHandler{msg: "custom router handler"} + + subrouter := router.PathPrefix("/sub").Subrouter() + subrouter.HandleFunc("/test", handler).Methods(http.MethodGet) + subrouter.MethodNotAllowedHandler = customMethodNotAllowedHandler{msg: "custom sub router handler"} + + testCases := map[string]struct { + path string + expMsg string + }{ + "router method not allowed": { + path: "/test", + expMsg: "custom router handler", + }, + "subrouter method not allowed": { + path: "/sub/test", + expMsg: "custom sub router handler", + }, + } + + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + w := NewRecorder() + req := newRequest(http.MethodPut, tc.path) + + router.ServeHTTP(w, req) + + if w.Code != http.StatusMethodNotAllowed { + tt.Errorf("Expected status code 405 (got %d)", w.Code) + } + + b, err := ioutil.ReadAll(w.Body) + if err != nil { + tt.Errorf("failed to read body: %v", err) + } + + if string(b) != tc.expMsg { + tt.Errorf("expected msg %q, got %q", tc.expMsg, string(b)) + } + }) + } +} + func TestSubrouterNotFound(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } router := NewRouter() diff --git a/route.go b/route.go index 8479c68c..7343d78a 100644 --- a/route.go +++ b/route.go @@ -74,7 +74,7 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool { return false } - if match.MatchErr == ErrMethodMismatch { + if match.MatchErr == ErrMethodMismatch && r.handler != nil { // We found a route which matches request method, clear MatchErr match.MatchErr = nil // Then override the mis-matched handler