diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39c..4b4f0d26a2 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -1,5 +1,10 @@ ### SDK Features ### SDK Enhancements +* `aws/credentials`: `ProviderWithContext` optional interface has been added to support passing contexts on credential retrieval ([#3223](https://github.com/aws/aws-sdk-go/pull/3223)) + * Credential providers that implement the optional `ProviderWithContext` will have context passed to them + * `ec2rolecreds.EC2RoleProvider`, `endpointcreds.Provider`, `stscreds.AssumeRoleProvider`, `stscreds.WebIdentityRoleProvider` have been updated to support the `ProviderWithContext` interface + * Fixes [#3213](https://github.com/aws/aws-sdk-go/issues/3213) +* `aws/ec2metadata`: Context aware operations have been added `EC2Metadata` client ([#3223](https://github.com/aws/aws-sdk-go/pull/3223)) ### SDK Bugs diff --git a/aws/credentials/credentials.go b/aws/credentials/credentials.go index c75d7bba03..9f8fd92a50 100644 --- a/aws/credentials/credentials.go +++ b/aws/credentials/credentials.go @@ -107,6 +107,13 @@ type Provider interface { IsExpired() bool } +// ProviderWithContext is a Provider that can retrieve credentials with a Context +type ProviderWithContext interface { + Provider + + RetrieveWithContext(Context) (Value, error) +} + // An Expirer is an interface that Providers can implement to expose the expiration // time, if known. If the Provider cannot accurately provide this info, // it should not implement this interface. @@ -233,7 +240,9 @@ func (c *Credentials) GetWithContext(ctx Context) (Value, error) { // Cannot pass context down to the actual retrieve, because the first // context would cancel the whole group when there is not direct // association of items in the group. - resCh := c.sf.DoChan("", c.singleRetrieve) + resCh := c.sf.DoChan("", func() (interface{}, error) { + return c.singleRetrieve(&suppressedContext{ctx}) + }) select { case res := <-resCh: return res.Val.(Value), res.Err @@ -243,12 +252,16 @@ func (c *Credentials) GetWithContext(ctx Context) (Value, error) { } } -func (c *Credentials) singleRetrieve() (interface{}, error) { +func (c *Credentials) singleRetrieve(ctx Context) (creds interface{}, err error) { if curCreds := c.creds.Load(); !c.isExpired(curCreds) { return curCreds.(Value), nil } - creds, err := c.provider.Retrieve() + if p, ok := c.provider.(ProviderWithContext); ok { + creds, err = p.RetrieveWithContext(ctx) + } else { + creds, err = c.provider.Retrieve() + } if err == nil { c.creds.Store(creds) } @@ -308,3 +321,19 @@ func (c *Credentials) ExpiresAt() (time.Time, error) { } return expirer.ExpiresAt(), nil } + +type suppressedContext struct { + Context +} + +func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) { + return time.Time{}, false +} + +func (s *suppressedContext) Done() <-chan struct{} { + return nil +} + +func (s *suppressedContext) Err() error { + return nil +} diff --git a/aws/credentials/ec2rolecreds/ec2_role_provider.go b/aws/credentials/ec2rolecreds/ec2_role_provider.go index 43d4ed386a..92af5b7250 100644 --- a/aws/credentials/ec2rolecreds/ec2_role_provider.go +++ b/aws/credentials/ec2rolecreds/ec2_role_provider.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/credentials" @@ -87,7 +88,14 @@ func NewCredentialsWithClient(client *ec2metadata.EC2Metadata, options ...func(* // Error will be returned if the request fails, or unable to extract // the desired credentials. func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) { - credsList, err := requestCredList(m.Client) + return m.RetrieveWithContext(aws.BackgroundContext()) +} + +// RetrieveWithContext retrieves credentials from the EC2 service. +// Error will be returned if the request fails, or unable to extract +// the desired credentials. +func (m *EC2RoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { + credsList, err := requestCredList(ctx, m.Client) if err != nil { return credentials.Value{ProviderName: ProviderName}, err } @@ -97,7 +105,7 @@ func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) { } credsName := credsList[0] - roleCreds, err := requestCred(m.Client, credsName) + roleCreds, err := requestCred(ctx, m.Client, credsName) if err != nil { return credentials.Value{ProviderName: ProviderName}, err } @@ -130,8 +138,8 @@ const iamSecurityCredsPath = "iam/security-credentials/" // requestCredList requests a list of credentials from the EC2 service. // If there are no credentials, or there is an error making or receiving the request -func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) { - resp, err := client.GetMetadata(iamSecurityCredsPath) +func requestCredList(ctx aws.Context, client *ec2metadata.EC2Metadata) ([]string, error) { + resp, err := client.GetMetadataWithContext(ctx, iamSecurityCredsPath) if err != nil { return nil, awserr.New("EC2RoleRequestError", "no EC2 instance role found", err) } @@ -154,8 +162,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) { // // If the credentials cannot be found, or there is an error reading the response // and error will be returned. -func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) { - resp, err := client.GetMetadata(sdkuri.PathJoin(iamSecurityCredsPath, credsName)) +func requestCred(ctx aws.Context, client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) { + resp, err := client.GetMetadataWithContext(ctx, sdkuri.PathJoin(iamSecurityCredsPath, credsName)) if err != nil { return ec2RoleCredRespBody{}, awserr.New("EC2RoleRequestError", diff --git a/aws/credentials/endpointcreds/provider.go b/aws/credentials/endpointcreds/provider.go index 1a7af53a4d..785f30d8e6 100644 --- a/aws/credentials/endpointcreds/provider.go +++ b/aws/credentials/endpointcreds/provider.go @@ -116,7 +116,13 @@ func (p *Provider) IsExpired() bool { // Retrieve will attempt to request the credentials from the endpoint the Provider // was configured for. And error will be returned if the retrieval fails. func (p *Provider) Retrieve() (credentials.Value, error) { - resp, err := p.getCredentials() + return p.RetrieveWithContext(aws.BackgroundContext()) +} + +// RetrieveWithContext will attempt to request the credentials from the endpoint the Provider +// was configured for. And error will be returned if the retrieval fails. +func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { + resp, err := p.getCredentials(ctx) if err != nil { return credentials.Value{ProviderName: ProviderName}, awserr.New("CredentialsEndpointError", "failed to load credentials", err) @@ -148,7 +154,7 @@ type errorOutput struct { Message string `json:"message"` } -func (p *Provider) getCredentials() (*getCredentialsOutput, error) { +func (p *Provider) getCredentials(ctx aws.Context) (*getCredentialsOutput, error) { op := &request.Operation{ Name: "GetCredentials", HTTPMethod: "GET", @@ -156,6 +162,7 @@ func (p *Provider) getCredentials() (*getCredentialsOutput, error) { out := &getCredentialsOutput{} req := p.Client.NewRequest(op, nil, out) + req.SetContext(ctx) req.HTTPRequest.Header.Set("Accept", "application/json") if authToken := p.AuthorizationToken; len(authToken) != 0 { req.HTTPRequest.Header.Set("Authorization", authToken) diff --git a/aws/credentials/stscreds/assume_role_provider.go b/aws/credentials/stscreds/assume_role_provider.go index 9f37f44bcf..73d9763c9b 100644 --- a/aws/credentials/stscreds/assume_role_provider.go +++ b/aws/credentials/stscreds/assume_role_provider.go @@ -87,6 +87,7 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/internal/sdkrand" "github.com/aws/aws-sdk-go/service/sts" ) @@ -118,6 +119,10 @@ type AssumeRoler interface { AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) } +type assumeRolerWithContext interface { + AssumeRoleWithContext(aws.Context, *sts.AssumeRoleInput, ...request.Option) (*sts.AssumeRoleOutput, error) +} + // DefaultDuration is the default amount of time in minutes that the credentials // will be valid for. var DefaultDuration = time.Duration(15) * time.Minute @@ -265,6 +270,11 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(* // Retrieve generates a new set of temporary credentials using STS. func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) { + return p.RetrieveWithContext(aws.BackgroundContext()) +} + +// RetrieveWithContext generates a new set of temporary credentials using STS. +func (p *AssumeRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { // Apply defaults where parameters are not set. if p.RoleSessionName == "" { // Try to work out a role name that will hopefully end up unique. @@ -304,7 +314,15 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) { } } - roleOutput, err := p.Client.AssumeRole(input) + var roleOutput *sts.AssumeRoleOutput + var err error + + if c, ok := p.Client.(assumeRolerWithContext); ok { + roleOutput, err = c.AssumeRoleWithContext(ctx, input) + } else { + roleOutput, err = p.Client.AssumeRole(input) + } + if err != nil { return credentials.Value{ProviderName: ProviderName}, err } diff --git a/aws/credentials/stscreds/assume_role_provider_test.go b/aws/credentials/stscreds/assume_role_provider_test.go index aea76b60e1..d34102de3b 100644 --- a/aws/credentials/stscreds/assume_role_provider_test.go +++ b/aws/credentials/stscreds/assume_role_provider_test.go @@ -6,6 +6,8 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/sts" ) @@ -29,6 +31,16 @@ func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, }, nil } +type stubSTSWithContext struct { + stubSTS + called chan struct{} +} + +func (s *stubSTSWithContext) AssumeRoleWithContext(context credentials.Context, input *sts.AssumeRoleInput, option ...request.Option) (*sts.AssumeRoleOutput, error) { + <-s.called + return s.stubSTS.AssumeRole(input) +} + func TestAssumeRoleProvider(t *testing.T) { stub := &stubSTS{} p := &AssumeRoleProvider{ @@ -223,3 +235,32 @@ func TestAssumeRoleProvider_WithTags(t *testing.T) { t.Errorf("expect error") } } + +func TestAssumeRoleProvider_RetrieveWithContext(t *testing.T) { + stub := &stubSTSWithContext{ + called: make(chan struct{}), + } + p := &AssumeRoleProvider{ + Client: stub, + RoleARN: "roleARN", + } + + go func() { + stub.called <- struct{}{} + }() + + creds, err := p.RetrieveWithContext(aws.BackgroundContext()) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + + if e, a := "roleARN", creds.AccessKeyID; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "assumedSessionToken", creds.SessionToken; e != a { + t.Errorf("expect %v, got %v", e, a) + } +} diff --git a/aws/credentials/stscreds/web_identity_provider.go b/aws/credentials/stscreds/web_identity_provider.go index b20b633948..fae6f2f651 100644 --- a/aws/credentials/stscreds/web_identity_provider.go +++ b/aws/credentials/stscreds/web_identity_provider.go @@ -64,6 +64,13 @@ func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, p // 'WebIdentityTokenFilePath' specified destination and if that is empty an // error will be returned. func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) { + return p.RetrieveWithContext(aws.BackgroundContext()) +} + +// RetrieveWithContext will attempt to assume a role from a token which is located at +// 'WebIdentityTokenFilePath' specified destination and if that is empty an +// error will be returned. +func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { b, err := ioutil.ReadFile(p.tokenFilePath) if err != nil { errMsg := fmt.Sprintf("unable to read file at %s", p.tokenFilePath) @@ -81,6 +88,9 @@ func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) { RoleSessionName: &sessionName, WebIdentityToken: aws.String(string(b)), }) + + req.SetContext(ctx) + // InvalidIdentityToken error is a temporary error that can occur // when assuming an Role with a JWT web identity token. req.RetryErrorCodes = append(req.RetryErrorCodes, sts.ErrCodeInvalidIdentityTokenException) diff --git a/aws/ec2metadata/api.go b/aws/ec2metadata/api.go index 12897eef62..a716c021cf 100644 --- a/aws/ec2metadata/api.go +++ b/aws/ec2metadata/api.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/internal/sdkuri" @@ -15,7 +16,7 @@ import ( // getToken uses the duration to return a token for EC2 metadata service, // or an error if the request failed. -func (c *EC2Metadata) getToken(duration time.Duration) (tokenOutput, error) { +func (c *EC2Metadata) getToken(ctx aws.Context, duration time.Duration) (tokenOutput, error) { op := &request.Operation{ Name: "GetToken", HTTPMethod: "PUT", @@ -24,6 +25,7 @@ func (c *EC2Metadata) getToken(duration time.Duration) (tokenOutput, error) { var output tokenOutput req := c.NewRequest(op, nil, &output) + req.SetContext(ctx) // remove the fetch token handler from the request handlers to avoid infinite recursion req.Handlers.Sign.RemoveByName(fetchTokenHandlerName) @@ -50,6 +52,13 @@ func (c *EC2Metadata) getToken(duration time.Duration) (tokenOutput, error) { // instance metadata service. The content will be returned as a string, or // error if the request failed. func (c *EC2Metadata) GetMetadata(p string) (string, error) { + return c.GetMetadataWithContext(aws.BackgroundContext(), p) +} + +// GetMetadataWithContext uses the path provided to request information from the EC2 +// instance metadata service. The content will be returned as a string, or +// error if the request failed. +func (c *EC2Metadata) GetMetadataWithContext(ctx aws.Context, p string) (string, error) { op := &request.Operation{ Name: "GetMetadata", HTTPMethod: "GET", @@ -59,6 +68,8 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) { req := c.NewRequest(op, nil, output) + req.SetContext(ctx) + err := req.Send() return output.Content, err } @@ -67,6 +78,13 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) { // there is no user-data setup for the EC2 instance a "NotFoundError" error // code will be returned. func (c *EC2Metadata) GetUserData() (string, error) { + return c.GetUserDataWithContext(aws.BackgroundContext()) +} + +// GetUserDataWithContext returns the userdata that was configured for the service. If +// there is no user-data setup for the EC2 instance a "NotFoundError" error +// code will be returned. +func (c *EC2Metadata) GetUserDataWithContext(ctx aws.Context) (string, error) { op := &request.Operation{ Name: "GetUserData", HTTPMethod: "GET", @@ -75,6 +93,7 @@ func (c *EC2Metadata) GetUserData() (string, error) { output := &metadataOutput{} req := c.NewRequest(op, nil, output) + req.SetContext(ctx) err := req.Send() return output.Content, err @@ -84,6 +103,13 @@ func (c *EC2Metadata) GetUserData() (string, error) { // instance metadata service for dynamic data. The content will be returned // as a string, or error if the request failed. func (c *EC2Metadata) GetDynamicData(p string) (string, error) { + return c.GetDynamicDataWithContext(aws.BackgroundContext(), p) +} + +// GetDynamicDataWithContext uses the path provided to request information from the EC2 +// instance metadata service for dynamic data. The content will be returned +// as a string, or error if the request failed. +func (c *EC2Metadata) GetDynamicDataWithContext(ctx aws.Context, p string) (string, error) { op := &request.Operation{ Name: "GetDynamicData", HTTPMethod: "GET", @@ -92,6 +118,7 @@ func (c *EC2Metadata) GetDynamicData(p string) (string, error) { output := &metadataOutput{} req := c.NewRequest(op, nil, output) + req.SetContext(ctx) err := req.Send() return output.Content, err @@ -101,7 +128,14 @@ func (c *EC2Metadata) GetDynamicData(p string) (string, error) { // instance. Error is returned if the request fails or is unable to parse // the response. func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument, error) { - resp, err := c.GetDynamicData("instance-identity/document") + return c.GetInstanceIdentityDocumentWithContext(aws.BackgroundContext()) +} + +// GetInstanceIdentityDocumentWithContext retrieves an identity document describing an +// instance. Error is returned if the request fails or is unable to parse +// the response. +func (c *EC2Metadata) GetInstanceIdentityDocumentWithContext(ctx aws.Context) (EC2InstanceIdentityDocument, error) { + resp, err := c.GetDynamicDataWithContext(ctx, "instance-identity/document") if err != nil { return EC2InstanceIdentityDocument{}, awserr.New("EC2MetadataRequestError", @@ -120,7 +154,12 @@ func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument // IAMInfo retrieves IAM info from the metadata API func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) { - resp, err := c.GetMetadata("iam/info") + return c.IAMInfoWithContext(aws.BackgroundContext()) +} + +// IAMInfoWithContext retrieves IAM info from the metadata API +func (c *EC2Metadata) IAMInfoWithContext(ctx aws.Context) (EC2IAMInfo, error) { + resp, err := c.GetMetadataWithContext(ctx, "iam/info") if err != nil { return EC2IAMInfo{}, awserr.New("EC2MetadataRequestError", @@ -145,7 +184,12 @@ func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) { // Region returns the region the instance is running in. func (c *EC2Metadata) Region() (string, error) { - ec2InstanceIdentityDocument, err := c.GetInstanceIdentityDocument() + return c.RegionWithContext(aws.BackgroundContext()) +} + +// RegionWithContext returns the region the instance is running in. +func (c *EC2Metadata) RegionWithContext(ctx aws.Context) (string, error) { + ec2InstanceIdentityDocument, err := c.GetInstanceIdentityDocumentWithContext(ctx) if err != nil { return "", err } @@ -162,7 +206,14 @@ func (c *EC2Metadata) Region() (string, error) { // Can be used to determine if application is running within an EC2 Instance and // the metadata service is available. func (c *EC2Metadata) Available() bool { - if _, err := c.GetMetadata("instance-id"); err != nil { + return c.AvailableWithContext(aws.BackgroundContext()) +} + +// AvailableWithContext returns if the application has access to the EC2 Metadata service. +// Can be used to determine if application is running within an EC2 Instance and +// the metadata service is available. +func (c *EC2Metadata) AvailableWithContext(ctx aws.Context) bool { + if _, err := c.GetMetadataWithContext(ctx, "instance-id"); err != nil { return false } diff --git a/aws/ec2metadata/token_provider.go b/aws/ec2metadata/token_provider.go index 663372a915..d0a3a020d8 100644 --- a/aws/ec2metadata/token_provider.go +++ b/aws/ec2metadata/token_provider.go @@ -46,7 +46,7 @@ func (t *tokenProvider) fetchTokenHandler(r *request.Request) { return } - output, err := t.client.getToken(t.configuredTTL) + output, err := t.client.getToken(r.Context(), t.configuredTTL) if err != nil {