Skip to content

Commit

Permalink
awsutil: add ability to mock IAM and STS APIs (#12)
Browse files Browse the repository at this point in the history
* awsutil: add ability to mock IAM and STS APIs

This adds the ability to supply mock IAM and STS interfaces to use with
the RotateKeys, CreateAccessKey, DeleteAccessKey, and GetCallerIdentity
methods.

Additionally, work has been done to incorporate the existing IAM mock
object into the package, in addition to adding a similar STS mock that
allows for the mocking of the GetCallerIdentity method.

Factory functions are also included to allow these to be incorporated
seamlessly into the call path, in addition to allowing for introspection
into any session data being received by the constructor functions.

Also adds MockOptionErr to expose mocking out an options error.
  • Loading branch information
vancluever committed Oct 6, 2021
1 parent 5b6393a commit f7bda98
Show file tree
Hide file tree
Showing 8 changed files with 767 additions and 42 deletions.
90 changes: 90 additions & 0 deletions awsutil/clients.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package awsutil

import (
"errors"
"fmt"

"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

// IAMAPIFunc is a factory function for returning an IAM interface,
// useful for supplying mock interfaces for testing IAM. The session
// is passed into the function in the same way as done with the
// standard iam.New() constructor.
type IAMAPIFunc func(sess *session.Session) (iamiface.IAMAPI, error)

// STSAPIFunc is a factory function for returning a STS interface,
// useful for supplying mock interfaces for testing STS. The session
// is passed into the function in the same way as done with the
// standard sts.New() constructor.
type STSAPIFunc func(sess *session.Session) (stsiface.STSAPI, error)

// IAMClient returns an IAM client.
//
// Supported options: WithSession, WithIAMAPIFunc.
//
// If WithIAMAPIFunc is supplied, the included function is used as
// the IAM client constructor instead. This can be used for Mocking
// the IAM API.
func (c *CredentialsConfig) IAMClient(opt ...Option) (iamiface.IAMAPI, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("error reading options: %w", err)
}

sess := opts.withAwsSession
if sess == nil {
sess, err = c.GetSession(opt...)
if err != nil {
return nil, fmt.Errorf("error calling GetSession: %w", err)
}
}

if opts.withIAMAPIFunc != nil {
return opts.withIAMAPIFunc(sess)
}

client := iam.New(sess)
if client == nil {
return nil, errors.New("could not obtain iam client from session")
}

return client, nil
}

// STSClient returns a STS client.
//
// Supported options: WithSession, WithSTSAPIFunc.
//
// If WithSTSAPIFunc is supplied, the included function is used as
// the STS client constructor instead. This can be used for Mocking
// the STS API.
func (c *CredentialsConfig) STSClient(opt ...Option) (stsiface.STSAPI, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("error reading options: %w", err)
}

sess := opts.withAwsSession
if sess == nil {
sess, err = c.GetSession(opt...)
if err != nil {
return nil, fmt.Errorf("error calling GetSession: %w", err)
}
}

if opts.withSTSAPIFunc != nil {
return opts.withSTSAPIFunc(sess)
}

client := sts.New(sess)
if client == nil {
return nil, errors.New("could not obtain sts client from session")
}

return client, nil
}
137 changes: 137 additions & 0 deletions awsutil/clients_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package awsutil

import (
"errors"
"fmt"
"testing"

"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/stretchr/testify/require"
)

const testOptionErr = "test option error"
const testBadClientType = "badclienttype"

func testWithBadClientType(o *options) error {
o.withClientType = testBadClientType
return nil
}

func TestCredentialsConfigIAMClient(t *testing.T) {
cases := []struct {
name string
credentialsConfig *CredentialsConfig
opts []Option
require func(t *testing.T, actual iamiface.IAMAPI)
requireErr string
}{
{
name: "options error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{MockOptionErr(errors.New(testOptionErr))},
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr),
},
{
name: "session error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{testWithBadClientType},
requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType),
},
{
name: "with mock IAM session",
credentialsConfig: &CredentialsConfig{},
opts: []Option{WithIAMAPIFunc(NewMockIAM())},
require: func(t *testing.T, actual iamiface.IAMAPI) {
t.Helper()
require := require.New(t)
require.Equal(&MockIAM{}, actual)
},
},
{
name: "no mock client",
credentialsConfig: &CredentialsConfig{},
opts: []Option{},
require: func(t *testing.T, actual iamiface.IAMAPI) {
t.Helper()
require := require.New(t)
require.IsType(&iam.IAM{}, actual)
},
},
}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
actual, err := tc.credentialsConfig.IAMClient(tc.opts...)
if tc.requireErr != "" {
require.EqualError(err, tc.requireErr)
return
}

require.NoError(err)
tc.require(t, actual)
})
}
}

func TestCredentialsConfigSTSClient(t *testing.T) {
cases := []struct {
name string
credentialsConfig *CredentialsConfig
opts []Option
require func(t *testing.T, actual stsiface.STSAPI)
requireErr string
}{
{
name: "options error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{MockOptionErr(errors.New(testOptionErr))},
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr),
},
{
name: "session error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{testWithBadClientType},
requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType),
},
{
name: "with mock STS session",
credentialsConfig: &CredentialsConfig{},
opts: []Option{WithSTSAPIFunc(NewMockSTS())},
require: func(t *testing.T, actual stsiface.STSAPI) {
t.Helper()
require := require.New(t)
require.Equal(&MockSTS{}, actual)
},
},
{
name: "no mock client",
credentialsConfig: &CredentialsConfig{},
opts: []Option{},
require: func(t *testing.T, actual stsiface.STSAPI) {
t.Helper()
require := require.New(t)
require.IsType(&sts.STS{}, actual)
},
},
}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
actual, err := tc.credentialsConfig.STSClient(tc.opts...)
if tc.requireErr != "" {
require.EqualError(err, tc.requireErr)
return
}

require.NoError(err)
tc.require(t, actual)
})
}
}
142 changes: 140 additions & 2 deletions awsutil/mocks.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,164 @@
package awsutil

import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

// MockOptionErr provides a mock option error for use with testing.
func MockOptionErr(withErr error) Option {
return func(_ *options) error {
return withErr
}
}

// MockIAM provides a way to mock the AWS IAM API.
type MockIAM struct {
iamiface.IAMAPI

CreateAccessKeyOutput *iam.CreateAccessKeyOutput
DeleteAccessKeyOutput *iam.DeleteAccessKeyOutput
CreateAccessKeyError error
DeleteAccessKeyError error
GetUserOutput *iam.GetUserOutput
GetUserError error
}

// MockIAMOption is a function for setting the various fields on a MockIAM
// object.
type MockIAMOption func(m *MockIAM) error

// WithCreateAccessKeyOutput sets the output for the CreateAccessKey method.
func WithCreateAccessKeyOutput(o *iam.CreateAccessKeyOutput) MockIAMOption {
return func(m *MockIAM) error {
m.CreateAccessKeyOutput = o
return nil
}
}

// WithCreateAccessKeyError sets the error output for the CreateAccessKey
// method.
func WithCreateAccessKeyError(e error) MockIAMOption {
return func(m *MockIAM) error {
m.CreateAccessKeyError = e
return nil
}
}

// WithDeleteAccessKeyError sets the error output for the DeleteAccessKey
// method.
func WithDeleteAccessKeyError(e error) MockIAMOption {
return func(m *MockIAM) error {
m.DeleteAccessKeyError = e
return nil
}
}

// WithGetUserOutput sets the output for the GetUser method.
func WithGetUserOutput(o *iam.GetUserOutput) MockIAMOption {
return func(m *MockIAM) error {
m.GetUserOutput = o
return nil
}
}

// WithGetUserError sets the error output for the GetUser method.
func WithGetUserError(e error) MockIAMOption {
return func(m *MockIAM) error {
m.GetUserError = e
return nil
}
}

// NewMockIAM provides a factory function to use with the WithIAMAPIFunc
// option.
func NewMockIAM(opts ...MockIAMOption) IAMAPIFunc {
return func(_ *session.Session) (iamiface.IAMAPI, error) {
m := new(MockIAM)
for _, opt := range opts {
if err := opt(m); err != nil {
return nil, err
}
}

return m, nil
}
}

func (m *MockIAM) CreateAccessKey(*iam.CreateAccessKeyInput) (*iam.CreateAccessKeyOutput, error) {
if m.CreateAccessKeyError != nil {
return nil, m.CreateAccessKeyError
}

return m.CreateAccessKeyOutput, nil
}

func (m *MockIAM) DeleteAccessKey(*iam.DeleteAccessKeyInput) (*iam.DeleteAccessKeyOutput, error) {
return m.DeleteAccessKeyOutput, nil
return &iam.DeleteAccessKeyOutput{}, m.DeleteAccessKeyError
}

func (m *MockIAM) GetUser(*iam.GetUserInput) (*iam.GetUserOutput, error) {
if m.GetUserError != nil {
return nil, m.GetUserError
}

return m.GetUserOutput, nil
}

// MockSTS provides a way to mock the AWS STS API.
type MockSTS struct {
stsiface.STSAPI

GetCallerIdentityOutput *sts.GetCallerIdentityOutput
GetCallerIdentityError error
}

// MockSTSOption is a function for setting the various fields on a MockSTS
// object.
type MockSTSOption func(m *MockSTS) error

// WithGetCallerIdentityOutput sets the output for the GetCallerIdentity
// method.
func WithGetCallerIdentityOutput(o *sts.GetCallerIdentityOutput) MockSTSOption {
return func(m *MockSTS) error {
m.GetCallerIdentityOutput = o
return nil
}
}

// WithGetCallerIdentityError sets the error output for the GetCallerIdentity
// method.
func WithGetCallerIdentityError(e error) MockSTSOption {
return func(m *MockSTS) error {
m.GetCallerIdentityError = e
return nil
}
}

// NewMockSTS provides a factory function to use with the WithSTSAPIFunc
// option.
//
// If withGetCallerIdentityError is supplied, calls to GetCallerIdentity will
// return the supplied error. Otherwise, a basic mock API output is returned.
func NewMockSTS(opts ...MockSTSOption) STSAPIFunc {
return func(_ *session.Session) (stsiface.STSAPI, error) {
m := new(MockSTS)
for _, opt := range opts {
if err := opt(m); err != nil {
return nil, err
}
}

return m, nil
}
}

func (m *MockSTS) GetCallerIdentity(_ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) {
if m.GetCallerIdentityError != nil {
return nil, m.GetCallerIdentityError
}

return m.GetCallerIdentityOutput, nil
}

0 comments on commit f7bda98

Please sign in to comment.