Skip to content

Commit

Permalink
secret store: AWS connection validation for parameter store and secre…
Browse files Browse the repository at this point in the history
…ts manager (#3301)

Signed-off-by: Elena Kolevska <elena@kolevska.com>
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
  • Loading branch information
elena-kolevska and ItalyPaleAle committed Jan 16, 2024
1 parent 3b0f320 commit 353447c
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 29 deletions.
15 changes: 15 additions & 0 deletions common/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"encoding/json"
"strconv"

"github.com/google/uuid"
"github.com/spf13/cast"
)

Expand Down Expand Up @@ -73,3 +74,17 @@ func Unquote(data []byte) (res string) {
}
return res
}

// GetRandOrDefaultString is used when we need to generate a random string,
// but don't want to fail if the entropy pool is empty (e.g. in tests)
// One example usage is for validating the aws connection on dapr initialisation
func GetRandOrDefaultString(defaultVal string) string {
var res string
if random, err := uuid.NewRandom(); err == nil {
res = random.String()
} else {
res = defaultVal
}

return res
}
35 changes: 27 additions & 8 deletions secretstores/aws/parameterstore/parameterstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package parameterstore

import (
"context"
"errors"
"fmt"
"reflect"

Expand All @@ -23,10 +24,12 @@ import (
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"

awsAuth "github.com/dapr/components-contrib/common/authentication/aws"
"github.com/dapr/components-contrib/common/utils"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/secretstores"
"github.com/dapr/kit/logger"
kitmd "github.com/dapr/kit/metadata"
"github.com/dapr/kit/ptr"
)

// Constant literals.
Expand Down Expand Up @@ -55,23 +58,39 @@ type ssmSecretStore struct {
logger logger.Logger
}

// Init creates a AWS secret manager client.
func (s *ssmSecretStore) Init(_ context.Context, metadata secretstores.Metadata) error {
// Init creates an AWS secret manager client.
func (s *ssmSecretStore) Init(ctx context.Context, metadata secretstores.Metadata) error {
meta, err := s.getSecretManagerMetadata(metadata)
if err != nil {
return err
}

client, err := s.getClient(meta)
if err != nil {
return err
// This check is needed because d.client is set to a mock in tests
if s.client == nil {
s.client, err = s.getClient(meta)
if err != nil {
return err
}
}
s.client = client
s.prefix = meta.Prefix

// Validate client connection
var notFoundErr *ssm.ParameterNotFound
if err := s.validateConnection(ctx); err != nil && !errors.As(err, &notFoundErr) {
return fmt.Errorf("error validating access to the aws.parameterstore secret store: %w", err)
}
return nil
}

// validateConnection runs a dummy GetParameterWithContext operation
// to validate the connection credentials
func (s *ssmSecretStore) validateConnection(ctx context.Context) error {
_, err := s.client.GetParameterWithContext(ctx, &ssm.GetParameterInput{
Name: ptr.Of(s.prefix + utils.GetRandOrDefaultString("dapr-test-param")),
})
return err
}

// GetSecret retrieves a secret using a key and returns a map of decrypted string/string values.
func (s *ssmSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecretRequest) (secretstores.GetSecretResponse, error) {
name := req.Name
Expand All @@ -83,8 +102,8 @@ func (s *ssmSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecr
}

output, err := s.client.GetParameterWithContext(ctx, &ssm.GetParameterInput{
Name: aws.String(s.prefix + name),
WithDecryption: aws.Bool(true),
Name: ptr.Of(s.prefix + name),
WithDecryption: ptr.Of(true),
})
if err != nil {
return secretstores.GetSecretResponse{Data: nil}, fmt.Errorf("couldn't get secret: %s", err)
Expand Down
20 changes: 20 additions & 0 deletions secretstores/aws/parameterstore/parameterstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ func (m *mockedSSM) DescribeParametersWithContext(ctx context.Context, input *ss
func TestInit(t *testing.T) {
m := secretstores.Metadata{}
s := NewParameterStore(logger.NewLogger("test"))
s.(*ssmSecretStore).client = &mockedSSM{
GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) {
// Simulate a non error response from AWS SSM
return nil, nil
},
}

t.Run("Init with valid metadata", func(t *testing.T) {
m.Properties = map[string]string{
"AccessKey": "a",
Expand All @@ -61,6 +68,19 @@ func TestInit(t *testing.T) {
err := s.Init(context.Background(), m)
require.NoError(t, err)
})

t.Run("Init with invalid connection details", func(t *testing.T) {
s.(*ssmSecretStore).client = &mockedSSM{
GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) {
// Simulate a failure that resembles what AWS SSM would return
return nil, fmt.Errorf("wrong-credentials")
},
}

err := s.Init(context.Background(), m)
require.Error(t, err)
require.EqualError(t, err, "error validating access to the aws.parameterstore secret store: wrong-credentials")
})
}

func TestGetSecret(t *testing.T) {
Expand Down
30 changes: 26 additions & 4 deletions secretstores/aws/secretmanager/secretmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@ package secretmanager
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"

"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"

awsAuth "github.com/dapr/components-contrib/common/authentication/aws"
"github.com/dapr/components-contrib/common/utils"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/secretstores"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
)

const (
Expand All @@ -52,22 +55,41 @@ type smSecretStore struct {
logger logger.Logger
}

// Init creates a AWS secret manager client.
func (s *smSecretStore) Init(_ context.Context, metadata secretstores.Metadata) error {
// Init creates an AWS secret manager client.
func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata) error {
meta, err := s.getSecretManagerMetadata(metadata)
if err != nil {
return err
}

client, err := s.getClient(meta)
// This check is needed because d.client is set to a mock in tests
if s.client == nil {
s.client, err = s.getClient(meta)
if err != nil {
return err
}
}
if err != nil {
return err
}
s.client = client

var notFoundErr *secretsmanager.ResourceNotFoundException
if err := s.validateConnection(ctx); err != nil && !errors.As(err, &notFoundErr) {
return fmt.Errorf("error validating access to the aws.secretmanager secret store: %w", err)
}
return nil
}

// validateConnection runs a dummy GetSecretValueWithContext operation
// to validate the connection credentials
func (s *smSecretStore) validateConnection(ctx context.Context) error {
_, err := s.client.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{
SecretId: ptr.Of(utils.GetRandOrDefaultString("dapr-test-secret")),
})

return err
}

// GetSecret retrieves a secret using a key and returns a map of decrypted string/string values.
func (s *smSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecretRequest) (secretstores.GetSecretResponse, error) {
var versionID *string
Expand Down
20 changes: 20 additions & 0 deletions secretstores/aws/secretmanager/secretmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ func (m *mockedSM) GetSecretValueWithContext(ctx context.Context, input *secrets
func TestInit(t *testing.T) {
m := secretstores.Metadata{}
s := NewSecretManager(logger.NewLogger("test"))
s.(*smSecretStore).client = &mockedSM{
GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) {
// Simulate a non error response
return nil, nil
},
}

t.Run("Init with valid metadata", func(t *testing.T) {
m.Properties = map[string]string{
"AccessKey": "a",
Expand All @@ -54,6 +61,19 @@ func TestInit(t *testing.T) {
err := s.Init(context.Background(), m)
require.NoError(t, err)
})

t.Run("Init with invalid connection details", func(t *testing.T) {
s.(*smSecretStore).client = &mockedSM{
GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) {
// Simulate a failure that resembles what AWS SM would return
return nil, fmt.Errorf("wrong-credentials")
},
}

err := s.Init(context.Background(), m)
require.Error(t, err)
require.EqualError(t, err, "error validating access to the aws.secretmanager secret store: wrong-credentials")
})
}

func TestGetSecret(t *testing.T) {
Expand Down
21 changes: 4 additions & 17 deletions state/aws/dynamodb/dynamodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,18 @@ import (
"strconv"
"time"

"github.com/google/uuid"

"github.com/dapr/kit/ptr"

"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
jsoniterator "github.com/json-iterator/go"

awsAuth "github.com/dapr/components-contrib/common/authentication/aws"
"github.com/dapr/components-contrib/common/utils"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/kit/logger"
kitmd "github.com/dapr/kit/metadata"
"github.com/dapr/kit/ptr"
)

// StateStore is a DynamoDB state store.
Expand Down Expand Up @@ -83,7 +81,7 @@ func (d *StateStore) Init(ctx context.Context, metadata state.Metadata) error {
return err
}

// We have this check because we need to set the client to a mock in tests
// This check is needed because d.client is set to a mock in tests
if d.client == nil {
d.client, err = d.getClient(meta)
if err != nil {
Expand All @@ -104,23 +102,12 @@ func (d *StateStore) Init(ctx context.Context, metadata state.Metadata) error {
// validateConnection runs a dummy Get operation to validate the connection credentials,
// as well as validating that the table exists, and we have access to it
func (d *StateStore) validateTableAccess(ctx context.Context) error {
var tableName string
if random, err := uuid.NewRandom(); err == nil {
tableName = random.String()
} else {
// We would get to this block if the entropy pool is empty.
// We don't want to fail initialising Dapr because of it though,
// since it's a dummy table that is only needed to check access, anyway
// So we'll just use a hardcoded table name
tableName = "dapr-test-table"
}

input := &dynamodb.GetItemInput{
ConsistentRead: ptr.Of(false),
TableName: ptr.Of(d.table),
Key: map[string]*dynamodb.AttributeValue{
d.partitionKey: {
S: ptr.Of(tableName),
S: ptr.Of(utils.GetRandOrDefaultString("dapr-test-table")),
},
},
}
Expand Down

0 comments on commit 353447c

Please sign in to comment.