From 871fcc1692bfb059f192519f143a9dbf81052617 Mon Sep 17 00:00:00 2001 From: Pinank Solanki Date: Sun, 1 May 2022 15:13:46 +0530 Subject: [PATCH] :bug: Fix expiration time in cache middleware (#1881) * :bug: Fix: Expiration time in cache middleware * Custom expiration time using ExpirationGenerator is also functional now instead of default Expiration only * :rotating_light: Improve Test_CustomExpiration * - stabilization of the tests - speed up the cache tests - fix race conditions in client and client tests Co-authored-by: wernerr --- client.go | 15 ++++---- middleware/cache/cache.go | 12 +++---- middleware/cache/cache_test.go | 63 +++++++++++++++++++++++++++++++--- 3 files changed, 73 insertions(+), 17 deletions(-) diff --git a/client.go b/client.go index f0d8db7c42..e5dd104367 100644 --- a/client.go +++ b/client.go @@ -3,6 +3,7 @@ package fiber import ( "bytes" "crypto/tls" + "encoding/json" "encoding/xml" "fmt" "io" @@ -16,8 +17,6 @@ import ( "sync" "time" - "encoding/json" - "github.com/gofiber/fiber/v2/utils" "github.com/valyala/fasthttp" ) @@ -60,6 +59,7 @@ var defaultClient Client // // It is safe calling Client methods from concurrently running goroutines. type Client struct { + mutex sync.RWMutex // UserAgent is used in User-Agent request header. UserAgent string @@ -133,10 +133,15 @@ func (c *Client) createAgent(method, url string) *Agent { a.req.Header.SetMethod(method) a.req.SetRequestURI(url) + c.mutex.RLock() a.Name = c.UserAgent a.NoDefaultUserAgentHeader = c.NoDefaultUserAgentHeader a.jsonDecoder = c.JSONDecoder a.jsonEncoder = c.JSONEncoder + if a.jsonDecoder == nil { + a.jsonDecoder = json.Unmarshal + } + c.mutex.RUnlock() if err := a.Parse(); err != nil { a.errs = append(a.errs, err) @@ -810,10 +815,6 @@ func (a *Agent) String() (int, string, []error) { // Struct returns the status code, bytes body and errors of url. // And bytes body will be unmarshalled to given v. func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) { - if a.jsonDecoder == nil { - a.jsonDecoder = json.Unmarshal - } - if code, body, errs = a.Bytes(); len(errs) > 0 { return } @@ -886,6 +887,8 @@ func AcquireClient() *Client { func ReleaseClient(c *Client) { c.UserAgent = "" c.NoDefaultUserAgentHeader = false + c.JSONEncoder = nil + c.JSONDecoder = nil clientPool.Put(c) } diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index dee6d18a0e..367baef694 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -168,23 +168,23 @@ func New(config ...Config) fiber.Handler { } // default cache expiration - expiration := uint64(cfg.Expiration.Seconds()) + expiration := cfg.Expiration // Calculate expiration by response header or other setting if cfg.ExpirationGenerator != nil { - expiration = uint64(cfg.ExpirationGenerator(c, &cfg).Seconds()) + expiration = cfg.ExpirationGenerator(c, &cfg) } - e.exp = ts + expiration + e.exp = ts + uint64(expiration.Seconds()) // For external Storage we store raw body separated if cfg.Storage != nil { - manager.setRaw(key+"_body", e.body, cfg.Expiration) + manager.setRaw(key+"_body", e.body, expiration) // avoid body msgp encoding e.body = nil - manager.set(key, e, cfg.Expiration) + manager.set(key, e, expiration) manager.release(e) } else { // Store entry in memory - manager.set(key, e, cfg.Expiration) + manager.set(key, e, expiration) } c.Set(cfg.CacheHeader, cacheMiss) diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index 216d9cea27..f29fd3650c 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -19,6 +19,8 @@ import ( ) func Test_Cache_CacheControl(t *testing.T) { + t.Parallel() + app := fiber.New() app.Use(New(Config{ @@ -77,6 +79,8 @@ func Test_Cache_Expired(t *testing.T) { } func Test_Cache(t *testing.T) { + t.Parallel() + app := fiber.New() app.Use(New()) @@ -102,6 +106,8 @@ func Test_Cache(t *testing.T) { } func Test_Cache_WithSeveralRequests(t *testing.T) { + t.Parallel() + app := fiber.New() app.Use(New(Config{ @@ -135,6 +141,8 @@ func Test_Cache_WithSeveralRequests(t *testing.T) { } func Test_Cache_Invalid_Expiration(t *testing.T) { + t.Parallel() + app := fiber.New() cache := New(Config{Expiration: 0 * time.Second}) app.Use(cache) @@ -161,6 +169,8 @@ func Test_Cache_Invalid_Expiration(t *testing.T) { } func Test_Cache_Invalid_Method(t *testing.T) { + t.Parallel() + app := fiber.New() app.Use(New()) @@ -199,6 +209,8 @@ func Test_Cache_Invalid_Method(t *testing.T) { } func Test_Cache_NothingToCache(t *testing.T) { + t.Parallel() + app := fiber.New() app.Use(New(Config{Expiration: -(time.Second * 1)})) @@ -225,6 +237,8 @@ func Test_Cache_NothingToCache(t *testing.T) { } func Test_Cache_CustomNext(t *testing.T) { + t.Parallel() + app := fiber.New() app.Use(New(Config{ @@ -263,6 +277,8 @@ func Test_Cache_CustomNext(t *testing.T) { } func Test_CustomKey(t *testing.T) { + t.Parallel() + app := fiber.New() var called bool app.Use(New(Config{KeyGenerator: func(c *fiber.Ctx) string { @@ -281,6 +297,8 @@ func Test_CustomKey(t *testing.T) { } func Test_CustomExpiration(t *testing.T) { + t.Parallel() + app := fiber.New() var called bool var newCacheTime int @@ -291,18 +309,45 @@ func Test_CustomExpiration(t *testing.T) { }})) app.Get("/", func(c *fiber.Ctx) error { - c.Response().Header.Add("Cache-Time", "6000") - return c.SendString("hi") + c.Response().Header.Add("Cache-Time", "1") + now := fmt.Sprintf("%d", time.Now().UnixNano()) + return c.SendString(now) }) - req := httptest.NewRequest("GET", "/", nil) - _, err := app.Test(req) + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, true, called) - utils.AssertEqual(t, 6000, newCacheTime) + utils.AssertEqual(t, 1, newCacheTime) + + // Sleep until the cache is expired + time.Sleep(1 * time.Second) + + cachedResp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + + body, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + cachedBody, err := ioutil.ReadAll(cachedResp.Body) + utils.AssertEqual(t, nil, err) + + if bytes.Equal(body, cachedBody) { + t.Errorf("Cache should have expired: %s, %s", body, cachedBody) + } + + // Next response should be cached + cachedRespNextRound, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + cachedBodyNextRound, err := ioutil.ReadAll(cachedRespNextRound.Body) + utils.AssertEqual(t, nil, err) + + if !bytes.Equal(cachedBodyNextRound, cachedBody) { + t.Errorf("Cache should not have expired: %s, %s", cachedBodyNextRound, cachedBody) + } } func Test_AdditionalE2EResponseHeaders(t *testing.T) { + t.Parallel() + app := fiber.New() app.Use(New(Config{ StoreResponseHeaders: true, @@ -325,6 +370,8 @@ func Test_AdditionalE2EResponseHeaders(t *testing.T) { } func Test_CacheHeader(t *testing.T) { + t.Parallel() + app := fiber.New() app.Use(New(Config{ @@ -364,6 +411,8 @@ func Test_CacheHeader(t *testing.T) { } func Test_Cache_WithHead(t *testing.T) { + t.Parallel() + app := fiber.New() app.Use(New()) @@ -389,6 +438,8 @@ func Test_Cache_WithHead(t *testing.T) { } func Test_Cache_WithHeadThenGet(t *testing.T) { + t.Parallel() + app := fiber.New() app.Use(New()) app.Get("/", func(c *fiber.Ctx) error { @@ -425,6 +476,8 @@ func Test_Cache_WithHeadThenGet(t *testing.T) { } func Test_CustomCacheHeader(t *testing.T) { + t.Parallel() + app := fiber.New() app.Use(New(Config{