From 6272d759ebe1c21755713d529f842665a7993eb9 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sun, 28 Aug 2022 13:57:47 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20[Feature]:=20middleware/csrf=20c?= =?UTF-8?q?ustom=20extractor=20(#2052)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(middleware/csrf): allow custom Extractor * test: update Test_CSRF_From_Custom * docs: add comma * docs: update KeyLookup docs --- middleware/csrf/README.md | 14 ++++++++++- middleware/csrf/config.go | 40 ++++++++++++++++++------------- middleware/csrf/csrf.go | 2 +- middleware/csrf/csrf_test.go | 44 +++++++++++++++++++++++++++++++++++ middleware/csrf/extractors.go | 10 ++++---- 5 files changed, 87 insertions(+), 23 deletions(-) diff --git a/middleware/csrf/README.md b/middleware/csrf/README.md index 8ab8171cc8..ac57428875 100644 --- a/middleware/csrf/README.md +++ b/middleware/csrf/README.md @@ -49,9 +49,12 @@ app.Use(csrf.New(csrf.Config{ CookieSameSite: "Lax", Expiration: 1 * time.Hour, KeyGenerator: utils.UUID, + Extractor: func(c *fiber.Ctx) (string, error) { ... }, })) ``` +Note: KeyLookup will be ignored if Extractor is explicitly set. + ### Custom Storage/Database You can use any storage from our [storage](https://github.com/gofiber/storage/) package. @@ -74,7 +77,7 @@ type Config struct { Next func(c *fiber.Ctx) bool // KeyLookup is a string in the form of ":" that is used - // to extract token from the request. + // to create an Extractor that extracts the token from the request. // Possible values: // - "header:" // - "query:" @@ -82,6 +85,8 @@ type Config struct { // - "form:" // - "cookie:" // + // Ignored if an Extractor is explicitly set. + // // Optional. Default: "header:X-CSRF-Token" KeyLookup string @@ -133,6 +138,13 @@ type Config struct { // // Optional. Default: utils.UUID KeyGenerator func() string + + // Extractor returns the csrf token + // + // If set this will be used in place of an Extractor based on KeyLookup. + // + // Optional. Default will create an Extractor based on KeyLookup. + Extractor func(c *fiber.Ctx) (string, error) } ``` diff --git a/middleware/csrf/config.go b/middleware/csrf/config.go index 4d9cd7b8a1..5e45c7b83d 100644 --- a/middleware/csrf/config.go +++ b/middleware/csrf/config.go @@ -18,7 +18,7 @@ type Config struct { Next func(c *fiber.Ctx) bool // KeyLookup is a string in the form of ":" that is used - // to extract token from the request. + // to create an Extractor that extracts the token from the request. // Possible values: // - "header:" // - "query:" @@ -26,6 +26,8 @@ type Config struct { // - "form:" // - "cookie:" // + // Ignored if an Extractor is explicitly set. + // // Optional. Default: "header:X-CSRF-Token" KeyLookup string @@ -92,8 +94,12 @@ type Config struct { // Optional. Default: DefaultErrorHandler ErrorHandler fiber.ErrorHandler - // extractor returns the csrf token from the request based on KeyLookup - extractor func(c *fiber.Ctx) (string, error) + // Extractor returns the csrf token + // + // If set this will be used in place of an Extractor based on KeyLookup. + // + // Optional. Default will create an Extractor based on KeyLookup. + Extractor func(c *fiber.Ctx) (string, error) } // ConfigDefault is the default config @@ -104,7 +110,7 @@ var ConfigDefault = Config{ Expiration: 1 * time.Hour, KeyGenerator: utils.UUID, ErrorHandler: defaultErrorHandler, - extractor: csrfFromHeader("X-Csrf-Token"), + Extractor: CsrfFromHeader("X-Csrf-Token"), } // default ErrorHandler that process return error from fiber.Handler @@ -174,18 +180,20 @@ func configDefault(config ...Config) Config { panic("[CSRF] KeyLookup must in the form of :") } - // By default we extract from a header - cfg.extractor = csrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1])) - - switch selectors[0] { - case "form": - cfg.extractor = csrfFromForm(selectors[1]) - case "query": - cfg.extractor = csrfFromQuery(selectors[1]) - case "param": - cfg.extractor = csrfFromParam(selectors[1]) - case "cookie": - cfg.extractor = csrfFromCookie(selectors[1]) + if cfg.Extractor == nil { + // By default we extract from a header + cfg.Extractor = CsrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1])) + + switch selectors[0] { + case "form": + cfg.Extractor = CsrfFromForm(selectors[1]) + case "query": + cfg.Extractor = CsrfFromQuery(selectors[1]) + case "param": + cfg.Extractor = CsrfFromParam(selectors[1]) + case "cookie": + cfg.Extractor = CsrfFromCookie(selectors[1]) + } } return cfg diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index 1de2aec5d6..e7ad4f2a72 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -39,7 +39,7 @@ func New(config ...Config) fiber.Handler { // Assume that anything not defined as 'safe' by RFC7231 needs protection // Extract token from client request i.e. header, query, param, form or cookie - token, err = cfg.extractor(c) + token, err = cfg.Extractor(c) if err != nil { return cfg.ErrorHandler(c, err) } diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 082b3519a1..90650bf327 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -236,6 +236,50 @@ func Test_CSRF_From_Cookie(t *testing.T) { utils.AssertEqual(t, "OK", string(ctx.Response.Body())) } +func Test_CSRF_From_Custom(t *testing.T) { + app := fiber.New() + + extractor := func(c *fiber.Ctx) (string, error) { + body := string(c.Body()) + // Generate the correct extractor to get the token from the correct location + selectors := strings.Split(body, "=") + + if len(selectors) != 2 || selectors[1] == "" { + return "", errMissingParam + } + return selectors[1], nil + } + + app.Use(New(Config{Extractor: extractor})) + + app.Post("/", func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + // Invalid CSRF token + ctx.Request.Header.SetMethod("POST") + ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) + h(ctx) + utils.AssertEqual(t, 403, ctx.Response.StatusCode()) + + // Generate CSRF token + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod("GET") + h(ctx) + token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + token = strings.Split(strings.Split(token, ";")[0], "=")[1] + + ctx.Request.Header.SetMethod("POST") + ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) + ctx.Request.SetBodyString("_csrf=" + token) + h(ctx) + utils.AssertEqual(t, 200, ctx.Response.StatusCode()) +} + func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) { app := fiber.New() diff --git a/middleware/csrf/extractors.go b/middleware/csrf/extractors.go index 2b577b1233..ab669bb47a 100644 --- a/middleware/csrf/extractors.go +++ b/middleware/csrf/extractors.go @@ -15,7 +15,7 @@ var ( ) // csrfFromParam returns a function that extracts token from the url param string. -func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) { +func CsrfFromParam(param string) func(c *fiber.Ctx) (string, error) { return func(c *fiber.Ctx) (string, error) { token := c.Params(param) if token == "" { @@ -26,7 +26,7 @@ func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) { } // csrfFromForm returns a function that extracts a token from a multipart-form. -func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) { +func CsrfFromForm(param string) func(c *fiber.Ctx) (string, error) { return func(c *fiber.Ctx) (string, error) { token := c.FormValue(param) if token == "" { @@ -37,7 +37,7 @@ func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) { } // csrfFromCookie returns a function that extracts token from the cookie header. -func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) { +func CsrfFromCookie(param string) func(c *fiber.Ctx) (string, error) { return func(c *fiber.Ctx) (string, error) { token := c.Cookies(param) if token == "" { @@ -48,7 +48,7 @@ func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) { } // csrfFromHeader returns a function that extracts token from the request header. -func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) { +func CsrfFromHeader(param string) func(c *fiber.Ctx) (string, error) { return func(c *fiber.Ctx) (string, error) { token := c.Get(param) if token == "" { @@ -59,7 +59,7 @@ func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) { } // csrfFromQuery returns a function that extracts token from the query string. -func csrfFromQuery(param string) func(c *fiber.Ctx) (string, error) { +func CsrfFromQuery(param string) func(c *fiber.Ctx) (string, error) { return func(c *fiber.Ctx) (string, error) { token := c.Query(param) if token == "" {