Skip to content

Commit

Permalink
aws/credentials: Add ProviderWithContext optional interface to suppor…
Browse files Browse the repository at this point in the history
…t passing contexts on credential retrieval (#3223)
  • Loading branch information
skmcgrail committed Apr 1, 2020
1 parent a660ff4 commit c075c54
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 18 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG_PENDING.md
@@ -1,5 +1,10 @@
### SDK Features

### SDK Enhancements
* `aws/credentials`: `ProviderWithContext` optional interface has been added to support passing contexts on credential retrieval ([#3223](https://github.com/aws/aws-sdk-go/pull/3223))
* Credential providers that implement the optional `ProviderWithContext` will have context passed to them
* `ec2rolecreds.EC2RoleProvider`, `endpointcreds.Provider`, `stscreds.AssumeRoleProvider`, `stscreds.WebIdentityRoleProvider` have been updated to support the `ProviderWithContext` interface
* Fixes [#3213](https://github.com/aws/aws-sdk-go/issues/3213)
* `aws/ec2metadata`: Context aware operations have been added `EC2Metadata` client ([#3223](https://github.com/aws/aws-sdk-go/pull/3223))

### SDK Bugs
35 changes: 32 additions & 3 deletions aws/credentials/credentials.go
Expand Up @@ -107,6 +107,13 @@ type Provider interface {
IsExpired() bool
}

// ProviderWithContext is a Provider that can retrieve credentials with a Context
type ProviderWithContext interface {
Provider

RetrieveWithContext(Context) (Value, error)
}

// An Expirer is an interface that Providers can implement to expose the expiration
// time, if known. If the Provider cannot accurately provide this info,
// it should not implement this interface.
Expand Down Expand Up @@ -233,7 +240,9 @@ func (c *Credentials) GetWithContext(ctx Context) (Value, error) {
// Cannot pass context down to the actual retrieve, because the first
// context would cancel the whole group when there is not direct
// association of items in the group.
resCh := c.sf.DoChan("", c.singleRetrieve)
resCh := c.sf.DoChan("", func() (interface{}, error) {
return c.singleRetrieve(&suppressedContext{ctx})
})
select {
case res := <-resCh:
return res.Val.(Value), res.Err
Expand All @@ -243,12 +252,16 @@ func (c *Credentials) GetWithContext(ctx Context) (Value, error) {
}
}

func (c *Credentials) singleRetrieve() (interface{}, error) {
func (c *Credentials) singleRetrieve(ctx Context) (creds interface{}, err error) {
if curCreds := c.creds.Load(); !c.isExpired(curCreds) {
return curCreds.(Value), nil
}

creds, err := c.provider.Retrieve()
if p, ok := c.provider.(ProviderWithContext); ok {
creds, err = p.RetrieveWithContext(ctx)
} else {
creds, err = c.provider.Retrieve()
}
if err == nil {
c.creds.Store(creds)
}
Expand Down Expand Up @@ -308,3 +321,19 @@ func (c *Credentials) ExpiresAt() (time.Time, error) {
}
return expirer.ExpiresAt(), nil
}

type suppressedContext struct {
Context
}

func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}

func (s *suppressedContext) Done() <-chan struct{} {
return nil
}

func (s *suppressedContext) Err() error {
return nil
}
20 changes: 14 additions & 6 deletions aws/credentials/ec2rolecreds/ec2_role_provider.go
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
Expand Down Expand Up @@ -87,7 +88,14 @@ func NewCredentialsWithClient(client *ec2metadata.EC2Metadata, options ...func(*
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
credsList, err := requestCredList(m.Client)
return m.RetrieveWithContext(aws.BackgroundContext())
}

// RetrieveWithContext retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
credsList, err := requestCredList(ctx, m.Client)
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
Expand All @@ -97,7 +105,7 @@ func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
}
credsName := credsList[0]

roleCreds, err := requestCred(m.Client, credsName)
roleCreds, err := requestCred(ctx, m.Client, credsName)
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
Expand Down Expand Up @@ -130,8 +138,8 @@ const iamSecurityCredsPath = "iam/security-credentials/"

// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
resp, err := client.GetMetadata(iamSecurityCredsPath)
func requestCredList(ctx aws.Context, client *ec2metadata.EC2Metadata) ([]string, error) {
resp, err := client.GetMetadataWithContext(ctx, iamSecurityCredsPath)
if err != nil {
return nil, awserr.New("EC2RoleRequestError", "no EC2 instance role found", err)
}
Expand All @@ -154,8 +162,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
//
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadata(sdkuri.PathJoin(iamSecurityCredsPath, credsName))
func requestCred(ctx aws.Context, client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadataWithContext(ctx, sdkuri.PathJoin(iamSecurityCredsPath, credsName))
if err != nil {
return ec2RoleCredRespBody{},
awserr.New("EC2RoleRequestError",
Expand Down
11 changes: 9 additions & 2 deletions aws/credentials/endpointcreds/provider.go
Expand Up @@ -116,7 +116,13 @@ func (p *Provider) IsExpired() bool {
// Retrieve will attempt to request the credentials from the endpoint the Provider
// was configured for. And error will be returned if the retrieval fails.
func (p *Provider) Retrieve() (credentials.Value, error) {
resp, err := p.getCredentials()
return p.RetrieveWithContext(aws.BackgroundContext())
}

// RetrieveWithContext will attempt to request the credentials from the endpoint the Provider
// was configured for. And error will be returned if the retrieval fails.
func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
resp, err := p.getCredentials(ctx)
if err != nil {
return credentials.Value{ProviderName: ProviderName},
awserr.New("CredentialsEndpointError", "failed to load credentials", err)
Expand Down Expand Up @@ -148,14 +154,15 @@ type errorOutput struct {
Message string `json:"message"`
}

func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
func (p *Provider) getCredentials(ctx aws.Context) (*getCredentialsOutput, error) {
op := &request.Operation{
Name: "GetCredentials",
HTTPMethod: "GET",
}

out := &getCredentialsOutput{}
req := p.Client.NewRequest(op, nil, out)
req.SetContext(ctx)
req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.AuthorizationToken; len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken)
Expand Down
20 changes: 19 additions & 1 deletion aws/credentials/stscreds/assume_role_provider.go
Expand Up @@ -87,6 +87,7 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkrand"
"github.com/aws/aws-sdk-go/service/sts"
)
Expand Down Expand Up @@ -118,6 +119,10 @@ type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
}

type assumeRolerWithContext interface {
AssumeRoleWithContext(aws.Context, *sts.AssumeRoleInput, ...request.Option) (*sts.AssumeRoleOutput, error)
}

// DefaultDuration is the default amount of time in minutes that the credentials
// will be valid for.
var DefaultDuration = time.Duration(15) * time.Minute
Expand Down Expand Up @@ -265,6 +270,11 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*

// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}

// RetrieveWithContext generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
Expand Down Expand Up @@ -304,7 +314,15 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
}
}

roleOutput, err := p.Client.AssumeRole(input)
var roleOutput *sts.AssumeRoleOutput
var err error

if c, ok := p.Client.(assumeRolerWithContext); ok {
roleOutput, err = c.AssumeRoleWithContext(ctx, input)
} else {
roleOutput, err = p.Client.AssumeRole(input)
}

if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
Expand Down
41 changes: 41 additions & 0 deletions aws/credentials/stscreds/assume_role_provider_test.go
Expand Up @@ -6,6 +6,8 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/sts"
)

Expand All @@ -29,6 +31,16 @@ func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput,
}, nil
}

type stubSTSWithContext struct {
stubSTS
called chan struct{}
}

func (s *stubSTSWithContext) AssumeRoleWithContext(context credentials.Context, input *sts.AssumeRoleInput, option ...request.Option) (*sts.AssumeRoleOutput, error) {
<-s.called
return s.stubSTS.AssumeRole(input)
}

func TestAssumeRoleProvider(t *testing.T) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Expand Down Expand Up @@ -223,3 +235,32 @@ func TestAssumeRoleProvider_WithTags(t *testing.T) {
t.Errorf("expect error")
}
}

func TestAssumeRoleProvider_RetrieveWithContext(t *testing.T) {
stub := &stubSTSWithContext{
called: make(chan struct{}),
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}

go func() {
stub.called <- struct{}{}
}()

creds, err := p.RetrieveWithContext(aws.BackgroundContext())
if err != nil {
t.Errorf("expect nil, got %v", err)
}

if e, a := "roleARN", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSessionToken", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
10 changes: 10 additions & 0 deletions aws/credentials/stscreds/web_identity_provider.go
Expand Up @@ -64,6 +64,13 @@ func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, p
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}

// RetrieveWithContext will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
b, err := ioutil.ReadFile(p.tokenFilePath)
if err != nil {
errMsg := fmt.Sprintf("unable to read file at %s", p.tokenFilePath)
Expand All @@ -81,6 +88,9 @@ func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
})

req.SetContext(ctx)

// InvalidIdentityToken error is a temporary error that can occur
// when assuming an Role with a JWT web identity token.
req.RetryErrorCodes = append(req.RetryErrorCodes, sts.ErrCodeInvalidIdentityTokenException)
Expand Down

0 comments on commit c075c54

Please sign in to comment.