Skip to content

Commit

Permalink
Suppress cancellation notification and pass contexts down to API calls
Browse files Browse the repository at this point in the history
  • Loading branch information
skmcgrail committed Mar 23, 2020
1 parent a6999bf commit 590e17b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
35 changes: 32 additions & 3 deletions aws/credentials/credentials.go
Expand Up @@ -107,6 +107,13 @@ type Provider interface {
IsExpired() bool
}

// ProviderWithContext is a Provider that can retrieve credentials with a Context
type ProviderWithContext interface {
Provider

RetrieveWithContext(Context) (Value, error)
}

// An Expirer is an interface that Providers can implement to expose the expiration
// time, if known. If the Provider cannot accurately provide this info,
// it should not implement this interface.
Expand Down Expand Up @@ -233,7 +240,9 @@ func (c *Credentials) GetWithContext(ctx Context) (Value, error) {
// Cannot pass context down to the actual retrieve, because the first
// context would cancel the whole group when there is not direct
// association of items in the group.
resCh := c.sf.DoChan("", c.singleRetrieve)
resCh := c.sf.DoChan("", func() (interface{}, error) {
return c.singleRetrieve(&suppressedContext{ctx})
})
select {
case res := <-resCh:
return res.Val.(Value), res.Err
Expand All @@ -243,12 +252,16 @@ func (c *Credentials) GetWithContext(ctx Context) (Value, error) {
}
}

func (c *Credentials) singleRetrieve() (interface{}, error) {
func (c *Credentials) singleRetrieve(ctx Context) (creds interface{}, err error) {
if curCreds := c.creds.Load(); !c.isExpired(curCreds) {
return curCreds.(Value), nil
}

creds, err := c.provider.Retrieve()
if p, ok := c.provider.(ProviderWithContext); ok {
creds, err = p.RetrieveWithContext(ctx)
} else {
creds, err = c.provider.Retrieve()
}
if err == nil {
c.creds.Store(creds)
}
Expand Down Expand Up @@ -308,3 +321,19 @@ func (c *Credentials) ExpiresAt() (time.Time, error) {
}
return expirer.ExpiresAt(), nil
}

type suppressedContext struct {
Context
}

func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}

func (s *suppressedContext) Done() <-chan struct{} {
return nil
}

func (s *suppressedContext) Err() error {
return nil
}
19 changes: 18 additions & 1 deletion aws/credentials/stscreds/assume_role_provider.go
Expand Up @@ -87,6 +87,7 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkrand"
"github.com/aws/aws-sdk-go/service/sts"
)
Expand Down Expand Up @@ -118,6 +119,10 @@ type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
}

type assumeRolerWithContext interface {
AssumeRoleWithContext(aws.Context, *sts.AssumeRoleInput, ...request.Option) (*sts.AssumeRoleOutput, error)
}

// DefaultDuration is the default amount of time in minutes that the credentials
// will be valid for.
var DefaultDuration = time.Duration(15) * time.Minute
Expand Down Expand Up @@ -265,6 +270,10 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*

// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}

func (p *AssumeRoleProvider) RetrieveWithContext(ctx aws.Context) (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
Expand Down Expand Up @@ -304,7 +313,15 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
}
}

roleOutput, err := p.Client.AssumeRole(input)
var roleOutput *sts.AssumeRoleOutput
var err error

if c, ok := p.Client.(assumeRolerWithContext); ok {
roleOutput, err = c.AssumeRoleWithContext(ctx, input)
} else {
roleOutput, err = p.Client.AssumeRole(input)
}

if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
Expand Down

0 comments on commit 590e17b

Please sign in to comment.