Skip to content

Commit

Permalink
Enforce minimum cache size for transit backend (#12418)
Browse files Browse the repository at this point in the history
* Enforce Minimum cache size for transit backend

* enfore minimum cache size and log a warning during backend construction

* Update documentation for transit backend cache configuration

* Added changelog

* Addressed review feedback and added unit test

* Modify code in pathCacheConfigWrite to make use of the updated cache size

* Updated code to refresh cache size on transit backend without restart

* Update code to acquire read and write locks appropriately
  • Loading branch information
divyapola5 committed Sep 13, 2021
1 parent dd19e12 commit 94d4fdb
Show file tree
Hide file tree
Showing 20 changed files with 141 additions and 30 deletions.
50 changes: 49 additions & 1 deletion builtin/logical/transit/backend.go
Expand Up @@ -3,13 +3,18 @@ package transit
import (
"context"
"fmt"
"io"
"strings"
"sync"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/logical"
)

// Minimum cache size for transit backend
const minCacheSize = 10

func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err := Backend(ctx, conf)
if err != nil {
Expand Down Expand Up @@ -68,6 +73,11 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error)
if err != nil {
return nil, fmt.Errorf("Error retrieving cache size from storage: %w", err)
}

if cacheSize != 0 && cacheSize < minCacheSize {
b.Logger().Warn("size %d is less than minimum %d. Cache size is set to %d", cacheSize, minCacheSize, minCacheSize)
cacheSize = minCacheSize
}
}

var err error
Expand All @@ -82,6 +92,9 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error)
type backend struct {
*framework.Backend
lm *keysutil.LockManager
// Lock to make changes to any of the backend's cache configuration.
configMutex sync.RWMutex
cacheSizeChanged bool
}

func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error) {
Expand All @@ -100,13 +113,48 @@ func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error
return size, nil
}

func (b *backend) invalidate(_ context.Context, key string) {
// Update cache size and get policy
func (b *backend) GetPolicy(ctx context.Context, polReq keysutil.PolicyRequest, rand io.Reader) (retP *keysutil.Policy, retUpserted bool, retErr error) {
// Acquire read lock to read cacheSizeChanged
b.configMutex.RLock()
if b.lm.GetUseCache() && b.cacheSizeChanged {
var err error
currentCacheSize := b.lm.GetCacheSize()
storedCacheSize, err := GetCacheSizeFromStorage(ctx, polReq.Storage)
if err != nil {
return nil, false, err
}
if currentCacheSize != storedCacheSize {
err = b.lm.InitCache(storedCacheSize)
if err != nil {
return nil, false, err
}
}
// Release the read lock and acquire the write lock
b.configMutex.RUnlock()
b.configMutex.Lock()
defer b.configMutex.Unlock()
b.cacheSizeChanged = false
}
p, _, err := b.lm.GetPolicy(ctx, polReq, rand)
if err != nil {
return p, false, err
}
return p, true, nil
}

func (b *backend) invalidate(ctx context.Context, key string) {
if b.Logger().IsDebug() {
b.Logger().Debug("invalidating key", "key", key)
}
switch {
case strings.HasPrefix(key, "policy/"):
name := strings.TrimPrefix(key, "policy/")
b.lm.InvalidatePolicy(name)
case strings.HasPrefix(key, "cache-config/"):
// Acquire the lock to set the flag to indicate that cache size needs to be refreshed from storage
b.configMutex.Lock()
defer b.configMutex.Unlock()
b.cacheSizeChanged = true
}
}
12 changes: 6 additions & 6 deletions builtin/logical/transit/path_cache_config.go
Expand Up @@ -3,7 +3,6 @@ package transit
import (
"context"
"errors"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
Expand Down Expand Up @@ -45,8 +44,8 @@ func (b *backend) pathCacheConfig() *framework.Path {
func (b *backend) pathCacheConfigWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// get target size
cacheSize := d.Get("size").(int)
if cacheSize < 0 {
return logical.ErrorResponse("size must be greater or equal to 0"), logical.ErrInvalidRequest
if cacheSize != 0 && cacheSize < minCacheSize {
return logical.ErrorResponse("size must be 0 or a value greater or equal to %d", minCacheSize), logical.ErrInvalidRequest
}

// store cache size
Expand All @@ -60,11 +59,12 @@ func (b *backend) pathCacheConfigWrite(ctx context.Context, req *logical.Request
return nil, err
}

resp := &logical.Response{
Warnings: []string{"cache configurations will be applied when this backend is restarted"},
err = b.lm.InitCache(cacheSize)
if err != nil {
return nil, err
}

return resp, nil
return nil, nil
}

type configCache struct {
Expand Down
43 changes: 42 additions & 1 deletion builtin/logical/transit/path_cache_config_test.go
Expand Up @@ -8,6 +8,7 @@ import (
)

const targetCacheSize = 12345
const smallCacheSize = 3

func TestTransit_CacheConfig(t *testing.T) {
b1, storage := createBackendWithSysView(t)
Expand Down Expand Up @@ -58,17 +59,51 @@ func TestTransit_CacheConfig(t *testing.T) {
},
}

writeSmallCacheSizeReq := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "cache-config",
Data: map[string]interface{}{
"size": smallCacheSize,
},
}

readReq := &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: "cache-config",
}

polReq := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/aes256",
Data: map[string]interface{}{
"derived": true,
},
}


// test steps
// b1 should spin up with an unlimited cache
validateResponse(doReq(b1, readReq), 0, false)

// Change cache size to targetCacheSize 12345 and validate that cache size is updated
doReq(b1, writeReq)
validateResponse(doReq(b1, readReq), targetCacheSize, true)
validateResponse(doReq(b1, readReq), targetCacheSize, false)
b1.invalidate(context.Background(), "cache-config/")

// Change the cache size to 1000 to mock the scenario where
// current cache size and stored cache size are different and
// a cache update is needed
b1.lm.InitCache(1000)

// Write a new policy which in its code path detects that cache size has changed
// and refreshes the cache to 12345
doReq(b1, polReq)

// Validate that cache size is updated to 12345
validateResponse(doReq(b1, readReq), targetCacheSize, false)

// b2 should spin up with a configured cache
b2 := createBackendWithSysViewWithStorage(t, storage)
Expand All @@ -77,4 +112,10 @@ func TestTransit_CacheConfig(t *testing.T) {
// b3 enables transit without a cache, trying to read it should error
b3 := createBackendWithForceNoCacheWithSysViewWithStorage(t, storage)
doErrReq(b3, readReq)

// b4 should spin up with a size less than minimum cache size (10)
b4, storage := createBackendWithSysView(t)
doErrReq(b4, writeSmallCacheSizeReq)


}
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_config.go
Expand Up @@ -62,7 +62,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
name := d.Get("name").(string)

// Check if the policy already exists before we lock everything
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_datakey.go
Expand Up @@ -99,7 +99,7 @@ func (b *backend) pathDatakeyWrite(ctx context.Context, req *logical.Request, d
}

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_decrypt.go
Expand Up @@ -121,7 +121,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
}

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: d.Get("name").(string),
}, b.GetRandomReader())
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/transit/path_encrypt.go
Expand Up @@ -217,7 +217,7 @@ func decodeBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error {

func (b *backend) pathEncryptExistenceCheck(ctx context.Context, req *logical.Request, d *framework.FieldData) (bool, error) {
name := d.Get("name").(string)
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down Expand Up @@ -336,7 +336,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
}
}

p, upserted, err = b.lm.GetPolicy(ctx, polReq, b.GetRandomReader())
p, upserted, err = b.GetPolicy(ctx, polReq, b.GetRandomReader())
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_export.go
Expand Up @@ -64,7 +64,7 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request
return logical.ErrorResponse(fmt.Sprintf("invalid export type: %s", exportType)), logical.ErrInvalidRequest
}

p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/transit/path_hmac.go
Expand Up @@ -96,7 +96,7 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr
}

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down Expand Up @@ -224,7 +224,7 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f
}

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/transit/path_hmac_test.go
Expand Up @@ -26,7 +26,7 @@ func TestTransit_HMAC(t *testing.T) {
}

// Now, change the key value to something we control
p, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{
p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: "foo",
}, b.GetRandomReader())
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestTransit_batchHMAC(t *testing.T) {
}

// Now, change the key value to something we control
p, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{
p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: "foo",
}, b.GetRandomReader())
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/transit/path_keys.go
Expand Up @@ -162,7 +162,7 @@ func (b *backend) pathPolicyWrite(ctx context.Context, req *logical.Request, d *
return logical.ErrorResponse(fmt.Sprintf("unknown key type %v", keyType)), logical.ErrInvalidRequest
}

p, upserted, err := b.lm.GetPolicy(ctx, polReq, b.GetRandomReader())
p, upserted, err := b.GetPolicy(ctx, polReq, b.GetRandomReader())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -191,7 +191,7 @@ type asymKey struct {
func (b *backend) pathPolicyRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)

p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_rewrap.go
Expand Up @@ -114,7 +114,7 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d *
}

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: d.Get("name").(string),
}, b.GetRandomReader())
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_rotate.go
Expand Up @@ -31,7 +31,7 @@ func (b *backend) pathRotateWrite(ctx context.Context, req *logical.Request, d *
name := d.Get("name").(string)

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/transit/path_sign_verify.go
Expand Up @@ -246,7 +246,7 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
sigAlgorithm := d.Get("signature_algorithm").(string)

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down Expand Up @@ -464,7 +464,7 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
sigAlgorithm := d.Get("signature_algorithm").(string)

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
6 changes: 3 additions & 3 deletions builtin/logical/transit/path_sign_verify_test.go
Expand Up @@ -55,7 +55,7 @@ func testTransit_SignVerify_ECDSA(t *testing.T, bits int) {
}

// Now, change the key value to something we control
p, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{
p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: "foo",
}, b.GetRandomReader())
Expand Down Expand Up @@ -377,15 +377,15 @@ func TestTransit_SignVerify_ED25519(t *testing.T) {
}

// Get the keys for later
fooP, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{
fooP, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: "foo",
}, b.GetRandomReader())
if err != nil {
t.Fatal(err)
}

barP, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{
barP, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: "bar",
}, b.GetRandomReader())
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_trim.go
Expand Up @@ -40,7 +40,7 @@ func (b *backend) pathTrimUpdate() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (resp *logical.Response, retErr error) {
name := d.Get("name").(string)

p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_trim_test.go
Expand Up @@ -36,7 +36,7 @@ func TestTransit_Trim(t *testing.T) {
doReq(t, req)

// Get the policy and check that the archive has correct number of keys
p, _, err := b.lm.GetPolicy(namespace.RootContext(nil), keysutil.PolicyRequest{
p, _, err := b.GetPolicy(namespace.RootContext(nil), keysutil.PolicyRequest{
Storage: storage,
Name: "aes",
}, b.GetRandomReader())
Expand Down
4 changes: 4 additions & 0 deletions changelog/12418.txt
@@ -0,0 +1,4 @@
```release-note:bug
Enforce minimum cache size for transit backend.
Init cache size on transit backend without restart.
```

0 comments on commit 94d4fdb

Please sign in to comment.