Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

state.dynamodb: validate AWS connection #3285

Merged
merged 12 commits into from
Jan 2, 2024
36 changes: 30 additions & 6 deletions state/aws/dynamodb/dynamodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
"encoding/binary"
"errors"
"fmt"
"github.com/google/uuid"

Check failure on line 22 in state/aws/dynamodb/dynamodb.go

View workflow job for this annotation

GitHub Actions / Build linux_amd64 binaries

File is not `gofumpt`-ed (gofumpt)
"reflect"
"strconv"
"time"

Check failure on line 25 in state/aws/dynamodb/dynamodb.go

View workflow job for this annotation

GitHub Actions / Build linux_amd64 binaries

File is not `gofumpt`-ed (gofumpt)

Check failure on line 26 in state/aws/dynamodb/dynamodb.go

View workflow job for this annotation

GitHub Actions / Build linux_amd64 binaries

File is not `goimports`-ed with -local github.com/dapr/ (goimports)
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
Expand Down Expand Up @@ -74,25 +75,48 @@
}

// Init does metadata and connection parsing.
func (d *StateStore) Init(_ context.Context, metadata state.Metadata) error {
func (d *StateStore) Init(ctx context.Context, metadata state.Metadata) error {
meta, err := d.getDynamoDBMetadata(metadata)
if err != nil {
return err
}

client, err := d.getClient(meta)
if err != nil {
return err
// We have this check because we need to set the client to a mock in tests
if d.client == nil {
client, err := d.getClient(meta)
if err != nil {
return err
}
d.client = client
elena-kolevska marked this conversation as resolved.
Show resolved Hide resolved
}

d.client = client
d.table = meta.Table
d.ttlAttributeName = meta.TTLAttributeName
d.partitionKey = meta.PartitionKey

if err := d.validateTableAccess(ctx); err != nil {
return fmt.Errorf("error validating DynamoDB table '%s' access: %w", d.table, err)
}

return nil
}

// validateConnection runs a dummy Get operation to validate the connection credentials,
// as well as validating that the table exists, and we have access to it
func (d *StateStore) validateTableAccess(ctx context.Context) error {
input := &dynamodb.GetItemInput{
ConsistentRead: aws.Bool(false),
elena-kolevska marked this conversation as resolved.
Show resolved Hide resolved
TableName: aws.String(d.table),
elena-kolevska marked this conversation as resolved.
Show resolved Hide resolved
Key: map[string]*dynamodb.AttributeValue{
d.partitionKey: {
S: aws.String(uuid.NewString()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uuid.NewString can panic (if the kernel doesn't have entropy). Please use uuid.NewRandom and then call .String() on the result

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, ptr.Of

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ItalyPaleAle! Check out if you like how I did this. In case of error, I chose to use a fixed string, since I didn't want to have the initialisation of dapr fail because of it; it's just a dummy check anyway.

},
},
}

_, err := d.client.GetItemWithContext(ctx, input)
return err
}

// Features returns the features available in this state store.
func (d *StateStore) Features() []state.Feature {
// TTLs are enabled only if ttlAttributeName is set
Expand Down
24 changes: 24 additions & 0 deletions state/aws/dynamodb/dynamodb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package dynamodb

import (
"context"
"errors"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -76,6 +77,12 @@ func TestInit(t *testing.T) {
m := state.Metadata{}
s := &StateStore{
partitionKey: defaultPartitionKeyName,
client: &mockedDynamoDB{
// We're adding this so we can pass the connection check on Init
GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) {
return nil, nil
},
},
}

t.Run("NewDynamoDBStateStore Default Partition Key", func(t *testing.T) {
Expand Down Expand Up @@ -124,6 +131,23 @@ func TestInit(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, s.partitionKey, pkey)
})

t.Run("Init with bad table name or permissions", func(t *testing.T) {
m.Properties = map[string]string{
"Table": "does-not-exist",
"Region": "eu-west-1",
}

s.client = &mockedDynamoDB{
GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) {
return nil, errors.New("Requested resource not found")
},
}

err := s.Init(context.Background(), m)
require.Error(t, err)
require.Equal(t, "error validating DynamoDB table 'does-not-exist' access: Requested resource not found", err.Error())
elena-kolevska marked this conversation as resolved.
Show resolved Hide resolved
})
}

func TestGet(t *testing.T) {
Expand Down