Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce minimum cache size for transit backend #12418

Merged
merged 13 commits into from Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 49 additions & 1 deletion builtin/logical/transit/backend.go
Expand Up @@ -3,13 +3,18 @@ package transit
import (
"context"
"fmt"
"io"
"strings"
"sync"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/logical"
)

// Minimum cache size for transit backend
const minCacheSize = 10

func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err := Backend(ctx, conf)
if err != nil {
Expand Down Expand Up @@ -68,6 +73,11 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error)
if err != nil {
return nil, fmt.Errorf("Error retrieving cache size from storage: %w", err)
}

if cacheSize != 0 && cacheSize < minCacheSize {
b.Logger().Warn("size %d is less than minimum %d. Cache size is set to %d", cacheSize, minCacheSize, minCacheSize)
cacheSize = minCacheSize
}
}

var err error
Expand All @@ -82,6 +92,9 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error)
type backend struct {
*framework.Backend
lm *keysutil.LockManager
// Lock to make changes to any of the backend's cache configuration.
configMutex sync.RWMutex
cacheSizeChanged bool
}

func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error) {
Expand All @@ -100,13 +113,48 @@ func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error
return size, nil
}

func (b *backend) invalidate(_ context.Context, key string) {
// Update cache size and get policy
func (b *backend) GetPolicy(ctx context.Context, polReq keysutil.PolicyRequest, rand io.Reader) (retP *keysutil.Policy, retUpserted bool, retErr error) {
// Acquire read lock to read cacheSizeChanged
b.configMutex.RLock()
if b.lm.GetUseCache() && b.cacheSizeChanged {
divyapola5 marked this conversation as resolved.
Show resolved Hide resolved
var err error
currentCacheSize := b.lm.GetCacheSize()
storedCacheSize, err := GetCacheSizeFromStorage(ctx, polReq.Storage)
if err != nil {
return nil, false, err
}
if currentCacheSize != storedCacheSize {
err = b.lm.InitCache(storedCacheSize)
if err != nil {
return nil, false, err
}
}
// Release the read lock and acquire the write lock
b.configMutex.RUnlock()
b.configMutex.Lock()
defer b.configMutex.Unlock()
b.cacheSizeChanged = false
divyapola5 marked this conversation as resolved.
Show resolved Hide resolved
}
p, _, err := b.lm.GetPolicy(ctx, polReq, rand)
if err != nil {
return p, false, err
}
return p, true, nil
}

func (b *backend) invalidate(ctx context.Context, key string) {
if b.Logger().IsDebug() {
b.Logger().Debug("invalidating key", "key", key)
}
switch {
case strings.HasPrefix(key, "policy/"):
name := strings.TrimPrefix(key, "policy/")
b.lm.InvalidatePolicy(name)
case strings.HasPrefix(key, "cache-config/"):
// Acquire the lock to set the flag to indicate that cache size needs to be refreshed from storage
b.configMutex.Lock()
defer b.configMutex.Unlock()
b.cacheSizeChanged = true
}
}
12 changes: 6 additions & 6 deletions builtin/logical/transit/path_cache_config.go
Expand Up @@ -3,7 +3,6 @@ package transit
import (
"context"
"errors"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
Expand Down Expand Up @@ -45,8 +44,8 @@ func (b *backend) pathCacheConfig() *framework.Path {
func (b *backend) pathCacheConfigWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// get target size
cacheSize := d.Get("size").(int)
if cacheSize < 0 {
return logical.ErrorResponse("size must be greater or equal to 0"), logical.ErrInvalidRequest
if cacheSize != 0 && cacheSize < minCacheSize {
return logical.ErrorResponse("size must be 0 or a value greater or equal to %d", minCacheSize), logical.ErrInvalidRequest
}

// store cache size
Expand All @@ -60,11 +59,12 @@ func (b *backend) pathCacheConfigWrite(ctx context.Context, req *logical.Request
return nil, err
}

resp := &logical.Response{
Warnings: []string{"cache configurations will be applied when this backend is restarted"},
err = b.lm.InitCache(cacheSize)
if err != nil {
return nil, err
}

return resp, nil
return nil, nil
}

type configCache struct {
Expand Down
43 changes: 42 additions & 1 deletion builtin/logical/transit/path_cache_config_test.go
Expand Up @@ -8,6 +8,7 @@ import (
)

const targetCacheSize = 12345
const smallCacheSize = 3

func TestTransit_CacheConfig(t *testing.T) {
b1, storage := createBackendWithSysView(t)
Expand Down Expand Up @@ -58,17 +59,51 @@ func TestTransit_CacheConfig(t *testing.T) {
},
}

writeSmallCacheSizeReq := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "cache-config",
Data: map[string]interface{}{
"size": smallCacheSize,
},
}

readReq := &logical.Request{
Storage: storage,
Operation: logical.ReadOperation,
Path: "cache-config",
}

polReq := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/aes256",
Data: map[string]interface{}{
"derived": true,
},
}


// test steps
// b1 should spin up with an unlimited cache
validateResponse(doReq(b1, readReq), 0, false)

// Change cache size to targetCacheSize 12345 and validate that cache size is updated
doReq(b1, writeReq)
validateResponse(doReq(b1, readReq), targetCacheSize, true)
validateResponse(doReq(b1, readReq), targetCacheSize, false)
b1.invalidate(context.Background(), "cache-config/")

// Change the cache size to 1000 to mock the scenario where
// current cache size and stored cache size are different and
// a cache update is needed
b1.lm.InitCache(1000)

// Write a new policy which in its code path detects that cache size has changed
// and refreshes the cache to 12345
doReq(b1, polReq)

// Validate that cache size is updated to 12345
validateResponse(doReq(b1, readReq), targetCacheSize, false)

// b2 should spin up with a configured cache
b2 := createBackendWithSysViewWithStorage(t, storage)
Expand All @@ -77,4 +112,10 @@ func TestTransit_CacheConfig(t *testing.T) {
// b3 enables transit without a cache, trying to read it should error
b3 := createBackendWithForceNoCacheWithSysViewWithStorage(t, storage)
doErrReq(b3, readReq)

// b4 should spin up with a size less than minimum cache size (10)
b4, storage := createBackendWithSysView(t)
doErrReq(b4, writeSmallCacheSizeReq)


}
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_config.go
Expand Up @@ -62,7 +62,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
name := d.Get("name").(string)

// Check if the policy already exists before we lock everything
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_datakey.go
Expand Up @@ -99,7 +99,7 @@ func (b *backend) pathDatakeyWrite(ctx context.Context, req *logical.Request, d
}

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_decrypt.go
Expand Up @@ -121,7 +121,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
}

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: d.Get("name").(string),
}, b.GetRandomReader())
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/transit/path_encrypt.go
Expand Up @@ -217,7 +217,7 @@ func decodeBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error {

func (b *backend) pathEncryptExistenceCheck(ctx context.Context, req *logical.Request, d *framework.FieldData) (bool, error) {
name := d.Get("name").(string)
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down Expand Up @@ -336,7 +336,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
}
}

p, upserted, err = b.lm.GetPolicy(ctx, polReq, b.GetRandomReader())
p, upserted, err = b.GetPolicy(ctx, polReq, b.GetRandomReader())
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_export.go
Expand Up @@ -64,7 +64,7 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request
return logical.ErrorResponse(fmt.Sprintf("invalid export type: %s", exportType)), logical.ErrInvalidRequest
}

p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/transit/path_hmac.go
Expand Up @@ -96,7 +96,7 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr
}

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down Expand Up @@ -224,7 +224,7 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f
}

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/transit/path_hmac_test.go
Expand Up @@ -26,7 +26,7 @@ func TestTransit_HMAC(t *testing.T) {
}

// Now, change the key value to something we control
p, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{
p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: "foo",
}, b.GetRandomReader())
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestTransit_batchHMAC(t *testing.T) {
}

// Now, change the key value to something we control
p, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{
p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: "foo",
}, b.GetRandomReader())
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/transit/path_keys.go
Expand Up @@ -162,7 +162,7 @@ func (b *backend) pathPolicyWrite(ctx context.Context, req *logical.Request, d *
return logical.ErrorResponse(fmt.Sprintf("unknown key type %v", keyType)), logical.ErrInvalidRequest
}

p, upserted, err := b.lm.GetPolicy(ctx, polReq, b.GetRandomReader())
p, upserted, err := b.GetPolicy(ctx, polReq, b.GetRandomReader())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -191,7 +191,7 @@ type asymKey struct {
func (b *backend) pathPolicyRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)

p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_rewrap.go
Expand Up @@ -114,7 +114,7 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d *
}

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: d.Get("name").(string),
}, b.GetRandomReader())
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_rotate.go
Expand Up @@ -31,7 +31,7 @@ func (b *backend) pathRotateWrite(ctx context.Context, req *logical.Request, d *
name := d.Get("name").(string)

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/transit/path_sign_verify.go
Expand Up @@ -246,7 +246,7 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
sigAlgorithm := d.Get("signature_algorithm").(string)

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down Expand Up @@ -464,7 +464,7 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
sigAlgorithm := d.Get("signature_algorithm").(string)

// Get the policy
p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
6 changes: 3 additions & 3 deletions builtin/logical/transit/path_sign_verify_test.go
Expand Up @@ -55,7 +55,7 @@ func testTransit_SignVerify_ECDSA(t *testing.T, bits int) {
}

// Now, change the key value to something we control
p, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{
p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: "foo",
}, b.GetRandomReader())
Expand Down Expand Up @@ -377,15 +377,15 @@ func TestTransit_SignVerify_ED25519(t *testing.T) {
}

// Get the keys for later
fooP, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{
fooP, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: "foo",
}, b.GetRandomReader())
if err != nil {
t.Fatal(err)
}

barP, _, err := b.lm.GetPolicy(context.Background(), keysutil.PolicyRequest{
barP, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: "bar",
}, b.GetRandomReader())
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_trim.go
Expand Up @@ -40,7 +40,7 @@ func (b *backend) pathTrimUpdate() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (resp *logical.Response, retErr error) {
name := d.Get("name").(string)

p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_trim_test.go
Expand Up @@ -36,7 +36,7 @@ func TestTransit_Trim(t *testing.T) {
doReq(t, req)

// Get the policy and check that the archive has correct number of keys
p, _, err := b.lm.GetPolicy(namespace.RootContext(nil), keysutil.PolicyRequest{
p, _, err := b.GetPolicy(namespace.RootContext(nil), keysutil.PolicyRequest{
Storage: storage,
Name: "aes",
}, b.GetRandomReader())
Expand Down
4 changes: 4 additions & 0 deletions changelog/12418.txt
@@ -0,0 +1,4 @@
```release-note:bug
Enforce minimum cache size for transit backend.
Init cache size on transit backend without restart.
```