diff --git a/middleware/csrf/config.go b/middleware/csrf/config.go index 5e45c7b83d..94c5287b65 100644 --- a/middleware/csrf/config.go +++ b/middleware/csrf/config.go @@ -102,15 +102,17 @@ type Config struct { Extractor func(c *fiber.Ctx) (string, error) } +const HeaderName = "X-Csrf-Token" + // ConfigDefault is the default config var ConfigDefault = Config{ - KeyLookup: "header:X-Csrf-Token", + KeyLookup: "header:" + HeaderName, CookieName: "csrf_", CookieSameSite: "Lax", Expiration: 1 * time.Hour, KeyGenerator: utils.UUID, ErrorHandler: defaultErrorHandler, - Extractor: CsrfFromHeader("X-Csrf-Token"), + Extractor: CsrfFromHeader(HeaderName), } // default ErrorHandler that process return error from fiber.Handler diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 90650bf327..dcdabc4628 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -40,7 +40,7 @@ func Test_CSRF(t *testing.T) { ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod("POST") - ctx.Request.Header.Set("X-CSRF-Token", "johndoe") + ctx.Request.Header.Set(HeaderName, "johndoe") h(ctx) utils.AssertEqual(t, 403, ctx.Response.StatusCode()) @@ -55,7 +55,7 @@ func Test_CSRF(t *testing.T) { ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod("POST") - ctx.Request.Header.Set("X-CSRF-Token", token) + ctx.Request.Header.Set(HeaderName, token) h(ctx) utils.AssertEqual(t, 200, ctx.Response.StatusCode()) } @@ -305,7 +305,7 @@ func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) { ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod("POST") - ctx.Request.Header.Set("X-CSRF-Token", "johndoe") + ctx.Request.Header.Set(HeaderName, "johndoe") h(ctx) utils.AssertEqual(t, 419, ctx.Response.StatusCode()) utils.AssertEqual(t, "invalid CSRF token", string(ctx.Response.Body())) @@ -340,3 +340,111 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) { utils.AssertEqual(t, 419, ctx.Response.StatusCode()) utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body())) } + +// TODO: use this test case and make the unsafe header value bug from https://github.com/gofiber/fiber/issues/2045 reproducible and permanently fixed/tested by this testcase +//func Test_CSRF_UnsafeHeaderValue(t *testing.T) { +// app := fiber.New() +// +// app.Use(New()) +// app.Get("/", func(c *fiber.Ctx) error { +// return c.SendStatus(fiber.StatusOK) +// }) +// app.Get("/test", func(c *fiber.Ctx) error { +// return c.SendStatus(fiber.StatusOK) +// }) +// app.Post("/", func(c *fiber.Ctx) error { +// return c.SendStatus(fiber.StatusOK) +// }) +// +// resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil)) +// utils.AssertEqual(t, nil, err) +// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) +// +// var token string +// for _, c := range resp.Cookies() { +// if c.Name != ConfigDefault.CookieName { +// continue +// } +// token = c.Value +// break +// } +// +// fmt.Println("token", token) +// +// getReq := httptest.NewRequest(http.MethodGet, "/", nil) +// getReq.Header.Set(HeaderName, token) +// resp, err = app.Test(getReq) +// +// getReq = httptest.NewRequest(http.MethodGet, "/test", nil) +// getReq.Header.Set("X-Requested-With", "XMLHttpRequest") +// getReq.Header.Set(fiber.HeaderCacheControl, "no") +// getReq.Header.Set(HeaderName, token) +// +// resp, err = app.Test(getReq) +// +// getReq.Header.Set(fiber.HeaderAccept, "*/*") +// getReq.Header.Del(HeaderName) +// resp, err = app.Test(getReq) +// +// postReq := httptest.NewRequest(http.MethodPost, "/", nil) +// postReq.Header.Set("X-Requested-With", "XMLHttpRequest") +// postReq.Header.Set(HeaderName, token) +// resp, err = app.Test(postReq) +//} + +// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4 +func Benchmark_Middleware_CSRF_Check(b *testing.B) { + app := fiber.New() + + app.Use(New()) + app.Get("/", func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusTeapot) + }) + + fctx := &fasthttp.RequestCtx{} + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + // Generate CSRF token + ctx.Request.Header.SetMethod("GET") + h(ctx) + token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + token = strings.Split(strings.Split(token, ";")[0], "=")[1] + + ctx.Request.Header.SetMethod("POST") + ctx.Request.Header.Set(HeaderName, token) + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + h(fctx) + } + + utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) +} + +// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4 +func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) { + app := fiber.New() + + app.Use(New()) + app.Get("/", func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusTeapot) + }) + + fctx := &fasthttp.RequestCtx{} + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + // Generate CSRF token + ctx.Request.Header.SetMethod("GET") + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + h(fctx) + } + + utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) +} diff --git a/middleware/csrf/manager.go b/middleware/csrf/manager.go index dd70c2ed1c..13f0ccb657 100644 --- a/middleware/csrf/manager.go +++ b/middleware/csrf/manager.go @@ -6,6 +6,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/internal/memory" + "github.com/gofiber/fiber/v2/utils" ) // go:generate msgp @@ -88,7 +89,8 @@ func (m *manager) set(key string, it *item, exp time.Duration) { _ = m.storage.Set(key, raw, exp) } } else { - m.memory.Set(key, it, exp) + // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here + m.memory.Set(utils.CopyString(key), it, exp) } } @@ -97,7 +99,8 @@ func (m *manager) setRaw(key string, raw []byte, exp time.Duration) { if m.storage != nil { _ = m.storage.Set(key, raw, exp) } else { - m.memory.Set(key, raw, exp) + // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here + m.memory.Set(utils.CopyString(key), raw, exp) } }