Skip to content

Commit

Permalink
🚀 [Feature]: middleware/csrf custom extractor (#2052)
Browse files Browse the repository at this point in the history
* feat(middleware/csrf): allow custom Extractor

* test: update Test_CSRF_From_Custom

* docs: add comma

* docs: update KeyLookup docs
  • Loading branch information
sixcolors committed Aug 28, 2022
1 parent 506f0b2 commit 6272d75
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 23 deletions.
14 changes: 13 additions & 1 deletion middleware/csrf/README.md
Expand Up @@ -49,9 +49,12 @@ app.Use(csrf.New(csrf.Config{
CookieSameSite: "Lax",
Expiration: 1 * time.Hour,
KeyGenerator: utils.UUID,
Extractor: func(c *fiber.Ctx) (string, error) { ... },
}))
```

Note: KeyLookup will be ignored if Extractor is explicitly set.

### Custom Storage/Database

You can use any storage from our [storage](https://github.com/gofiber/storage/) package.
Expand All @@ -74,14 +77,16 @@ type Config struct {
Next func(c *fiber.Ctx) bool

// KeyLookup is a string in the form of "<source>:<key>" that is used
// to extract token from the request.
// to create an Extractor that extracts the token from the request.
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "param:<name>"
// - "form:<name>"
// - "cookie:<name>"
//
// Ignored if an Extractor is explicitly set.
//
// Optional. Default: "header:X-CSRF-Token"
KeyLookup string

Expand Down Expand Up @@ -133,6 +138,13 @@ type Config struct {
//
// Optional. Default: utils.UUID
KeyGenerator func() string

// Extractor returns the csrf token
//
// If set this will be used in place of an Extractor based on KeyLookup.
//
// Optional. Default will create an Extractor based on KeyLookup.
Extractor func(c *fiber.Ctx) (string, error)
}
```

Expand Down
40 changes: 24 additions & 16 deletions middleware/csrf/config.go
Expand Up @@ -18,14 +18,16 @@ type Config struct {
Next func(c *fiber.Ctx) bool

// KeyLookup is a string in the form of "<source>:<key>" that is used
// to extract token from the request.
// to create an Extractor that extracts the token from the request.
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "param:<name>"
// - "form:<name>"
// - "cookie:<name>"
//
// Ignored if an Extractor is explicitly set.
//
// Optional. Default: "header:X-CSRF-Token"
KeyLookup string

Expand Down Expand Up @@ -92,8 +94,12 @@ type Config struct {
// Optional. Default: DefaultErrorHandler
ErrorHandler fiber.ErrorHandler

// extractor returns the csrf token from the request based on KeyLookup
extractor func(c *fiber.Ctx) (string, error)
// Extractor returns the csrf token
//
// If set this will be used in place of an Extractor based on KeyLookup.
//
// Optional. Default will create an Extractor based on KeyLookup.
Extractor func(c *fiber.Ctx) (string, error)
}

// ConfigDefault is the default config
Expand All @@ -104,7 +110,7 @@ var ConfigDefault = Config{
Expiration: 1 * time.Hour,
KeyGenerator: utils.UUID,
ErrorHandler: defaultErrorHandler,
extractor: csrfFromHeader("X-Csrf-Token"),
Extractor: CsrfFromHeader("X-Csrf-Token"),
}

// default ErrorHandler that process return error from fiber.Handler
Expand Down Expand Up @@ -174,18 +180,20 @@ func configDefault(config ...Config) Config {
panic("[CSRF] KeyLookup must in the form of <source>:<key>")
}

// By default we extract from a header
cfg.extractor = csrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1]))

switch selectors[0] {
case "form":
cfg.extractor = csrfFromForm(selectors[1])
case "query":
cfg.extractor = csrfFromQuery(selectors[1])
case "param":
cfg.extractor = csrfFromParam(selectors[1])
case "cookie":
cfg.extractor = csrfFromCookie(selectors[1])
if cfg.Extractor == nil {
// By default we extract from a header
cfg.Extractor = CsrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1]))

switch selectors[0] {
case "form":
cfg.Extractor = CsrfFromForm(selectors[1])
case "query":
cfg.Extractor = CsrfFromQuery(selectors[1])
case "param":
cfg.Extractor = CsrfFromParam(selectors[1])
case "cookie":
cfg.Extractor = CsrfFromCookie(selectors[1])
}
}

return cfg
Expand Down
2 changes: 1 addition & 1 deletion middleware/csrf/csrf.go
Expand Up @@ -39,7 +39,7 @@ func New(config ...Config) fiber.Handler {
// Assume that anything not defined as 'safe' by RFC7231 needs protection

// Extract token from client request i.e. header, query, param, form or cookie
token, err = cfg.extractor(c)
token, err = cfg.Extractor(c)
if err != nil {
return cfg.ErrorHandler(c, err)
}
Expand Down
44 changes: 44 additions & 0 deletions middleware/csrf/csrf_test.go
Expand Up @@ -236,6 +236,50 @@ func Test_CSRF_From_Cookie(t *testing.T) {
utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
}

func Test_CSRF_From_Custom(t *testing.T) {
app := fiber.New()

extractor := func(c *fiber.Ctx) (string, error) {
body := string(c.Body())
// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(body, "=")

if len(selectors) != 2 || selectors[1] == "" {
return "", errMissingParam
}
return selectors[1], nil
}

app.Use(New(Config{Extractor: extractor}))

app.Post("/", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

h := app.Handler()
ctx := &fasthttp.RequestCtx{}

// Invalid CSRF token
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())

// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]

ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
ctx.Request.SetBodyString("_csrf=" + token)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
}

func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
app := fiber.New()

Expand Down
10 changes: 5 additions & 5 deletions middleware/csrf/extractors.go
Expand Up @@ -15,7 +15,7 @@ var (
)

// csrfFromParam returns a function that extracts token from the url param string.
func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
func CsrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Params(param)
if token == "" {
Expand All @@ -26,7 +26,7 @@ func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
}

// csrfFromForm returns a function that extracts a token from a multipart-form.
func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
func CsrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.FormValue(param)
if token == "" {
Expand All @@ -37,7 +37,7 @@ func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
}

// csrfFromCookie returns a function that extracts token from the cookie header.
func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
func CsrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Cookies(param)
if token == "" {
Expand All @@ -48,7 +48,7 @@ func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
}

// csrfFromHeader returns a function that extracts token from the request header.
func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
func CsrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Get(param)
if token == "" {
Expand All @@ -59,7 +59,7 @@ func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
}

// csrfFromQuery returns a function that extracts token from the query string.
func csrfFromQuery(param string) func(c *fiber.Ctx) (string, error) {
func CsrfFromQuery(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Query(param)
if token == "" {
Expand Down

0 comments on commit 6272d75

Please sign in to comment.