diff --git a/middleware/cache/README.md b/middleware/cache/README.md index 7074100ffb..235b96c69b 100644 --- a/middleware/cache/README.md +++ b/middleware/cache/README.md @@ -50,6 +50,25 @@ app.Use(cache.New(cache.Config{ })) ``` +### Custom Cache Key Or Expiration + +```go +app.Use(New(Config{ + ExpirationGenerator: func(c *fiber.Ctx, cfg *Config) time.Duration { + newCacheTime, _ := strconv.Atoi(c.GetRespHeader("Cache-Time", "600")) + return time.Second * time.Duration(newCacheTime) + }, + KeyGenerator: func(c *fiber.Ctx) string { + return c.Path() + } +})) + +app.Get("/", func(c *fiber.Ctx) error { + c.Response().Header.Add("Cache-Time", "6000") + return c.SendString("hi") +}) +``` + ### Config ```go @@ -84,6 +103,11 @@ type Config struct { // } KeyGenerator func(*fiber.Ctx) string + // allows you to generate custom Expiration Key By Key, default is Expiration (Optional) + // + // Default: nil + ExpirationGenerator func(*fiber.Ctx, *Config) time.Duration + // Store is used to store the state of the middleware // // Default: an in memory store for this process only @@ -103,6 +127,7 @@ var ConfigDefault = Config{ KeyGenerator: func(c *fiber.Ctx) string { return utils.CopyString(c.Path()) }, + ExpirationGenerator : nil, Storage: nil, } ``` diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index fae24cd7e6..84652b28c0 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -41,9 +41,8 @@ func New(config ...Config) fiber.Handler { var ( // Cache settings - mux = &sync.RWMutex{} - timestamp = uint64(time.Now().Unix()) - expiration = uint64(cfg.Expiration.Seconds()) + mux = &sync.RWMutex{} + timestamp = uint64(time.Now().Unix()) ) // Create manager to simplify storage operations ( see manager.go ) manager := newManager(cfg.Storage) @@ -125,6 +124,13 @@ func New(config ...Config) fiber.Handler { e.status = c.Response().StatusCode() e.ctype = utils.CopyBytes(c.Response().Header.ContentType()) e.cencoding = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding)) + + // default cache expiration + expiration := uint64(cfg.Expiration.Seconds()) + // Calculate expiration by response header or other setting + if cfg.ExpirationGenerator != nil { + expiration = uint64(cfg.ExpirationGenerator(c, &cfg).Seconds()) + } e.exp = ts + expiration // For external Storage we store raw body separated diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index 2cdf995a7a..e8c272b990 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -280,6 +280,28 @@ func Test_CustomKey(t *testing.T) { utils.AssertEqual(t, true, called) } +func Test_CustomExpiration(t *testing.T) { + app := fiber.New() + var called bool + var newCacheTime int + app.Use(New(Config{ExpirationGenerator: func(c *fiber.Ctx, cfg *Config) time.Duration { + called = true + newCacheTime, _ = strconv.Atoi(c.GetRespHeader("Cache-Time", "600")) + return time.Second * time.Duration(newCacheTime) + }})) + + app.Get("/", func(c *fiber.Ctx) error { + c.Response().Header.Add("Cache-Time", "6000") + return c.SendString("hi") + }) + + req := httptest.NewRequest("GET", "/", nil) + _, err := app.Test(req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, called) + utils.AssertEqual(t, 6000, newCacheTime) +} + func Test_CacheHeader(t *testing.T) { app := fiber.New() diff --git a/middleware/cache/config.go b/middleware/cache/config.go index 4d230d14da..cd25c4ade5 100644 --- a/middleware/cache/config.go +++ b/middleware/cache/config.go @@ -39,6 +39,11 @@ type Config struct { // } KeyGenerator func(*fiber.Ctx) string + // allows you to generate custom Expiration Key By Key, default is Expiration (Optional) + // + // Default: nil + ExpirationGenerator func(*fiber.Ctx, *Config) time.Duration + // Store is used to store the state of the middleware // // Default: an in memory store for this process only @@ -60,7 +65,8 @@ var ConfigDefault = Config{ KeyGenerator: func(c *fiber.Ctx) string { return utils.CopyString(c.Path()) }, - Storage: nil, + ExpirationGenerator: nil, + Storage: nil, } // Helper function to set default values