Skip to content

Commit

Permalink
🐛 Fix expiration time in cache middleware (gofiber#1881)
Browse files Browse the repository at this point in the history
* 🐛 Fix: Expiration time in cache middleware

* Custom expiration time using ExpirationGenerator is also functional
now instead of default Expiration only

* 🚨 Improve Test_CustomExpiration

* - stabilization of the tests
- speed up the cache tests
- fix race conditions in client and client tests

Co-authored-by: wernerr <rene@gofiber.io>
  • Loading branch information
2 people authored and trim21 committed Aug 15, 2022
1 parent b829bc4 commit 871fcc1
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 17 deletions.
15 changes: 9 additions & 6 deletions client.go
Expand Up @@ -3,6 +3,7 @@ package fiber
import (
"bytes"
"crypto/tls"
"encoding/json"
"encoding/xml"
"fmt"
"io"
Expand All @@ -16,8 +17,6 @@ import (
"sync"
"time"

"encoding/json"

"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
Expand Down Expand Up @@ -60,6 +59,7 @@ var defaultClient Client
//
// It is safe calling Client methods from concurrently running goroutines.
type Client struct {
mutex sync.RWMutex
// UserAgent is used in User-Agent request header.
UserAgent string

Expand Down Expand Up @@ -133,10 +133,15 @@ func (c *Client) createAgent(method, url string) *Agent {
a.req.Header.SetMethod(method)
a.req.SetRequestURI(url)

c.mutex.RLock()
a.Name = c.UserAgent
a.NoDefaultUserAgentHeader = c.NoDefaultUserAgentHeader
a.jsonDecoder = c.JSONDecoder
a.jsonEncoder = c.JSONEncoder
if a.jsonDecoder == nil {
a.jsonDecoder = json.Unmarshal
}
c.mutex.RUnlock()

if err := a.Parse(); err != nil {
a.errs = append(a.errs, err)
Expand Down Expand Up @@ -810,10 +815,6 @@ func (a *Agent) String() (int, string, []error) {
// Struct returns the status code, bytes body and errors of url.
// And bytes body will be unmarshalled to given v.
func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) {
if a.jsonDecoder == nil {
a.jsonDecoder = json.Unmarshal
}

if code, body, errs = a.Bytes(); len(errs) > 0 {
return
}
Expand Down Expand Up @@ -886,6 +887,8 @@ func AcquireClient() *Client {
func ReleaseClient(c *Client) {
c.UserAgent = ""
c.NoDefaultUserAgentHeader = false
c.JSONEncoder = nil
c.JSONDecoder = nil

clientPool.Put(c)
}
Expand Down
12 changes: 6 additions & 6 deletions middleware/cache/cache.go
Expand Up @@ -168,23 +168,23 @@ func New(config ...Config) fiber.Handler {
}

// default cache expiration
expiration := uint64(cfg.Expiration.Seconds())
expiration := cfg.Expiration
// Calculate expiration by response header or other setting
if cfg.ExpirationGenerator != nil {
expiration = uint64(cfg.ExpirationGenerator(c, &cfg).Seconds())
expiration = cfg.ExpirationGenerator(c, &cfg)
}
e.exp = ts + expiration
e.exp = ts + uint64(expiration.Seconds())

// For external Storage we store raw body separated
if cfg.Storage != nil {
manager.setRaw(key+"_body", e.body, cfg.Expiration)
manager.setRaw(key+"_body", e.body, expiration)
// avoid body msgp encoding
e.body = nil
manager.set(key, e, cfg.Expiration)
manager.set(key, e, expiration)
manager.release(e)
} else {
// Store entry in memory
manager.set(key, e, cfg.Expiration)
manager.set(key, e, expiration)
}

c.Set(cfg.CacheHeader, cacheMiss)
Expand Down
63 changes: 58 additions & 5 deletions middleware/cache/cache_test.go
Expand Up @@ -19,6 +19,8 @@ import (
)

func Test_Cache_CacheControl(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New(Config{
Expand Down Expand Up @@ -77,6 +79,8 @@ func Test_Cache_Expired(t *testing.T) {
}

func Test_Cache(t *testing.T) {
t.Parallel()

app := fiber.New()
app.Use(New())

Expand All @@ -102,6 +106,8 @@ func Test_Cache(t *testing.T) {
}

func Test_Cache_WithSeveralRequests(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New(Config{
Expand Down Expand Up @@ -135,6 +141,8 @@ func Test_Cache_WithSeveralRequests(t *testing.T) {
}

func Test_Cache_Invalid_Expiration(t *testing.T) {
t.Parallel()

app := fiber.New()
cache := New(Config{Expiration: 0 * time.Second})
app.Use(cache)
Expand All @@ -161,6 +169,8 @@ func Test_Cache_Invalid_Expiration(t *testing.T) {
}

func Test_Cache_Invalid_Method(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New())
Expand Down Expand Up @@ -199,6 +209,8 @@ func Test_Cache_Invalid_Method(t *testing.T) {
}

func Test_Cache_NothingToCache(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New(Config{Expiration: -(time.Second * 1)}))
Expand All @@ -225,6 +237,8 @@ func Test_Cache_NothingToCache(t *testing.T) {
}

func Test_Cache_CustomNext(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New(Config{
Expand Down Expand Up @@ -263,6 +277,8 @@ func Test_Cache_CustomNext(t *testing.T) {
}

func Test_CustomKey(t *testing.T) {
t.Parallel()

app := fiber.New()
var called bool
app.Use(New(Config{KeyGenerator: func(c *fiber.Ctx) string {
Expand All @@ -281,6 +297,8 @@ func Test_CustomKey(t *testing.T) {
}

func Test_CustomExpiration(t *testing.T) {
t.Parallel()

app := fiber.New()
var called bool
var newCacheTime int
Expand All @@ -291,18 +309,45 @@ func Test_CustomExpiration(t *testing.T) {
}}))

app.Get("/", func(c *fiber.Ctx) error {
c.Response().Header.Add("Cache-Time", "6000")
return c.SendString("hi")
c.Response().Header.Add("Cache-Time", "1")
now := fmt.Sprintf("%d", time.Now().UnixNano())
return c.SendString(now)
})

req := httptest.NewRequest("GET", "/", nil)
_, err := app.Test(req)
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, called)
utils.AssertEqual(t, 6000, newCacheTime)
utils.AssertEqual(t, 1, newCacheTime)

// Sleep until the cache is expired
time.Sleep(1 * time.Second)

cachedResp, err := app.Test(httptest.NewRequest("GET", "/", nil))
utils.AssertEqual(t, nil, err)

body, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
cachedBody, err := ioutil.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err)

if bytes.Equal(body, cachedBody) {
t.Errorf("Cache should have expired: %s, %s", body, cachedBody)
}

// Next response should be cached
cachedRespNextRound, err := app.Test(httptest.NewRequest("GET", "/", nil))
utils.AssertEqual(t, nil, err)
cachedBodyNextRound, err := ioutil.ReadAll(cachedRespNextRound.Body)
utils.AssertEqual(t, nil, err)

if !bytes.Equal(cachedBodyNextRound, cachedBody) {
t.Errorf("Cache should not have expired: %s, %s", cachedBodyNextRound, cachedBody)
}
}

func Test_AdditionalE2EResponseHeaders(t *testing.T) {
t.Parallel()

app := fiber.New()
app.Use(New(Config{
StoreResponseHeaders: true,
Expand All @@ -325,6 +370,8 @@ func Test_AdditionalE2EResponseHeaders(t *testing.T) {
}

func Test_CacheHeader(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New(Config{
Expand Down Expand Up @@ -364,6 +411,8 @@ func Test_CacheHeader(t *testing.T) {
}

func Test_Cache_WithHead(t *testing.T) {
t.Parallel()

app := fiber.New()
app.Use(New())

Expand All @@ -389,6 +438,8 @@ func Test_Cache_WithHead(t *testing.T) {
}

func Test_Cache_WithHeadThenGet(t *testing.T) {
t.Parallel()

app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
Expand Down Expand Up @@ -425,6 +476,8 @@ func Test_Cache_WithHeadThenGet(t *testing.T) {
}

func Test_CustomCacheHeader(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New(Config{
Expand Down

0 comments on commit 871fcc1

Please sign in to comment.