From 1cddc56f134f8a80289fb8bc9b682c05f7012458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9otime=20L=C3=A9v=C3=AAque?= Date: Tue, 8 Mar 2022 10:18:04 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Cache=20middleware:=20Store=20e2e?= =?UTF-8?q?=20headers.=20(#1807)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ Cache middleware: Store e2e headers. As defined in RFC2616 - section-13.5.1, shared caches MUST store end-to-end headers from backend response and MUST be transmitted in any response formed from a cache entry. This commit ensures a stronger consistency between responses served from the handlers & from the cache middleware. * ✨ Cache middleware: Add flag for e2e headers. Set flag to prevent e2e headers caching to be the default behavior of the cache middleware. This would otherwise change quite a lot the experience for cache middleware current users. * ✨ Cache middleware: Add Benchmark for additionalHeaders feature. * ✨ Cache middleware: Rename E2Eheaders into StoreResponseHeaders. E2E is an acronym commonly associated with test. While in the present case it refers to end-to-end HTTP headers (by opposition to hop-by-hop), this still remains confusing. This commits renames it to a more generic name. * ✨ Cache middleware: Update README * ✨ Cache middleware: Move map instanciation. This will prevent an extra memory allocation for users not interested in this feature. * ✨ Cache middleware: Prevent memory allocation when StoreResponseHeaders is disabled. * ✨ Cache middleware: Store e2e headers. #1807 - use set instead of add for the headers - copy value from the headers -> prevent problems with mutable values Co-authored-by: wernerr --- middleware/cache/README.md | 7 +++ middleware/cache/cache.go | 33 ++++++++++ middleware/cache/cache_test.go | 50 +++++++++++++++ middleware/cache/config.go | 10 ++- middleware/cache/manager.go | 2 + middleware/cache/manager_msgp.go | 105 +++++++++++++++++++++++++++++-- 6 files changed, 200 insertions(+), 7 deletions(-) diff --git a/middleware/cache/README.md b/middleware/cache/README.md index 3f3607b54d..9f76418f1a 100644 --- a/middleware/cache/README.md +++ b/middleware/cache/README.md @@ -10,6 +10,7 @@ Cache middleware for [Fiber](https://github.com/gofiber/fiber) designed to inter - [Examples](#examples) - [Default Config](#default-config) - [Custom Config](#custom-config) + - [Custom Cache Key Or Expiration](#custom-cache-key-or-expiration) - [Config](#config) - [Default Config](#default-config-1) @@ -112,6 +113,11 @@ type Config struct { // // Default: an in memory store for this process only Storage fiber.Storage + + // allows you to store additional headers generated by next middlewares & handler + // + // Default: false + StoreResponseHeaders bool } ``` @@ -128,6 +134,7 @@ var ConfigDefault = Config{ return utils.CopyString(c.Path()) }, ExpirationGenerator : nil, + StoreResponseHeaders: false, Storage: nil, } ``` diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index aefe5c12a8..dee6d18a0e 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -27,6 +27,19 @@ const ( cacheMiss = "miss" ) +var ignoreHeaders = map[string]interface{}{ + "Connection": nil, + "Keep-Alive": nil, + "Proxy-Authenticate": nil, + "Proxy-Authorization": nil, + "TE": nil, + "Trailers": nil, + "Transfer-Encoding": nil, + "Upgrade": nil, + "Content-Type": nil, // already stored explicitely by the cache manager + "Content-Encoding": nil, // already stored explicitely by the cache manager +} + // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config @@ -96,6 +109,11 @@ func New(config ...Config) fiber.Handler { if len(e.cencoding) > 0 { c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding) } + if e.headers != nil { + for k, v := range e.headers { + c.Response().Header.SetBytesV(k, v) + } + } // Set Cache-Control header if enabled if cfg.CacheControl { maxAge := strconv.FormatUint(e.exp-ts, 10) @@ -134,6 +152,21 @@ func New(config ...Config) fiber.Handler { e.ctype = utils.CopyBytes(c.Response().Header.ContentType()) e.cencoding = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding)) + // Store all response headers + // (more: https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1) + if cfg.StoreResponseHeaders { + e.headers = make(map[string][]byte) + c.Response().Header.VisitAll( + func(key []byte, value []byte) { + // create real copy + keyS := string(key) + if _, ok := ignoreHeaders[keyS]; !ok { + e.headers[keyS] = utils.CopyBytes(value) + } + }, + ) + } + // default cache expiration expiration := uint64(cfg.Expiration.Seconds()) // Calculate expiration by response header or other setting diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index 654e91fdd6..29ca9730f7 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -302,6 +302,28 @@ func Test_CustomExpiration(t *testing.T) { utils.AssertEqual(t, 6000, newCacheTime) } +func Test_AdditionalE2EResponseHeaders(t *testing.T) { + app := fiber.New() + app.Use(New(Config{ + StoreResponseHeaders: true, + })) + + app.Get("/", func(c *fiber.Ctx) error { + c.Response().Header.Add("X-Foobar", "foobar") + return c.SendString("hi") + }) + + req := httptest.NewRequest("GET", "/", nil) + resp, err := app.Test(req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar")) + + req = httptest.NewRequest("GET", "/", nil) + resp, err = app.Test(req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar")) +} + func Test_CacheHeader(t *testing.T) { app := fiber.New() @@ -475,3 +497,31 @@ func Benchmark_Cache_Storage(b *testing.B) { utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) utils.AssertEqual(b, true, len(fctx.Response.Body()) > 30000) } + +func Benchmark_Cache_AdditionalHeaders(b *testing.B) { + app := fiber.New() + app.Use(New(Config{ + StoreResponseHeaders: true, + })) + + app.Get("/demo", func(c *fiber.Ctx) error { + c.Response().Header.Add("X-Foobar", "foobar") + return c.SendStatus(418) + }) + + h := app.Handler() + + fctx := &fasthttp.RequestCtx{} + fctx.Request.Header.SetMethod("GET") + fctx.Request.SetRequestURI("/demo") + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + h(fctx) + } + + utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) + utils.AssertEqual(b, []byte("foobar"), fctx.Response.Header.Peek("X-Foobar")) +} diff --git a/middleware/cache/config.go b/middleware/cache/config.go index cd25c4ade5..636eb0141e 100644 --- a/middleware/cache/config.go +++ b/middleware/cache/config.go @@ -54,6 +54,11 @@ type Config struct { // Deprecated, use KeyGenerator instead Key func(*fiber.Ctx) string + + // allows you to store additional headers generated by next middlewares & handler + // + // Default: false + StoreResponseHeaders bool } // ConfigDefault is the default config @@ -65,8 +70,9 @@ var ConfigDefault = Config{ KeyGenerator: func(c *fiber.Ctx) string { return utils.CopyString(c.Path()) }, - ExpirationGenerator: nil, - Storage: nil, + ExpirationGenerator: nil, + StoreResponseHeaders: false, + Storage: nil, } // Helper function to set default values diff --git a/middleware/cache/manager.go b/middleware/cache/manager.go index 3a41708943..5ba9260c09 100644 --- a/middleware/cache/manager.go +++ b/middleware/cache/manager.go @@ -18,6 +18,7 @@ type item struct { cencoding []byte status int exp uint64 + headers map[string][]byte } //msgp:ignore manager @@ -61,6 +62,7 @@ func (m *manager) release(e *item) { e.ctype = nil e.status = 0 e.exp = 0 + e.headers = nil m.pool.Put(e) } diff --git a/middleware/cache/manager_msgp.go b/middleware/cache/manager_msgp.go index b664caf8b9..9d373dae05 100644 --- a/middleware/cache/manager_msgp.go +++ b/middleware/cache/manager_msgp.go @@ -54,6 +54,36 @@ func (z *item) DecodeMsg(dc *msgp.Reader) (err error) { err = msgp.WrapError(err, "exp") return } + case "headers": + var zb0002 uint32 + zb0002, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err, "headers") + return + } + if z.headers == nil { + z.headers = make(map[string][]byte, zb0002) + } else if len(z.headers) > 0 { + for key := range z.headers { + delete(z.headers, key) + } + } + for zb0002 > 0 { + zb0002-- + var za0001 string + var za0002 []byte + za0001, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "headers") + return + } + za0002, err = dc.ReadBytes(za0002) + if err != nil { + err = msgp.WrapError(err, "headers", za0001) + return + } + z.headers[za0001] = za0002 + } default: err = dc.Skip() if err != nil { @@ -67,9 +97,9 @@ func (z *item) DecodeMsg(dc *msgp.Reader) (err error) { // EncodeMsg implements msgp.Encodable func (z *item) EncodeMsg(en *msgp.Writer) (err error) { - // map header, size 5 + // map header, size 6 // write "body" - err = en.Append(0x85, 0xa4, 0x62, 0x6f, 0x64, 0x79) + err = en.Append(0x86, 0xa4, 0x62, 0x6f, 0x64, 0x79) if err != nil { return } @@ -118,15 +148,37 @@ func (z *item) EncodeMsg(en *msgp.Writer) (err error) { err = msgp.WrapError(err, "exp") return } + // write "headers" + err = en.Append(0xaa, 0x65, 0x32, 0x65, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73) + if err != nil { + return + } + err = en.WriteMapHeader(uint32(len(z.headers))) + if err != nil { + err = msgp.WrapError(err, "headers") + return + } + for za0001, za0002 := range z.headers { + err = en.WriteString(za0001) + if err != nil { + err = msgp.WrapError(err, "headers") + return + } + err = en.WriteBytes(za0002) + if err != nil { + err = msgp.WrapError(err, "headers", za0001) + return + } + } return } // MarshalMsg implements msgp.Marshaler func (z *item) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) - // map header, size 5 + // map header, size 6 // string "body" - o = append(o, 0x85, 0xa4, 0x62, 0x6f, 0x64, 0x79) + o = append(o, 0x86, 0xa4, 0x62, 0x6f, 0x64, 0x79) o = msgp.AppendBytes(o, z.body) // string "ctype" o = append(o, 0xa5, 0x63, 0x74, 0x79, 0x70, 0x65) @@ -140,6 +192,13 @@ func (z *item) MarshalMsg(b []byte) (o []byte, err error) { // string "exp" o = append(o, 0xa3, 0x65, 0x78, 0x70) o = msgp.AppendUint64(o, z.exp) + // string "headers" + o = append(o, 0xaa, 0x65, 0x32, 0x65, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73) + o = msgp.AppendMapHeader(o, uint32(len(z.headers))) + for za0001, za0002 := range z.headers { + o = msgp.AppendString(o, za0001) + o = msgp.AppendBytes(o, za0002) + } return } @@ -191,6 +250,36 @@ func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) { err = msgp.WrapError(err, "exp") return } + case "headers": + var zb0002 uint32 + zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "headers") + return + } + if z.headers == nil { + z.headers = make(map[string][]byte, zb0002) + } else if len(z.headers) > 0 { + for key := range z.headers { + delete(z.headers, key) + } + } + for zb0002 > 0 { + var za0001 string + var za0002 []byte + zb0002-- + za0001, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "headers") + return + } + za0002, bts, err = msgp.ReadBytesBytes(bts, za0002) + if err != nil { + err = msgp.WrapError(err, "headers", za0001) + return + } + z.headers[za0001] = za0002 + } default: bts, err = msgp.Skip(bts) if err != nil { @@ -205,6 +294,12 @@ func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) { // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z *item) Msgsize() (s int) { - s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 7 + msgp.IntSize + 4 + msgp.Uint64Size + s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 7 + msgp.IntSize + 4 + msgp.Uint64Size + 11 + msgp.MapHeaderSize + if z.headers != nil { + for za0001, za0002 := range z.headers { + _ = za0002 + s += msgp.StringPrefixSize + len(za0001) + msgp.BytesPrefixSize + len(za0002) + } + } return }