Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 Fix expiration time in cache middleware #1881

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 9 additions & 6 deletions client.go
Expand Up @@ -3,6 +3,7 @@ package fiber
import (
"bytes"
"crypto/tls"
"encoding/json"
"encoding/xml"
"fmt"
"io"
Expand All @@ -16,8 +17,6 @@ import (
"sync"
"time"

"encoding/json"

"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
Expand Down Expand Up @@ -60,6 +59,7 @@ var defaultClient Client
//
// It is safe calling Client methods from concurrently running goroutines.
type Client struct {
mutex sync.RWMutex
// UserAgent is used in User-Agent request header.
UserAgent string

Expand Down Expand Up @@ -133,10 +133,15 @@ func (c *Client) createAgent(method, url string) *Agent {
a.req.Header.SetMethod(method)
a.req.SetRequestURI(url)

c.mutex.RLock()
a.Name = c.UserAgent
a.NoDefaultUserAgentHeader = c.NoDefaultUserAgentHeader
a.jsonDecoder = c.JSONDecoder
a.jsonEncoder = c.JSONEncoder
if a.jsonDecoder == nil {
a.jsonDecoder = json.Unmarshal
}
c.mutex.RUnlock()

if err := a.Parse(); err != nil {
a.errs = append(a.errs, err)
Expand Down Expand Up @@ -810,10 +815,6 @@ func (a *Agent) String() (int, string, []error) {
// Struct returns the status code, bytes body and errors of url.
// And bytes body will be unmarshalled to given v.
func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) {
if a.jsonDecoder == nil {
a.jsonDecoder = json.Unmarshal
}

if code, body, errs = a.Bytes(); len(errs) > 0 {
return
}
Expand Down Expand Up @@ -886,6 +887,8 @@ func AcquireClient() *Client {
func ReleaseClient(c *Client) {
c.UserAgent = ""
c.NoDefaultUserAgentHeader = false
c.JSONEncoder = nil
c.JSONDecoder = nil

clientPool.Put(c)
}
Expand Down
12 changes: 6 additions & 6 deletions middleware/cache/cache.go
Expand Up @@ -168,23 +168,23 @@ func New(config ...Config) fiber.Handler {
}

// default cache expiration
expiration := uint64(cfg.Expiration.Seconds())
expiration := cfg.Expiration
// Calculate expiration by response header or other setting
if cfg.ExpirationGenerator != nil {
expiration = uint64(cfg.ExpirationGenerator(c, &cfg).Seconds())
expiration = cfg.ExpirationGenerator(c, &cfg)
}
e.exp = ts + expiration
e.exp = ts + uint64(expiration.Seconds())

// For external Storage we store raw body separated
if cfg.Storage != nil {
manager.setRaw(key+"_body", e.body, cfg.Expiration)
manager.setRaw(key+"_body", e.body, expiration)
// avoid body msgp encoding
e.body = nil
manager.set(key, e, cfg.Expiration)
manager.set(key, e, expiration)
manager.release(e)
} else {
// Store entry in memory
manager.set(key, e, cfg.Expiration)
manager.set(key, e, expiration)
}

c.Set(cfg.CacheHeader, cacheMiss)
Expand Down
63 changes: 58 additions & 5 deletions middleware/cache/cache_test.go
Expand Up @@ -19,6 +19,8 @@ import (
)

func Test_Cache_CacheControl(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New(Config{
Expand Down Expand Up @@ -77,6 +79,8 @@ func Test_Cache_Expired(t *testing.T) {
}

func Test_Cache(t *testing.T) {
t.Parallel()

app := fiber.New()
app.Use(New())

Expand All @@ -102,6 +106,8 @@ func Test_Cache(t *testing.T) {
}

func Test_Cache_WithSeveralRequests(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New(Config{
Expand Down Expand Up @@ -135,6 +141,8 @@ func Test_Cache_WithSeveralRequests(t *testing.T) {
}

func Test_Cache_Invalid_Expiration(t *testing.T) {
t.Parallel()

app := fiber.New()
cache := New(Config{Expiration: 0 * time.Second})
app.Use(cache)
Expand All @@ -161,6 +169,8 @@ func Test_Cache_Invalid_Expiration(t *testing.T) {
}

func Test_Cache_Invalid_Method(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New())
Expand Down Expand Up @@ -199,6 +209,8 @@ func Test_Cache_Invalid_Method(t *testing.T) {
}

func Test_Cache_NothingToCache(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New(Config{Expiration: -(time.Second * 1)}))
Expand All @@ -225,6 +237,8 @@ func Test_Cache_NothingToCache(t *testing.T) {
}

func Test_Cache_CustomNext(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New(Config{
Expand Down Expand Up @@ -263,6 +277,8 @@ func Test_Cache_CustomNext(t *testing.T) {
}

func Test_CustomKey(t *testing.T) {
t.Parallel()

app := fiber.New()
var called bool
app.Use(New(Config{KeyGenerator: func(c *fiber.Ctx) string {
Expand All @@ -281,6 +297,8 @@ func Test_CustomKey(t *testing.T) {
}

func Test_CustomExpiration(t *testing.T) {
t.Parallel()

app := fiber.New()
var called bool
var newCacheTime int
Expand All @@ -291,18 +309,45 @@ func Test_CustomExpiration(t *testing.T) {
}}))

app.Get("/", func(c *fiber.Ctx) error {
c.Response().Header.Add("Cache-Time", "6000")
return c.SendString("hi")
c.Response().Header.Add("Cache-Time", "1")
now := fmt.Sprintf("%d", time.Now().UnixNano())
return c.SendString(now)
})

req := httptest.NewRequest("GET", "/", nil)
_, err := app.Test(req)
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, called)
utils.AssertEqual(t, 6000, newCacheTime)
utils.AssertEqual(t, 1, newCacheTime)

// Sleep until the cache is expired
time.Sleep(1 * time.Second)

cachedResp, err := app.Test(httptest.NewRequest("GET", "/", nil))
utils.AssertEqual(t, nil, err)

body, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
cachedBody, err := ioutil.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err)

if bytes.Equal(body, cachedBody) {
t.Errorf("Cache should have expired: %s, %s", body, cachedBody)
}

// Next response should be cached
cachedRespNextRound, err := app.Test(httptest.NewRequest("GET", "/", nil))
utils.AssertEqual(t, nil, err)
cachedBodyNextRound, err := ioutil.ReadAll(cachedRespNextRound.Body)
utils.AssertEqual(t, nil, err)

if !bytes.Equal(cachedBodyNextRound, cachedBody) {
t.Errorf("Cache should not have expired: %s, %s", cachedBodyNextRound, cachedBody)
}
}

func Test_AdditionalE2EResponseHeaders(t *testing.T) {
t.Parallel()

app := fiber.New()
app.Use(New(Config{
StoreResponseHeaders: true,
Expand All @@ -325,6 +370,8 @@ func Test_AdditionalE2EResponseHeaders(t *testing.T) {
}

func Test_CacheHeader(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New(Config{
Expand Down Expand Up @@ -364,6 +411,8 @@ func Test_CacheHeader(t *testing.T) {
}

func Test_Cache_WithHead(t *testing.T) {
t.Parallel()

app := fiber.New()
app.Use(New())

Expand All @@ -389,6 +438,8 @@ func Test_Cache_WithHead(t *testing.T) {
}

func Test_Cache_WithHeadThenGet(t *testing.T) {
t.Parallel()

app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
Expand Down Expand Up @@ -425,6 +476,8 @@ func Test_Cache_WithHeadThenGet(t *testing.T) {
}

func Test_CustomCacheHeader(t *testing.T) {
t.Parallel()

app := fiber.New()

app.Use(New(Config{
Expand Down