From 5cc3349f521fa46d971a62334a80d86fd77932ce Mon Sep 17 00:00:00 2001 From: divyapola5 <87338962+divyapola5@users.noreply.github.com> Date: Wed, 15 Sep 2021 16:32:14 -0500 Subject: [PATCH] Enforce minimum cache size for transit backend (#12418) (#12551) * Enforce Minimum cache size for transit backend * enfore minimum cache size and log a warning during backend construction * Update documentation for transit backend cache configuration * Added changelog * Addressed review feedback and added unit test * Modify code in pathCacheConfigWrite to make use of the updated cache size * Updated code to refresh cache size on transit backend without restart * Update code to acquire read and write locks appropriately --- builtin/logical/transit/backend.go | 50 ++++++++++++++++++- builtin/logical/transit/path_cache_config.go | 12 ++--- .../logical/transit/path_cache_config_test.go | 43 +++++++++++++++- builtin/logical/transit/path_config.go | 2 +- builtin/logical/transit/path_datakey.go | 2 +- builtin/logical/transit/path_decrypt.go | 2 +- builtin/logical/transit/path_encrypt.go | 4 +- builtin/logical/transit/path_export.go | 2 +- builtin/logical/transit/path_hmac.go | 4 +- builtin/logical/transit/path_hmac_test.go | 4 +- builtin/logical/transit/path_keys.go | 4 +- builtin/logical/transit/path_rewrap.go | 2 +- builtin/logical/transit/path_rotate.go | 2 +- builtin/logical/transit/path_sign_verify.go | 4 +- .../logical/transit/path_sign_verify_test.go | 6 +-- builtin/logical/transit/path_trim.go | 2 +- builtin/logical/transit/path_trim_test.go | 2 +- changelog/12418.txt | 4 ++ sdk/helper/keysutil/lock_manager.go | 18 +++++++ website/content/api-docs/secret/transit.mdx | 2 +- 20 files changed, 141 insertions(+), 30 deletions(-) create mode 100644 changelog/12418.txt diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index f438ac1b0290b..0e2f8264153b8 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -3,13 +3,18 @@ package transit import ( "context" "fmt" + "io" "strings" + "sync" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" ) +// Minimum cache size for transit backend +const minCacheSize = 10 + func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { b, err := Backend(ctx, conf) if err != nil { @@ -68,6 +73,11 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error) if err != nil { return nil, fmt.Errorf("Error retrieving cache size from storage: %w", err) } + + if cacheSize != 0 && cacheSize < minCacheSize { + b.Logger().Warn("size %d is less than minimum %d. Cache size is set to %d", cacheSize, minCacheSize, minCacheSize) + cacheSize = minCacheSize + } } var err error @@ -82,6 +92,9 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error) type backend struct { *framework.Backend lm *keysutil.LockManager + // Lock to make changes to any of the backend's cache configuration. + configMutex sync.RWMutex + cacheSizeChanged bool } func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error) { @@ -100,7 +113,37 @@ func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error return size, nil } -func (b *backend) invalidate(_ context.Context, key string) { +// Update cache size and get policy +func (b *backend) GetPolicy(ctx context.Context, polReq keysutil.PolicyRequest, rand io.Reader) (retP *keysutil.Policy, retUpserted bool, retErr error) { + // Acquire read lock to read cacheSizeChanged + b.configMutex.RLock() + if b.lm.GetUseCache() && b.cacheSizeChanged { + var err error + currentCacheSize := b.lm.GetCacheSize() + storedCacheSize, err := GetCacheSizeFromStorage(ctx, polReq.Storage) + if err != nil { + return nil, false, err + } + if currentCacheSize != storedCacheSize { + err = b.lm.InitCache(storedCacheSize) + if err != nil { + return nil, false, err + } + } + // Release the read lock and acquire the write lock + b.configMutex.RUnlock() + b.configMutex.Lock() + defer b.configMutex.Unlock() + b.cacheSizeChanged = false + } + p, _, err := b.lm.GetPolicy(ctx, polReq, rand) + if err != nil { + return p, false, err + } + return p, true, nil +} + +func (b *backend) invalidate(ctx context.Context, key string) { if b.Logger().IsDebug() { b.Logger().Debug("invalidating key", "key", key) } @@ -108,5 +151,10 @@ func (b *backend) invalidate(_ context.Context, key string) { case strings.HasPrefix(key, "policy/"): name := strings.TrimPrefix(key, "policy/") b.lm.InvalidatePolicy(name) + case strings.HasPrefix(key, "cache-config/"): + // Acquire the lock to set the flag to indicate that cache size needs to be refreshed from storage + b.configMutex.Lock() + defer b.configMutex.Unlock() + b.cacheSizeChanged = true } } diff --git a/builtin/logical/transit/path_cache_config.go b/builtin/logical/transit/path_cache_config.go index 6239555b37668..6610548ce1351 100644 --- a/builtin/logical/transit/path_cache_config.go +++ b/builtin/logical/transit/path_cache_config.go @@ -3,7 +3,6 @@ package transit import ( "context" "errors" - "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" ) @@ -45,8 +44,8 @@ func (b *backend) pathCacheConfig() *framework.Path { func (b *backend) pathCacheConfigWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { // get target size cacheSize := d.Get("size").(int) - if cacheSize < 0 { - return logical.ErrorResponse("size must be greater or equal to 0"), logical.ErrInvalidRequest + if cacheSize != 0 && cacheSize < minCacheSize { + return logical.ErrorResponse("size must be 0 or a value greater or equal to %d", minCacheSize), logical.ErrInvalidRequest } // store cache size @@ -60,11 +59,12 @@ func (b *backend) pathCacheConfigWrite(ctx context.Context, req *logical.Request return nil, err } - resp := &logical.Response{ - Warnings: []string{"cache configurations will be applied when this backend is restarted"}, + err = b.lm.InitCache(cacheSize) + if err != nil { + return nil, err } - return resp, nil + return nil, nil } type configCache struct { diff --git a/builtin/logical/transit/path_cache_config_test.go b/builtin/logical/transit/path_cache_config_test.go index 6cca1b265676a..2d74129f95bfd 100644 --- a/builtin/logical/transit/path_cache_config_test.go +++ b/builtin/logical/transit/path_cache_config_test.go @@ -8,6 +8,7 @@ import ( ) const targetCacheSize = 12345 +const smallCacheSize = 3 func TestTransit_CacheConfig(t *testing.T) { b1, storage := createBackendWithSysView(t) @@ -58,17 +59,51 @@ func TestTransit_CacheConfig(t *testing.T) { }, } + writeSmallCacheSizeReq := &logical.Request{ + Storage: storage, + Operation: logical.UpdateOperation, + Path: "cache-config", + Data: map[string]interface{}{ + "size": smallCacheSize, + }, + } + readReq := &logical.Request{ Storage: storage, Operation: logical.ReadOperation, Path: "cache-config", } + polReq := &logical.Request{ + Storage: storage, + Operation: logical.UpdateOperation, + Path: "keys/aes256", + Data: map[string]interface{}{ + "derived": true, + }, + } + + // test steps // b1 should spin up with an unlimited cache validateResponse(doReq(b1, readReq), 0, false) + + // Change cache size to targetCacheSize 12345 and validate that cache size is updated doReq(b1, writeReq) - validateResponse(doReq(b1, readReq), targetCacheSize, true) + validateResponse(doReq(b1, readReq), targetCacheSize, false) + b1.invalidate(context.Background(), "cache-config/") + + // Change the cache size to 1000 to mock the scenario where + // current cache size and stored cache size are different and + // a cache update is needed + b1.lm.InitCache(1000) + + // Write a new policy which in its code path detects that cache size has changed + // and refreshes the cache to 12345 + doReq(b1, polReq) + + // Validate that cache size is updated to 12345 + validateResponse(doReq(b1, readReq), targetCacheSize, false) // b2 should spin up with a configured cache b2 := createBackendWithSysViewWithStorage(t, storage) @@ -77,4 +112,10 @@ func TestTransit_CacheConfig(t *testing.T) { // b3 enables transit without a cache, trying to read it should error b3 := createBackendWithForceNoCacheWithSysViewWithStorage(t, storage) doErrReq(b3, readReq) + + // b4 should spin up with a size less than minimum cache size (10) + b4, storage := createBackendWithSysView(t) + doErrReq(b4, writeSmallCacheSizeReq) + + } diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go index 4641e2b6a7cd1..1c41cd0d49dd1 100644 --- a/builtin/logical/transit/path_config.go +++ b/builtin/logical/transit/path_config.go @@ -62,7 +62,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d * name := d.Get("name").(string) // Check if the policy already exists before we lock everything - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: name, }, b.GetRandomReader()) diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index a287bea34415f..9e9ef2c173404 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -99,7 +99,7 @@ func (b *backend) pathDatakeyWrite(ctx context.Context, req *logical.Request, d } // Get the policy - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: name, }, b.GetRandomReader()) diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index 5d8510da89f46..cf6d450604062 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -121,7 +121,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d } // Get the policy - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: d.Get("name").(string), }, b.GetRandomReader()) diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index 321e920998943..c328951645a18 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -217,7 +217,7 @@ func decodeBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error { func (b *backend) pathEncryptExistenceCheck(ctx context.Context, req *logical.Request, d *framework.FieldData) (bool, error) { name := d.Get("name").(string) - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: name, }, b.GetRandomReader()) @@ -336,7 +336,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d } } - p, upserted, err = b.lm.GetPolicy(ctx, polReq, b.GetRandomReader()) + p, upserted, err = b.GetPolicy(ctx, polReq, b.GetRandomReader()) if err != nil { return nil, err } diff --git a/builtin/logical/transit/path_export.go b/builtin/logical/transit/path_export.go index 33a76cf33b738..3b0d97e15e735 100644 --- a/builtin/logical/transit/path_export.go +++ b/builtin/logical/transit/path_export.go @@ -64,7 +64,7 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request return logical.ErrorResponse(fmt.Sprintf("invalid export type: %s", exportType)), logical.ErrInvalidRequest } - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: name, }, b.GetRandomReader()) diff --git a/builtin/logical/transit/path_hmac.go b/builtin/logical/transit/path_hmac.go index 025a39efdf01f..30a79a40789ab 100644 --- a/builtin/logical/transit/path_hmac.go +++ b/builtin/logical/transit/path_hmac.go @@ -96,7 +96,7 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr } // Get the policy - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: name, }, b.GetRandomReader()) @@ -224,7 +224,7 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f } // Get the policy - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: name, }, b.GetRandomReader()) diff --git a/builtin/logical/transit/path_hmac_test.go b/builtin/logical/transit/path_hmac_test.go index b9d6bbc813325..756dc77e559f1 100644 --- a/builtin/logical/transit/path_hmac_test.go +++ b/builtin/logical/transit/path_hmac_test.go @@ -26,7 +26,7 @@ func TestTransit_HMAC(t *testing.T) { } // Now, change the key value to something we control - p, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{ Storage: storage, Name: "foo", }, b.GetRandomReader()) @@ -196,7 +196,7 @@ func TestTransit_batchHMAC(t *testing.T) { } // Now, change the key value to something we control - p, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{ Storage: storage, Name: "foo", }, b.GetRandomReader()) diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 4cc25f66c4009..8c43ab593b5b4 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -162,7 +162,7 @@ func (b *backend) pathPolicyWrite(ctx context.Context, req *logical.Request, d * return logical.ErrorResponse(fmt.Sprintf("unknown key type %v", keyType)), logical.ErrInvalidRequest } - p, upserted, err := b.lm.GetPolicy(ctx, polReq, b.GetRandomReader()) + p, upserted, err := b.GetPolicy(ctx, polReq, b.GetRandomReader()) if err != nil { return nil, err } @@ -191,7 +191,7 @@ type asymKey struct { func (b *backend) pathPolicyRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: name, }, b.GetRandomReader()) diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index 9d473d256948d..c32fddc99976a 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -114,7 +114,7 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d * } // Get the policy - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: d.Get("name").(string), }, b.GetRandomReader()) diff --git a/builtin/logical/transit/path_rotate.go b/builtin/logical/transit/path_rotate.go index 3d2c2cdf40453..a74e69980512e 100644 --- a/builtin/logical/transit/path_rotate.go +++ b/builtin/logical/transit/path_rotate.go @@ -31,7 +31,7 @@ func (b *backend) pathRotateWrite(ctx context.Context, req *logical.Request, d * name := d.Get("name").(string) // Get the policy - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: name, }, b.GetRandomReader()) diff --git a/builtin/logical/transit/path_sign_verify.go b/builtin/logical/transit/path_sign_verify.go index 659e6a2091c66..265d63cec1984 100644 --- a/builtin/logical/transit/path_sign_verify.go +++ b/builtin/logical/transit/path_sign_verify.go @@ -246,7 +246,7 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr sigAlgorithm := d.Get("signature_algorithm").(string) // Get the policy - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: name, }, b.GetRandomReader()) @@ -464,7 +464,7 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * sigAlgorithm := d.Get("signature_algorithm").(string) // Get the policy - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: name, }, b.GetRandomReader()) diff --git a/builtin/logical/transit/path_sign_verify_test.go b/builtin/logical/transit/path_sign_verify_test.go index 652dc186c9162..40df90838b19e 100644 --- a/builtin/logical/transit/path_sign_verify_test.go +++ b/builtin/logical/transit/path_sign_verify_test.go @@ -55,7 +55,7 @@ func testTransit_SignVerify_ECDSA(t *testing.T, bits int) { } // Now, change the key value to something we control - p, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{ Storage: storage, Name: "foo", }, b.GetRandomReader()) @@ -377,7 +377,7 @@ func TestTransit_SignVerify_ED25519(t *testing.T) { } // Get the keys for later - fooP, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{ + fooP, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{ Storage: storage, Name: "foo", }, b.GetRandomReader()) @@ -385,7 +385,7 @@ func TestTransit_SignVerify_ED25519(t *testing.T) { t.Fatal(err) } - barP, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{ + barP, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{ Storage: storage, Name: "bar", }, b.GetRandomReader()) diff --git a/builtin/logical/transit/path_trim.go b/builtin/logical/transit/path_trim.go index cec7a5648ef7f..d8587f1c18d49 100644 --- a/builtin/logical/transit/path_trim.go +++ b/builtin/logical/transit/path_trim.go @@ -40,7 +40,7 @@ func (b *backend) pathTrimUpdate() framework.OperationFunc { return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (resp *logical.Response, retErr error) { name := d.Get("name").(string) - p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ Storage: req.Storage, Name: name, }, b.GetRandomReader()) diff --git a/builtin/logical/transit/path_trim_test.go b/builtin/logical/transit/path_trim_test.go index 6b3dfaa9ec2fc..be989b1642459 100644 --- a/builtin/logical/transit/path_trim_test.go +++ b/builtin/logical/transit/path_trim_test.go @@ -36,7 +36,7 @@ func TestTransit_Trim(t *testing.T) { doReq(t, req) // Get the policy and check that the archive has correct number of keys - p, _, err := b.lm.GetPolicy(namespace.RootContext(nil), keysutil.PolicyRequest{ + p, _, err := b.GetPolicy(namespace.RootContext(nil), keysutil.PolicyRequest{ Storage: storage, Name: "aes", }, b.GetRandomReader()) diff --git a/changelog/12418.txt b/changelog/12418.txt new file mode 100644 index 0000000000000..5ec2f6055393b --- /dev/null +++ b/changelog/12418.txt @@ -0,0 +1,4 @@ +```release-note:bug +Enforce minimum cache size for transit backend. +Init cache size on transit backend without restart. +``` diff --git a/sdk/helper/keysutil/lock_manager.go b/sdk/helper/keysutil/lock_manager.go index 039b05ad05356..c6a0a23d61457 100644 --- a/sdk/helper/keysutil/lock_manager.go +++ b/sdk/helper/keysutil/lock_manager.go @@ -101,6 +101,24 @@ func (lm *LockManager) InvalidatePolicy(name string) { } } +func (lm *LockManager) InitCache(cacheSize int) error { + if lm.useCache { + switch { + case cacheSize < 0: + return errors.New("cache size must be greater or equal to zero") + case cacheSize == 0: + lm.cache = NewTransitSyncMap() + case cacheSize > 0: + newLRUCache, err := NewTransitLRU(cacheSize) + if err != nil { + return errwrap.Wrapf("failed to create cache: {{err}}", err) + } + lm.cache = newLRUCache + } + } + return nil +} + // RestorePolicy acquires an exclusive lock on the policy name and restores the // given policy along with the archive. func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storage, name, backup string, force bool) error { diff --git a/website/content/api-docs/secret/transit.mdx b/website/content/api-docs/secret/transit.mdx index e2a27cdd266af..3c63e7eb8c762 100644 --- a/website/content/api-docs/secret/transit.mdx +++ b/website/content/api-docs/secret/transit.mdx @@ -1307,7 +1307,7 @@ using the [`/sys/plugins/reload/backend`][sys-plugin-reload-backend] endpoint. - `size` `(int: 0)` - Specifies the size in terms of number of entries. A size of `0` means unlimited. A _Least Recently Used_ (LRU) caching strategy is used for a - non-zero cache size. + non-zero cache size. Must be 0 (default) or a value greater or equal to 10 (minimum cache size). ### Sample Payload