From aaa22381998b2c3d0313bb112c6039375af6c31a Mon Sep 17 00:00:00 2001 From: thylong Date: Sat, 5 Mar 2022 15:59:34 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Cache=20middleware:=20Store=20e2e?= =?UTF-8?q?=20headers.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As defined in RFC2616 - section-13.5.1, shared caches MUST store end-to-end headers from backend response and MUST be transmitted in any response formed from a cache entry. This commit ensures a stronger consistency between responses served from the handlers & from the cache middleware. --- middleware/cache/cache.go | 28 +++++++++ middleware/cache/cache_test.go | 20 ++++++ middleware/cache/manager.go | 12 ++-- middleware/cache/manager_msgp.go | 105 +++++++++++++++++++++++++++++-- 4 files changed, 155 insertions(+), 10 deletions(-) diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index aefe5c12a84..d8f7130365f 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -27,6 +27,19 @@ const ( cacheMiss = "miss" ) +var ignoreHeaders = map[string]interface{}{ + "Connection": nil, + "Keep-Alive": nil, + "Proxy-Authenticate": nil, + "Proxy-Authorization": nil, + "TE": nil, + "Trailers": nil, + "Transfer-Encoding": nil, + "Upgrade": nil, + "Content-Type": nil, // already stored explicitely by the cache manager + "Content-Encoding": nil, // already stored explicitely by the cache manager +} + // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config @@ -96,6 +109,9 @@ func New(config ...Config) fiber.Handler { if len(e.cencoding) > 0 { c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding) } + for k, v := range e.e2eHeaders { + c.Response().Header.AddBytesV(k, v) + } // Set Cache-Control header if enabled if cfg.CacheControl { maxAge := strconv.FormatUint(e.exp-ts, 10) @@ -133,6 +149,18 @@ func New(config ...Config) fiber.Handler { e.status = c.Response().StatusCode() e.ctype = utils.CopyBytes(c.Response().Header.ContentType()) e.cencoding = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding)) + e.e2eHeaders = make(map[string][]byte) + + // Store all end-to-end headers + // (more: https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1) + c.Response().Header.VisitAll( + func(key []byte, value []byte) { + keyS := string(key) + if _, isHopbyHop := ignoreHeaders[keyS]; !isHopbyHop { + e.e2eHeaders[keyS] = value + } + }, + ) // default cache expiration expiration := uint64(cfg.Expiration.Seconds()) diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index 654e91fdd6a..2d6027d2e33 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -302,6 +302,26 @@ func Test_CustomExpiration(t *testing.T) { utils.AssertEqual(t, 6000, newCacheTime) } +func Test_AdditionalE2EResponseHeaders(t *testing.T) { + app := fiber.New() + app.Use(New()) + + app.Get("/", func(c *fiber.Ctx) error { + c.Response().Header.Add("X-Foobar", "foobar") + return c.SendString("hi") + }) + + req := httptest.NewRequest("GET", "/", nil) + resp, err := app.Test(req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar")) + + req = httptest.NewRequest("GET", "/", nil) + resp, err = app.Test(req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar")) +} + func Test_CacheHeader(t *testing.T) { app := fiber.New() diff --git a/middleware/cache/manager.go b/middleware/cache/manager.go index 3a417089434..13cff7a7f98 100644 --- a/middleware/cache/manager.go +++ b/middleware/cache/manager.go @@ -13,11 +13,12 @@ import ( // don't forget to replace the msgp import path to: // "github.com/gofiber/fiber/v2/internal/msgp" type item struct { - body []byte - ctype []byte - cencoding []byte - status int - exp uint64 + body []byte + ctype []byte + cencoding []byte + status int + exp uint64 + e2eHeaders map[string][]byte } //msgp:ignore manager @@ -61,6 +62,7 @@ func (m *manager) release(e *item) { e.ctype = nil e.status = 0 e.exp = 0 + e.e2eHeaders = nil m.pool.Put(e) } diff --git a/middleware/cache/manager_msgp.go b/middleware/cache/manager_msgp.go index b664caf8b9a..85668799795 100644 --- a/middleware/cache/manager_msgp.go +++ b/middleware/cache/manager_msgp.go @@ -54,6 +54,36 @@ func (z *item) DecodeMsg(dc *msgp.Reader) (err error) { err = msgp.WrapError(err, "exp") return } + case "e2eHeaders": + var zb0002 uint32 + zb0002, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err, "e2eHeaders") + return + } + if z.e2eHeaders == nil { + z.e2eHeaders = make(map[string][]byte, zb0002) + } else if len(z.e2eHeaders) > 0 { + for key := range z.e2eHeaders { + delete(z.e2eHeaders, key) + } + } + for zb0002 > 0 { + zb0002-- + var za0001 string + var za0002 []byte + za0001, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "e2eHeaders") + return + } + za0002, err = dc.ReadBytes(za0002) + if err != nil { + err = msgp.WrapError(err, "e2eHeaders", za0001) + return + } + z.e2eHeaders[za0001] = za0002 + } default: err = dc.Skip() if err != nil { @@ -67,9 +97,9 @@ func (z *item) DecodeMsg(dc *msgp.Reader) (err error) { // EncodeMsg implements msgp.Encodable func (z *item) EncodeMsg(en *msgp.Writer) (err error) { - // map header, size 5 + // map header, size 6 // write "body" - err = en.Append(0x85, 0xa4, 0x62, 0x6f, 0x64, 0x79) + err = en.Append(0x86, 0xa4, 0x62, 0x6f, 0x64, 0x79) if err != nil { return } @@ -118,15 +148,37 @@ func (z *item) EncodeMsg(en *msgp.Writer) (err error) { err = msgp.WrapError(err, "exp") return } + // write "e2eHeaders" + err = en.Append(0xaa, 0x65, 0x32, 0x65, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73) + if err != nil { + return + } + err = en.WriteMapHeader(uint32(len(z.e2eHeaders))) + if err != nil { + err = msgp.WrapError(err, "e2eHeaders") + return + } + for za0001, za0002 := range z.e2eHeaders { + err = en.WriteString(za0001) + if err != nil { + err = msgp.WrapError(err, "e2eHeaders") + return + } + err = en.WriteBytes(za0002) + if err != nil { + err = msgp.WrapError(err, "e2eHeaders", za0001) + return + } + } return } // MarshalMsg implements msgp.Marshaler func (z *item) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) - // map header, size 5 + // map header, size 6 // string "body" - o = append(o, 0x85, 0xa4, 0x62, 0x6f, 0x64, 0x79) + o = append(o, 0x86, 0xa4, 0x62, 0x6f, 0x64, 0x79) o = msgp.AppendBytes(o, z.body) // string "ctype" o = append(o, 0xa5, 0x63, 0x74, 0x79, 0x70, 0x65) @@ -140,6 +192,13 @@ func (z *item) MarshalMsg(b []byte) (o []byte, err error) { // string "exp" o = append(o, 0xa3, 0x65, 0x78, 0x70) o = msgp.AppendUint64(o, z.exp) + // string "e2eHeaders" + o = append(o, 0xaa, 0x65, 0x32, 0x65, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73) + o = msgp.AppendMapHeader(o, uint32(len(z.e2eHeaders))) + for za0001, za0002 := range z.e2eHeaders { + o = msgp.AppendString(o, za0001) + o = msgp.AppendBytes(o, za0002) + } return } @@ -191,6 +250,36 @@ func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) { err = msgp.WrapError(err, "exp") return } + case "e2eHeaders": + var zb0002 uint32 + zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "e2eHeaders") + return + } + if z.e2eHeaders == nil { + z.e2eHeaders = make(map[string][]byte, zb0002) + } else if len(z.e2eHeaders) > 0 { + for key := range z.e2eHeaders { + delete(z.e2eHeaders, key) + } + } + for zb0002 > 0 { + var za0001 string + var za0002 []byte + zb0002-- + za0001, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "e2eHeaders") + return + } + za0002, bts, err = msgp.ReadBytesBytes(bts, za0002) + if err != nil { + err = msgp.WrapError(err, "e2eHeaders", za0001) + return + } + z.e2eHeaders[za0001] = za0002 + } default: bts, err = msgp.Skip(bts) if err != nil { @@ -205,6 +294,12 @@ func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) { // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z *item) Msgsize() (s int) { - s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 7 + msgp.IntSize + 4 + msgp.Uint64Size + s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 7 + msgp.IntSize + 4 + msgp.Uint64Size + 11 + msgp.MapHeaderSize + if z.e2eHeaders != nil { + for za0001, za0002 := range z.e2eHeaders { + _ = za0002 + s += msgp.StringPrefixSize + len(za0001) + msgp.BytesPrefixSize + len(za0002) + } + } return }