From 59061923b2d5b63298473bab3c3025d6183898e9 Mon Sep 17 00:00:00 2001 From: Bracken Dawson Date: Thu, 29 Jul 2021 12:29:58 +0100 Subject: [PATCH] Remove github.com/dgrijalva/jwt-go By updating github.com/Azure/go-autorest/autorest to v.0.11.19 (10e0b31633f168ce1a329dcbdd0ab9842e533fb5) Backport of github.com/distribution/distribution#3459 Signed-off-by: Bracken Dawson --- vendor.conf | 6 +- vendor/github.com/Azure/go-autorest/README.md | 18 +- .../Azure/go-autorest/autorest/adal/README.md | 16 +- .../Azure/go-autorest/autorest/adal/config.go | 88 +- .../go-autorest/autorest/adal/devicetoken.go | 39 +- .../Azure/go-autorest/autorest/adal/go.mod | 13 + .../autorest/adal/go_mod_tidy_hack.go | 24 + .../go-autorest/autorest/adal/persist.go | 62 + .../Azure/go-autorest/autorest/adal/sender.go | 40 +- .../Azure/go-autorest/autorest/adal/token.go | 921 ++++++++++++--- .../go-autorest/autorest/adal/token_1.13.go | 75 ++ .../go-autorest/autorest/adal/token_legacy.go | 74 ++ .../go-autorest/autorest/adal/version.go | 45 + .../go-autorest/autorest/authorization.go | 134 ++- .../go-autorest/autorest/authorization_sas.go | 66 ++ .../autorest/authorization_storage.go | 307 +++++ .../Azure/go-autorest/autorest/azure/async.go | 1052 ++++++++++++----- .../Azure/go-autorest/autorest/azure/azure.go | 177 ++- .../autorest/azure/environments.go | 154 ++- .../Azure/go-autorest/autorest/azure/rp.go | 20 +- .../Azure/go-autorest/autorest/client.go | 104 +- .../Azure/go-autorest/autorest/date/go.mod | 5 + .../autorest/date/go_mod_tidy_hack.go | 24 + .../Azure/go-autorest/autorest/error.go | 5 + .../Azure/go-autorest/autorest/go.mod | 12 + .../go-autorest/autorest/go_mod_tidy_hack.go | 24 + .../Azure/go-autorest/autorest/preparer.go | 117 +- .../Azure/go-autorest/autorest/responder.go | 19 + .../Azure/go-autorest/autorest/sender.go | 206 +++- .../Azure/go-autorest/autorest/utility.go | 45 +- .../go-autorest/autorest/utility_1.13.go | 29 + .../go-autorest/autorest/utility_legacy.go | 31 + .../Azure/go-autorest/autorest/version.go | 23 +- vendor/github.com/Azure/go-autorest/doc.go | 18 + .../Azure/go-autorest/logger/go.mod | 5 + .../go-autorest/logger/go_mod_tidy_hack.go | 24 + .../Azure/go-autorest/logger/logger.go | 337 ++++++ .../Azure/go-autorest/tracing/go.mod | 5 + .../go-autorest/tracing/go_mod_tidy_hack.go | 24 + .../Azure/go-autorest/tracing/tracing.go | 67 ++ .../jwt-go/LICENSE | 0 .../jwt-go/README.md | 39 +- .../jwt-go/claims.go | 16 +- .../jwt-go/doc.go | 0 .../jwt-go/ecdsa.go | 1 + .../jwt-go/ecdsa_utils.go | 4 +- .../jwt-go/errors.go | 0 .../jwt-go/hmac.go | 3 +- .../jwt-go/map_claims.go | 10 +- .../jwt-go/none.go | 0 .../jwt-go/parser.go | 113 +- .../jwt-go/rsa.go | 5 +- .../jwt-go/rsa_pss.go | 38 +- .../jwt-go/rsa_utils.go | 34 +- .../jwt-go/signing_method.go | 0 .../jwt-go/token.go | 0 vendor/golang.org/x/crypto/README | 3 - vendor/golang.org/x/crypto/README.md | 21 + vendor/golang.org/x/crypto/bcrypt/bcrypt.go | 11 +- vendor/golang.org/x/crypto/blowfish/cipher.go | 12 +- vendor/golang.org/x/crypto/blowfish/const.go | 2 +- vendor/golang.org/x/crypto/go.mod | 8 + vendor/golang.org/x/crypto/ocsp/ocsp.go | 393 ++++-- .../x/crypto/otr/libotr_test_helper.c | 154 ++- vendor/golang.org/x/crypto/otr/otr.go | 45 +- .../golang.org/x/crypto/pkcs12/bmp-string.go | 50 + vendor/golang.org/x/crypto/pkcs12/crypto.go | 131 ++ vendor/golang.org/x/crypto/pkcs12/errors.go | 23 + .../x/crypto/pkcs12/internal/rc2/rc2.go | 271 +++++ vendor/golang.org/x/crypto/pkcs12/mac.go | 45 + vendor/golang.org/x/crypto/pkcs12/pbkdf.go | 170 +++ vendor/golang.org/x/crypto/pkcs12/pkcs12.go | 360 ++++++ vendor/golang.org/x/crypto/pkcs12/safebags.go | 57 + vendor/golang.org/x/crypto/ssh/test/doc.go | 7 + .../x/crypto/ssh/test/sshd_test_pw.c | 173 +++ 75 files changed, 5677 insertions(+), 977 deletions(-) create mode 100644 vendor/github.com/Azure/go-autorest/autorest/adal/go.mod create mode 100644 vendor/github.com/Azure/go-autorest/autorest/adal/go_mod_tidy_hack.go create mode 100644 vendor/github.com/Azure/go-autorest/autorest/adal/token_1.13.go create mode 100644 vendor/github.com/Azure/go-autorest/autorest/adal/token_legacy.go create mode 100644 vendor/github.com/Azure/go-autorest/autorest/adal/version.go create mode 100644 vendor/github.com/Azure/go-autorest/autorest/authorization_sas.go create mode 100644 vendor/github.com/Azure/go-autorest/autorest/authorization_storage.go create mode 100644 vendor/github.com/Azure/go-autorest/autorest/date/go.mod create mode 100644 vendor/github.com/Azure/go-autorest/autorest/date/go_mod_tidy_hack.go create mode 100644 vendor/github.com/Azure/go-autorest/autorest/go.mod create mode 100644 vendor/github.com/Azure/go-autorest/autorest/go_mod_tidy_hack.go create mode 100644 vendor/github.com/Azure/go-autorest/autorest/utility_1.13.go create mode 100644 vendor/github.com/Azure/go-autorest/autorest/utility_legacy.go create mode 100644 vendor/github.com/Azure/go-autorest/doc.go create mode 100644 vendor/github.com/Azure/go-autorest/logger/go.mod create mode 100644 vendor/github.com/Azure/go-autorest/logger/go_mod_tidy_hack.go create mode 100644 vendor/github.com/Azure/go-autorest/logger/logger.go create mode 100644 vendor/github.com/Azure/go-autorest/tracing/go.mod create mode 100644 vendor/github.com/Azure/go-autorest/tracing/go_mod_tidy_hack.go create mode 100644 vendor/github.com/Azure/go-autorest/tracing/tracing.go rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/LICENSE (100%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/README.md (63%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/claims.go (93%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/doc.go (100%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/ecdsa.go (97%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/ecdsa_utils.go (93%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/errors.go (100%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/hmac.go (96%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/map_claims.go (94%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/none.go (100%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/parser.go (68%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/rsa.go (91%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/rsa_pss.go (71%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/rsa_utils.go (62%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/signing_method.go (100%) rename vendor/github.com/{dgrijalva => form3tech-oss}/jwt-go/token.go (100%) delete mode 100644 vendor/golang.org/x/crypto/README create mode 100644 vendor/golang.org/x/crypto/README.md create mode 100644 vendor/golang.org/x/crypto/go.mod create mode 100644 vendor/golang.org/x/crypto/pkcs12/bmp-string.go create mode 100644 vendor/golang.org/x/crypto/pkcs12/crypto.go create mode 100644 vendor/golang.org/x/crypto/pkcs12/errors.go create mode 100644 vendor/golang.org/x/crypto/pkcs12/internal/rc2/rc2.go create mode 100644 vendor/golang.org/x/crypto/pkcs12/mac.go create mode 100644 vendor/golang.org/x/crypto/pkcs12/pbkdf.go create mode 100644 vendor/golang.org/x/crypto/pkcs12/pkcs12.go create mode 100644 vendor/golang.org/x/crypto/pkcs12/safebags.go create mode 100644 vendor/golang.org/x/crypto/ssh/test/doc.go create mode 100644 vendor/golang.org/x/crypto/ssh/test/sshd_test_pw.c diff --git a/vendor.conf b/vendor.conf index a249caf269..cb85088265 100644 --- a/vendor.conf +++ b/vendor.conf @@ -1,5 +1,5 @@ github.com/Azure/azure-sdk-for-go 4650843026a7fdec254a8d9cf893693a254edd0b -github.com/Azure/go-autorest eaa7994b2278094c904d31993d26f56324db3052 +github.com/Azure/go-autorest 10e0b31633f168ce1a329dcbdd0ab9842e533fb5 github.com/sirupsen/logrus 3d4380f53a34dcdc95f0c1db702615992b38d9a4 github.com/aws/aws-sdk-go f831d5a0822a1ad72420ab18c6269bca1ddaf490 github.com/bshuster-repo/logrus-logstash-hook d2c0ecc1836d91814e15e23bb5dc309c3ef51f4a @@ -8,9 +8,9 @@ github.com/bugsnag/bugsnag-go b1d153021fcd90ca3f080db36bec96dc690fb274 github.com/bugsnag/osext 0dd3f918b21bec95ace9dc86c7e70266cfc5c702 github.com/bugsnag/panicwrap e2c28503fcd0675329da73bf48b33404db873782 github.com/denverdino/aliyungo afedced274aa9a7fcdd47ac97018f0f8db4e5de2 -github.com/dgrijalva/jwt-go a601269ab70c205d26370c16f7c81e9017c14e04 github.com/docker/go-metrics 399ea8c73916000c64c2c76e8da00ca82f8387ab github.com/docker/libtrust fa567046d9b14f6aa788882a950d69651d230b21 +github.com/form3tech-oss/jwt-go 9162a5abdbc046b7c8b03ee90052cee67e25caa7 github.com/garyburd/redigo 535138d7bcd717d6531c701ef5933d98b1866257 github.com/go-ini/ini 2ba15ac2dc9cdf88c110ec2dc0ced7fa45f5678c github.com/golang/protobuf 8d92cf5fc15a4382f8964b08e1f42a75c0591aa3 @@ -35,7 +35,7 @@ github.com/xenolf/lego a9d8cec0e6563575e5868a005359ac97911b5985 github.com/yvasiyarov/go-metrics 57bccd1ccd43f94bb17fdd8bf3007059b802f85e github.com/yvasiyarov/gorelic a9bba5b9ab508a086f9a12b8c51fab68478e2128 github.com/yvasiyarov/newrelic_platform_go b21fdbd4370f3717f3bbd2bf41c223bc273068e6 -golang.org/x/crypto c10c31b5e94b6f7a0283272dc2bb27163dcea24b +golang.org/x/crypto 7f63de1d35b0f77fa2b9faea3e7deb402a2383c8 golang.org/x/net 4876518f9e71663000c348837735820161a42df7 golang.org/x/oauth2 045497edb6234273d67dbc25da3f2ddbc4c4cacf golang.org/x/time a4bde12657593d5e90d0533a3e4fd95e635124cb diff --git a/vendor/github.com/Azure/go-autorest/README.md b/vendor/github.com/Azure/go-autorest/README.md index 9a7b13a359..de1e19a44d 100644 --- a/vendor/github.com/Azure/go-autorest/README.md +++ b/vendor/github.com/Azure/go-autorest/README.md @@ -1,7 +1,7 @@ # go-autorest [![GoDoc](https://godoc.org/github.com/Azure/go-autorest/autorest?status.png)](https://godoc.org/github.com/Azure/go-autorest/autorest) -[![Build Status](https://travis-ci.org/Azure/go-autorest.svg?branch=master)](https://travis-ci.org/Azure/go-autorest) +[![Build Status](https://dev.azure.com/azure-sdk/public/_apis/build/status/go/Azure.go-autorest?branchName=master)](https://dev.azure.com/azure-sdk/public/_build/latest?definitionId=625&branchName=master) [![Go Report Card](https://goreportcard.com/badge/Azure/go-autorest)](https://goreportcard.com/report/Azure/go-autorest) Package go-autorest provides an HTTP request client for use with [Autorest](https://github.com/Azure/autorest.go)-generated API client packages. @@ -135,6 +135,22 @@ go get github.com/Azure/go-autorest/autorest/date go get github.com/Azure/go-autorest/autorest/to ``` +### Using with Go Modules +In [v12.0.1](https://github.com/Azure/go-autorest/pull/386), this repository introduced the following modules. + +- autorest/adal +- autorest/azure/auth +- autorest/azure/cli +- autorest/date +- autorest/mocks +- autorest/to +- autorest/validation +- autorest +- logger +- tracing + +Tagging cumulative SDK releases as a whole (e.g. `v12.3.0`) is still enabled to support consumers of this repo that have not yet migrated to modules. + ## License See LICENSE file. diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/README.md b/vendor/github.com/Azure/go-autorest/autorest/adal/README.md index 7b0c4bc4d2..fec416a9c4 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/adal/README.md +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/README.md @@ -135,7 +135,7 @@ resource := "https://management.core.windows.net/" applicationSecret := "APPLICATION_SECRET" spt, err := adal.NewServicePrincipalToken( - oauthConfig, + *oauthConfig, appliationID, applicationSecret, resource, @@ -170,7 +170,7 @@ if err != nil { } spt, err := adal.NewServicePrincipalTokenFromCertificate( - oauthConfig, + *oauthConfig, applicationID, certificate, rsaPrivateKey, @@ -195,7 +195,7 @@ oauthClient := &http.Client{} // Acquire the device code deviceCode, err := adal.InitiateDeviceAuth( oauthClient, - oauthConfig, + *oauthConfig, applicationID, resource) if err != nil { @@ -212,7 +212,7 @@ if err != nil { } spt, err := adal.NewServicePrincipalTokenFromManualToken( - oauthConfig, + *oauthConfig, applicationID, resource, *token, @@ -227,7 +227,7 @@ if (err == nil) { ```Go spt, err := adal.NewServicePrincipalTokenFromUsernamePassword( - oauthConfig, + *oauthConfig, applicationID, username, password, @@ -243,11 +243,11 @@ if (err == nil) { ``` Go spt, err := adal.NewServicePrincipalTokenFromAuthorizationCode( - oauthConfig, + *oauthConfig, applicationID, clientSecret, - authorizationCode, - redirectURI, + authorizationCode, + redirectURI, resource, callbacks...) diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/config.go b/vendor/github.com/Azure/go-autorest/autorest/adal/config.go index f570d540a6..fa5964742f 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/adal/config.go +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/config.go @@ -15,21 +15,22 @@ package adal // limitations under the License. import ( + "errors" "fmt" "net/url" ) const ( - activeDirectoryAPIVersion = "1.0" + activeDirectoryEndpointTemplate = "%s/oauth2/%s%s" ) // OAuthConfig represents the endpoints needed // in OAuth operations type OAuthConfig struct { - AuthorityEndpoint url.URL - AuthorizeEndpoint url.URL - TokenEndpoint url.URL - DeviceCodeEndpoint url.URL + AuthorityEndpoint url.URL `json:"authorityEndpoint"` + AuthorizeEndpoint url.URL `json:"authorizeEndpoint"` + TokenEndpoint url.URL `json:"tokenEndpoint"` + DeviceCodeEndpoint url.URL `json:"deviceCodeEndpoint"` } // IsZero returns true if the OAuthConfig object is zero-initialized. @@ -46,11 +47,24 @@ func validateStringParam(param, name string) error { // NewOAuthConfig returns an OAuthConfig with tenant specific urls func NewOAuthConfig(activeDirectoryEndpoint, tenantID string) (*OAuthConfig, error) { + apiVer := "1.0" + return NewOAuthConfigWithAPIVersion(activeDirectoryEndpoint, tenantID, &apiVer) +} + +// NewOAuthConfigWithAPIVersion returns an OAuthConfig with tenant specific urls. +// If apiVersion is not nil the "api-version" query parameter will be appended to the endpoint URLs with the specified value. +func NewOAuthConfigWithAPIVersion(activeDirectoryEndpoint, tenantID string, apiVersion *string) (*OAuthConfig, error) { if err := validateStringParam(activeDirectoryEndpoint, "activeDirectoryEndpoint"); err != nil { return nil, err } + api := "" // it's legal for tenantID to be empty so don't validate it - const activeDirectoryEndpointTemplate = "%s/oauth2/%s?api-version=%s" + if apiVersion != nil { + if err := validateStringParam(*apiVersion, "apiVersion"); err != nil { + return nil, err + } + api = fmt.Sprintf("?api-version=%s", *apiVersion) + } u, err := url.Parse(activeDirectoryEndpoint) if err != nil { return nil, err @@ -59,15 +73,15 @@ func NewOAuthConfig(activeDirectoryEndpoint, tenantID string) (*OAuthConfig, err if err != nil { return nil, err } - authorizeURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "authorize", activeDirectoryAPIVersion)) + authorizeURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "authorize", api)) if err != nil { return nil, err } - tokenURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "token", activeDirectoryAPIVersion)) + tokenURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "token", api)) if err != nil { return nil, err } - deviceCodeURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "devicecode", activeDirectoryAPIVersion)) + deviceCodeURL, err := u.Parse(fmt.Sprintf(activeDirectoryEndpointTemplate, tenantID, "devicecode", api)) if err != nil { return nil, err } @@ -79,3 +93,59 @@ func NewOAuthConfig(activeDirectoryEndpoint, tenantID string) (*OAuthConfig, err DeviceCodeEndpoint: *deviceCodeURL, }, nil } + +// MultiTenantOAuthConfig provides endpoints for primary and aulixiary tenant IDs. +type MultiTenantOAuthConfig interface { + PrimaryTenant() *OAuthConfig + AuxiliaryTenants() []*OAuthConfig +} + +// OAuthOptions contains optional OAuthConfig creation arguments. +type OAuthOptions struct { + APIVersion string +} + +func (c OAuthOptions) apiVersion() string { + if c.APIVersion != "" { + return fmt.Sprintf("?api-version=%s", c.APIVersion) + } + return "1.0" +} + +// NewMultiTenantOAuthConfig creates an object that support multitenant OAuth configuration. +// See https://docs.microsoft.com/en-us/azure/azure-resource-manager/authenticate-multi-tenant for more information. +func NewMultiTenantOAuthConfig(activeDirectoryEndpoint, primaryTenantID string, auxiliaryTenantIDs []string, options OAuthOptions) (MultiTenantOAuthConfig, error) { + if len(auxiliaryTenantIDs) == 0 || len(auxiliaryTenantIDs) > 3 { + return nil, errors.New("must specify one to three auxiliary tenants") + } + mtCfg := multiTenantOAuthConfig{ + cfgs: make([]*OAuthConfig, len(auxiliaryTenantIDs)+1), + } + apiVer := options.apiVersion() + pri, err := NewOAuthConfigWithAPIVersion(activeDirectoryEndpoint, primaryTenantID, &apiVer) + if err != nil { + return nil, fmt.Errorf("failed to create OAuthConfig for primary tenant: %v", err) + } + mtCfg.cfgs[0] = pri + for i := range auxiliaryTenantIDs { + aux, err := NewOAuthConfig(activeDirectoryEndpoint, auxiliaryTenantIDs[i]) + if err != nil { + return nil, fmt.Errorf("failed to create OAuthConfig for tenant '%s': %v", auxiliaryTenantIDs[i], err) + } + mtCfg.cfgs[i+1] = aux + } + return mtCfg, nil +} + +type multiTenantOAuthConfig struct { + // first config in the slice is the primary tenant + cfgs []*OAuthConfig +} + +func (m multiTenantOAuthConfig) PrimaryTenant() *OAuthConfig { + return m.cfgs[0] +} + +func (m multiTenantOAuthConfig) AuxiliaryTenants() []*OAuthConfig { + return m.cfgs[1:] +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/devicetoken.go b/vendor/github.com/Azure/go-autorest/autorest/adal/devicetoken.go index b38f4c2458..9daa4b58b8 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/adal/devicetoken.go +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/devicetoken.go @@ -24,6 +24,7 @@ package adal */ import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -101,7 +102,14 @@ type deviceToken struct { // InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode // that can be used with CheckForUserCompletion or WaitForUserCompletion. +// Deprecated: use InitiateDeviceAuthWithContext() instead. func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) { + return InitiateDeviceAuthWithContext(context.Background(), sender, oauthConfig, clientID, resource) +} + +// InitiateDeviceAuthWithContext initiates a device auth flow. It returns a DeviceCode +// that can be used with CheckForUserCompletion or WaitForUserCompletion. +func InitiateDeviceAuthWithContext(ctx context.Context, sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) { v := url.Values{ "client_id": []string{clientID}, "resource": []string{resource}, @@ -117,7 +125,7 @@ func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resour req.ContentLength = int64(len(s)) req.Header.Set(contentType, mimeTypeFormPost) - resp, err := sender.Do(req) + resp, err := sender.Do(req.WithContext(ctx)) if err != nil { return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error()) } @@ -151,7 +159,14 @@ func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resour // CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint // to see if the device flow has: been completed, timed out, or otherwise failed +// Deprecated: use CheckForUserCompletionWithContext() instead. func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) { + return CheckForUserCompletionWithContext(context.Background(), sender, code) +} + +// CheckForUserCompletionWithContext takes a DeviceCode and checks with the Azure AD OAuth endpoint +// to see if the device flow has: been completed, timed out, or otherwise failed +func CheckForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) { v := url.Values{ "client_id": []string{code.ClientID}, "code": []string{*code.DeviceCode}, @@ -169,7 +184,7 @@ func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) { req.ContentLength = int64(len(s)) req.Header.Set(contentType, mimeTypeFormPost) - resp, err := sender.Do(req) + resp, err := sender.Do(req.WithContext(ctx)) if err != nil { return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error()) } @@ -207,18 +222,29 @@ func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) { case "code_expired": return nil, ErrDeviceCodeExpired default: + // return a more meaningful error message if available + if token.ErrorDescription != nil { + return nil, fmt.Errorf("%s %s: %s", logPrefix, *token.Error, *token.ErrorDescription) + } return nil, ErrDeviceGeneric } } // WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs. // This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'. +// Deprecated: use WaitForUserCompletionWithContext() instead. func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) { + return WaitForUserCompletionWithContext(context.Background(), sender, code) +} + +// WaitForUserCompletionWithContext calls CheckForUserCompletion repeatedly until a token is granted or an error +// state occurs. This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'. +func WaitForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) { intervalDuration := time.Duration(*code.Interval) * time.Second waitDuration := intervalDuration for { - token, err := CheckForUserCompletion(sender, code) + token, err := CheckForUserCompletionWithContext(ctx, sender, code) if err == nil { return token, nil @@ -237,6 +263,11 @@ func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) { return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix) } - time.Sleep(waitDuration) + select { + case <-time.After(waitDuration): + // noop + case <-ctx.Done(): + return nil, ctx.Err() + } } } diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/go.mod b/vendor/github.com/Azure/go-autorest/autorest/adal/go.mod new file mode 100644 index 0000000000..8c5d36ca61 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/go.mod @@ -0,0 +1,13 @@ +module github.com/Azure/go-autorest/autorest/adal + +go 1.12 + +require ( + github.com/Azure/go-autorest v14.2.0+incompatible + github.com/Azure/go-autorest/autorest/date v0.3.0 + github.com/Azure/go-autorest/autorest/mocks v0.4.1 + github.com/Azure/go-autorest/logger v0.2.1 + github.com/Azure/go-autorest/tracing v0.6.0 + github.com/form3tech-oss/jwt-go v3.2.2+incompatible + golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0 +) diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/go_mod_tidy_hack.go b/vendor/github.com/Azure/go-autorest/autorest/adal/go_mod_tidy_hack.go new file mode 100644 index 0000000000..7551b79235 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/go_mod_tidy_hack.go @@ -0,0 +1,24 @@ +// +build modhack + +package adal + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file, and the github.com/Azure/go-autorest import, won't actually become part of +// the resultant binary. + +// Necessary for safely adding multi-module repo. +// See: https://github.com/golang/go/wiki/Modules#is-it-possible-to-add-a-module-to-a-multi-module-repository +import _ "github.com/Azure/go-autorest" diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/persist.go b/vendor/github.com/Azure/go-autorest/autorest/adal/persist.go index 9e15f2751f..2a974a39b3 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/adal/persist.go +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/persist.go @@ -15,11 +15,24 @@ package adal // limitations under the License. import ( + "crypto/rsa" + "crypto/x509" "encoding/json" + "errors" "fmt" "io/ioutil" "os" "path/filepath" + + "golang.org/x/crypto/pkcs12" +) + +var ( + // ErrMissingCertificate is returned when no local certificate is found in the provided PFX data. + ErrMissingCertificate = errors.New("adal: certificate missing") + + // ErrMissingPrivateKey is returned when no private key is found in the provided PFX data. + ErrMissingPrivateKey = errors.New("adal: private key missing") ) // LoadToken restores a Token object from a file located at 'path'. @@ -71,3 +84,52 @@ func SaveToken(path string, mode os.FileMode, token Token) error { } return nil } + +// DecodePfxCertificateData extracts the x509 certificate and RSA private key from the provided PFX data. +// The PFX data must contain a private key along with a certificate whose public key matches that of the +// private key or an error is returned. +// If the private key is not password protected pass the empty string for password. +func DecodePfxCertificateData(pfxData []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) { + blocks, err := pkcs12.ToPEM(pfxData, password) + if err != nil { + return nil, nil, err + } + // first extract the private key + var priv *rsa.PrivateKey + for _, block := range blocks { + if block.Type == "PRIVATE KEY" { + priv, err = x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, nil, err + } + break + } + } + if priv == nil { + return nil, nil, ErrMissingPrivateKey + } + // now find the certificate with the matching public key of our private key + var cert *x509.Certificate + for _, block := range blocks { + if block.Type == "CERTIFICATE" { + pcert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, nil, err + } + certKey, ok := pcert.PublicKey.(*rsa.PublicKey) + if !ok { + // keep looking + continue + } + if priv.E == certKey.E && priv.N.Cmp(certKey.N) == 0 { + // found a match + cert = pcert + break + } + } + } + if cert == nil { + return nil, nil, ErrMissingCertificate + } + return cert, priv, nil +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/sender.go b/vendor/github.com/Azure/go-autorest/autorest/adal/sender.go index 0e5ad14d39..1826a68dc8 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/adal/sender.go +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/sender.go @@ -15,7 +15,12 @@ package adal // limitations under the License. import ( + "crypto/tls" "net/http" + "net/http/cookiejar" + "sync" + + "github.com/Azure/go-autorest/tracing" ) const ( @@ -23,6 +28,10 @@ const ( mimeTypeFormPost = "application/x-www-form-urlencoded" ) +// DO NOT ACCESS THIS DIRECTLY. go through sender() +var defaultSender Sender +var defaultSenderInit = &sync.Once{} + // Sender is the interface that wraps the Do method to send HTTP requests. // // The standard http.Client conforms to this interface. @@ -38,14 +47,14 @@ func (sf SenderFunc) Do(r *http.Request) (*http.Response, error) { return sf(r) } -// SendDecorator takes and possibily decorates, by wrapping, a Sender. Decorators may affect the +// SendDecorator takes and possibly decorates, by wrapping, a Sender. Decorators may affect the // http.Request and pass it along or, first, pass the http.Request along then react to the // http.Response result. type SendDecorator func(Sender) Sender // CreateSender creates, decorates, and returns, as a Sender, the default http.Client. func CreateSender(decorators ...SendDecorator) Sender { - return DecorateSender(&http.Client{}, decorators...) + return DecorateSender(sender(), decorators...) } // DecorateSender accepts a Sender and a, possibly empty, set of SendDecorators, which is applies to @@ -58,3 +67,30 @@ func DecorateSender(s Sender, decorators ...SendDecorator) Sender { } return s } + +func sender() Sender { + // note that we can't init defaultSender in init() since it will + // execute before calling code has had a chance to enable tracing + defaultSenderInit.Do(func() { + // Use behaviour compatible with DefaultTransport, but require TLS minimum version. + defaultTransport := http.DefaultTransport.(*http.Transport) + transport := &http.Transport{ + Proxy: defaultTransport.Proxy, + DialContext: defaultTransport.DialContext, + MaxIdleConns: defaultTransport.MaxIdleConns, + IdleConnTimeout: defaultTransport.IdleConnTimeout, + TLSHandshakeTimeout: defaultTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + } + var roundTripper http.RoundTripper = transport + if tracing.IsEnabled() { + roundTripper = tracing.NewTransport(transport) + } + j, _ := cookiejar.New(nil) + defaultSender = &http.Client{Jar: j, Transport: roundTripper} + }) + return defaultSender +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/token.go b/vendor/github.com/Azure/go-autorest/autorest/adal/token.go index 67c5a0b0ba..1d0241f8e4 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/adal/token.go +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/token.go @@ -22,18 +22,22 @@ import ( "crypto/x509" "encoding/base64" "encoding/json" + "errors" "fmt" + "io" "io/ioutil" - "net" + "math" "net/http" "net/url" + "os" "strconv" "strings" "sync" "time" "github.com/Azure/go-autorest/autorest/date" - "github.com/dgrijalva/jwt-go" + "github.com/Azure/go-autorest/logger" + "github.com/form3tech-oss/jwt-go" ) const ( @@ -59,6 +63,30 @@ const ( // msiEndpoint is the well known endpoint for getting MSI authentications tokens msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" + + // the API version to use for the MSI endpoint + msiAPIVersion = "2018-02-01" + + // the default number of attempts to refresh an MSI authentication token + defaultMaxMSIRefreshAttempts = 5 + + // asMSIEndpointEnv is the environment variable used to store the endpoint on App Service and Functions + msiEndpointEnv = "MSI_ENDPOINT" + + // asMSISecretEnv is the environment variable used to store the request secret on App Service and Functions + msiSecretEnv = "MSI_SECRET" + + // the API version to use for the legacy App Service MSI endpoint + appServiceAPIVersion2017 = "2017-09-01" + + // secret header used when authenticating against app service MSI endpoint + secretHeader = "Secret" + + // the format for expires_on in UTC with AM/PM + expiresOnDateFormatPM = "1/2/2006 15:04:05 PM +00:00" + + // the format for expires_on in UTC without AM/PM + expiresOnDateFormat = "1/2/2006 15:04:05 +00:00" ) // OAuthTokenProvider is an interface which should be implemented by an access token retriever @@ -66,6 +94,12 @@ type OAuthTokenProvider interface { OAuthToken() string } +// MultitenantOAuthTokenProvider provides tokens used for multi-tenant authorization. +type MultitenantOAuthTokenProvider interface { + PrimaryOAuthToken() string + AuxiliaryOAuthTokens() []string +} + // TokenRefreshError is an interface used by errors returned during token refresh. type TokenRefreshError interface { error @@ -90,19 +124,31 @@ type RefresherWithContext interface { // a successful token refresh type TokenRefreshCallback func(Token) error +// TokenRefresh is a type representing a custom callback to refresh a token +type TokenRefresh func(ctx context.Context, resource string) (*Token, error) + // Token encapsulates the access token used to authorize Azure requests. +// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response type Token struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` - ExpiresIn string `json:"expires_in"` - ExpiresOn string `json:"expires_on"` - NotBefore string `json:"not_before"` + ExpiresIn json.Number `json:"expires_in"` + ExpiresOn json.Number `json:"expires_on"` + NotBefore json.Number `json:"not_before"` Resource string `json:"resource"` Type string `json:"token_type"` } +func newToken() Token { + return Token{ + ExpiresIn: "0", + ExpiresOn: "0", + NotBefore: "0", + } +} + // IsZero returns true if the token object is zero-initialized. func (t Token) IsZero() bool { return t == Token{} @@ -110,12 +156,12 @@ func (t Token) IsZero() bool { // Expires returns the time.Time when the Token expires. func (t Token) Expires() time.Time { - s, err := strconv.Atoi(t.ExpiresOn) + s, err := t.ExpiresOn.Float64() if err != nil { s = -3600 } - expiration := date.NewUnixTimeFromSeconds(float64(s)) + expiration := date.NewUnixTimeFromSeconds(s) return time.Time(expiration).UTC() } @@ -136,6 +182,12 @@ func (t *Token) OAuthToken() string { return t.AccessToken } +// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form +// that is submitted when acquiring an oAuth token. +type ServicePrincipalSecret interface { + SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error +} + // ServicePrincipalNoSecret represents a secret type that contains no secret // meaning it is not valid for fetching a fresh token. This is used by Manual type ServicePrincipalNoSecret struct { @@ -147,15 +199,19 @@ func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePr return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token") } -// ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form -// that is submitted when acquiring an oAuth token. -type ServicePrincipalSecret interface { - SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error +// MarshalJSON implements the json.Marshaler interface. +func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) { + type tokenType struct { + Type string `json:"type"` + } + return json.Marshal(tokenType{ + Type: "ServicePrincipalNoSecret", + }) } // ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization. type ServicePrincipalTokenSecret struct { - ClientSecret string + ClientSecret string `json:"value"` } // SetAuthenticationValues is a method of the interface ServicePrincipalSecret. @@ -165,49 +221,24 @@ func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *Ser return nil } +// MarshalJSON implements the json.Marshaler interface. +func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) { + type tokenType struct { + Type string `json:"type"` + Value string `json:"value"` + } + return json.Marshal(tokenType{ + Type: "ServicePrincipalTokenSecret", + Value: tokenSecret.ClientSecret, + }) +} + // ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs. type ServicePrincipalCertificateSecret struct { Certificate *x509.Certificate PrivateKey *rsa.PrivateKey } -// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension. -type ServicePrincipalMSISecret struct { -} - -// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth. -type ServicePrincipalUsernamePasswordSecret struct { - Username string - Password string -} - -// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth. -type ServicePrincipalAuthorizationCodeSecret struct { - ClientSecret string - AuthorizationCode string - RedirectURI string -} - -// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. -func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { - v.Set("code", secret.AuthorizationCode) - v.Set("client_secret", secret.ClientSecret) - v.Set("redirect_uri", secret.RedirectURI) - return nil -} - -// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. -func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { - v.Set("username", secret.Username) - v.Set("password", secret.Password) - return nil -} - -// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. -func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { - return nil -} - // SignJwt returns the JWT signed with the certificate's private key. func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) { hasher := sha1.New() @@ -227,13 +258,15 @@ func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalTo token := jwt.New(jwt.SigningMethodRS256) token.Header["x5t"] = thumbprint + x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)} + token.Header["x5c"] = x5c token.Claims = jwt.MapClaims{ - "aud": spt.oauthConfig.TokenEndpoint.String(), - "iss": spt.clientID, - "sub": spt.clientID, + "aud": spt.inner.OauthConfig.TokenEndpoint.String(), + "iss": spt.inner.ClientID, + "sub": spt.inner.ClientID, "jti": base64.URLEncoding.EncodeToString(jti), "nbf": time.Now().Unix(), - "exp": time.Now().Add(time.Hour * 24).Unix(), + "exp": time.Now().Add(24 * time.Hour).Unix(), } signedString, err := token.SignedString(secret.PrivateKey) @@ -253,19 +286,165 @@ func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *Se return nil } +// MarshalJSON implements the json.Marshaler interface. +func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) { + return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported") +} + +// ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension. +type ServicePrincipalMSISecret struct { + msiType msiType + clientResourceID string +} + +// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. +func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) { + return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported") +} + +// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth. +type ServicePrincipalUsernamePasswordSecret struct { + Username string `json:"username"` + Password string `json:"password"` +} + +// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. +func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { + v.Set("username", secret.Username) + v.Set("password", secret.Password) + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) { + type tokenType struct { + Type string `json:"type"` + Username string `json:"username"` + Password string `json:"password"` + } + return json.Marshal(tokenType{ + Type: "ServicePrincipalUsernamePasswordSecret", + Username: secret.Username, + Password: secret.Password, + }) +} + +// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth. +type ServicePrincipalAuthorizationCodeSecret struct { + ClientSecret string `json:"value"` + AuthorizationCode string `json:"authCode"` + RedirectURI string `json:"redirect"` +} + +// SetAuthenticationValues is a method of the interface ServicePrincipalSecret. +func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { + v.Set("code", secret.AuthorizationCode) + v.Set("client_secret", secret.ClientSecret) + v.Set("redirect_uri", secret.RedirectURI) + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) { + type tokenType struct { + Type string `json:"type"` + Value string `json:"value"` + AuthCode string `json:"authCode"` + Redirect string `json:"redirect"` + } + return json.Marshal(tokenType{ + Type: "ServicePrincipalAuthorizationCodeSecret", + Value: secret.ClientSecret, + AuthCode: secret.AuthorizationCode, + Redirect: secret.RedirectURI, + }) +} + // ServicePrincipalToken encapsulates a Token created for a Service Principal. type ServicePrincipalToken struct { - token Token - secret ServicePrincipalSecret - oauthConfig OAuthConfig - clientID string - resource string - autoRefresh bool - refreshLock *sync.RWMutex - refreshWithin time.Duration - sender Sender + inner servicePrincipalToken + refreshLock *sync.RWMutex + sender Sender + customRefreshFunc TokenRefresh + refreshCallbacks []TokenRefreshCallback + // MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token. + // Settings this to a value less than 1 will use the default value. + MaxMSIRefreshAttempts int +} + +// MarshalTokenJSON returns the marshalled inner token. +func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) { + return json.Marshal(spt.inner.Token) +} + +// SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks. +func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) { + spt.refreshCallbacks = callbacks +} + +// SetCustomRefreshFunc sets a custom refresh function used to refresh the token. +func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) { + spt.customRefreshFunc = customRefreshFunc +} + +// MarshalJSON implements the json.Marshaler interface. +func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) { + return json.Marshal(spt.inner) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error { + // need to determine the token type + raw := map[string]interface{}{} + err := json.Unmarshal(data, &raw) + if err != nil { + return err + } + secret := raw["secret"].(map[string]interface{}) + switch secret["type"] { + case "ServicePrincipalNoSecret": + spt.inner.Secret = &ServicePrincipalNoSecret{} + case "ServicePrincipalTokenSecret": + spt.inner.Secret = &ServicePrincipalTokenSecret{} + case "ServicePrincipalCertificateSecret": + return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported") + case "ServicePrincipalMSISecret": + return errors.New("unmarshalling ServicePrincipalMSISecret is not supported") + case "ServicePrincipalUsernamePasswordSecret": + spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{} + case "ServicePrincipalAuthorizationCodeSecret": + spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{} + default: + return fmt.Errorf("unrecognized token type '%s'", secret["type"]) + } + err = json.Unmarshal(data, &spt.inner) + if err != nil { + return err + } + // Don't override the refreshLock or the sender if those have been already set. + if spt.refreshLock == nil { + spt.refreshLock = &sync.RWMutex{} + } + if spt.sender == nil { + spt.sender = sender() + } + return nil +} - refreshCallbacks []TokenRefreshCallback +// internal type used for marshalling/unmarshalling +type servicePrincipalToken struct { + Token Token `json:"token"` + Secret ServicePrincipalSecret `json:"secret"` + OauthConfig OAuthConfig `json:"oauth"` + ClientID string `json:"clientID"` + Resource string `json:"resource"` + AutoRefresh bool `json:"autoRefresh"` + RefreshWithin time.Duration `json:"refreshWithin"` } func validateOAuthConfig(oac OAuthConfig) error { @@ -290,14 +469,17 @@ func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, reso return nil, fmt.Errorf("parameter 'secret' cannot be nil") } spt := &ServicePrincipalToken{ - oauthConfig: oauthConfig, - secret: secret, - clientID: id, - resource: resource, - autoRefresh: true, + inner: servicePrincipalToken{ + Token: newToken(), + OauthConfig: oauthConfig, + Secret: secret, + ClientID: id, + Resource: resource, + AutoRefresh: true, + RefreshWithin: defaultRefresh, + }, refreshLock: &sync.RWMutex{}, - refreshWithin: defaultRefresh, - sender: &http.Client{}, + sender: sender(), refreshCallbacks: callbacks, } return spt, nil @@ -327,7 +509,39 @@ func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID s return nil, err } - spt.token = token + spt.inner.Token = token + + return spt, nil +} + +// NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret +func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { + if err := validateOAuthConfig(oauthConfig); err != nil { + return nil, err + } + if err := validateStringParam(clientID, "clientID"); err != nil { + return nil, err + } + if err := validateStringParam(resource, "resource"); err != nil { + return nil, err + } + if secret == nil { + return nil, fmt.Errorf("parameter 'secret' cannot be nil") + } + if token.IsZero() { + return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized") + } + spt, err := NewServicePrincipalTokenWithSecret( + oauthConfig, + clientID, + resource, + secret, + callbacks...) + if err != nil { + return nil, err + } + + spt.inner.Token = token return spt, nil } @@ -451,64 +665,193 @@ func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clie ) } +type msiType int + +const ( + msiTypeUnavailable msiType = iota + msiTypeAppServiceV20170901 + msiTypeCloudShell + msiTypeIMDS +) + +func (m msiType) String() string { + switch m { + case msiTypeUnavailable: + return "unavailable" + case msiTypeAppServiceV20170901: + return "AppServiceV20170901" + case msiTypeCloudShell: + return "CloudShell" + case msiTypeIMDS: + return "IMDS" + default: + return fmt.Sprintf("unhandled MSI type %d", m) + } +} + +// returns the MSI type and endpoint, or an error +func getMSIType() (msiType, string, error) { + if endpointEnvVar := os.Getenv(msiEndpointEnv); endpointEnvVar != "" { + // if the env var MSI_ENDPOINT is set + if secretEnvVar := os.Getenv(msiSecretEnv); secretEnvVar != "" { + // if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the msiType is AppService + return msiTypeAppServiceV20170901, endpointEnvVar, nil + } + // if ONLY the env var MSI_ENDPOINT is set the msiType is CloudShell + return msiTypeCloudShell, endpointEnvVar, nil + } else if msiAvailableHook(context.Background(), sender()) { + // if MSI_ENDPOINT is NOT set AND the IMDS endpoint is available the msiType is IMDS. This will timeout after 500 milliseconds + return msiTypeIMDS, msiEndpoint, nil + } else { + // if MSI_ENDPOINT is NOT set and IMDS endpoint is not available Managed Identity is not available + return msiTypeUnavailable, "", errors.New("MSI not available") + } +} + // GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines. +// NOTE: this always returns the IMDS endpoint, it does not work for app services or cloud shell. +// Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint. func GetMSIVMEndpoint() (string, error) { return msiEndpoint, nil } +// GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions. +// It will return an error when not running in an app service/functions environment. +// Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint. +func GetMSIAppServiceEndpoint() (string, error) { + msiType, endpoint, err := getMSIType() + if err != nil { + return "", err + } + switch msiType { + case msiTypeAppServiceV20170901: + return endpoint, nil + default: + return "", fmt.Errorf("%s is not app service environment", msiType) + } +} + +// GetMSIEndpoint get the appropriate MSI endpoint depending on the runtime environment +// Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint. +func GetMSIEndpoint() (string, error) { + _, endpoint, err := getMSIType() + return endpoint, err +} + // NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension. // It will use the system assigned identity when creating the token. +// msiEndpoint - empty string, or pass a non-empty string to override the default value. +// Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead. func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { - return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...) + return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", "", callbacks...) } // NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension. -// It will use the specified user assigned identity when creating the token. +// It will use the clientID of specified user assigned identity when creating the token. +// msiEndpoint - empty string, or pass a non-empty string to override the default value. +// Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead. func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { - return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...) + if err := validateStringParam(userAssignedID, "userAssignedID"); err != nil { + return nil, err + } + return newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, "", callbacks...) } -func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { - if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil { +// NewServicePrincipalTokenFromMSIWithIdentityResourceID creates a ServicePrincipalToken via the MSI VM Extension. +// It will use the azure resource id of user assigned identity when creating the token. +// msiEndpoint - empty string, or pass a non-empty string to override the default value. +// Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead. +func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { + if err := validateStringParam(identityResourceID, "identityResourceID"); err != nil { return nil, err } + return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", identityResourceID, callbacks...) +} + +// ManagedIdentityOptions contains optional values for configuring managed identity authentication. +type ManagedIdentityOptions struct { + // ClientID is the user-assigned identity to use during authentication. + // It is mutually exclusive with IdentityResourceID. + ClientID string + + // IdentityResourceID is the resource ID of the user-assigned identity to use during authentication. + // It is mutually exclusive with ClientID. + IdentityResourceID string +} + +// NewServicePrincipalTokenFromManagedIdentity creates a ServicePrincipalToken using a managed identity. +// It supports the following managed identity environments. +// - App Service Environment (API version 2017-09-01 only) +// - Cloud shell +// - IMDS with a system or user assigned identity +func NewServicePrincipalTokenFromManagedIdentity(resource string, options *ManagedIdentityOptions, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { + if options == nil { + options = &ManagedIdentityOptions{} + } + return newServicePrincipalTokenFromMSI("", resource, options.ClientID, options.IdentityResourceID, callbacks...) +} + +func newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { if err := validateStringParam(resource, "resource"); err != nil { return nil, err } - if userAssignedID != nil { - if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil { - return nil, err - } + if userAssignedID != "" && identityResourceID != "" { + return nil, errors.New("cannot specify userAssignedID and identityResourceID") } - // We set the oauth config token endpoint to be MSI's endpoint - msiEndpointURL, err := url.Parse(msiEndpoint) + msiType, endpoint, err := getMSIType() if err != nil { + logger.Instance.Writef(logger.LogError, "Error determining managed identity environment: %v\n", err) return nil, err } - - v := url.Values{} - v.Set("resource", resource) - v.Set("api-version", "2018-02-01") - if userAssignedID != nil { - v.Set("client_id", *userAssignedID) + logger.Instance.Writef(logger.LogInfo, "Managed identity environment is %s, endpoint is %s\n", msiType, endpoint) + if msiEndpoint != "" { + endpoint = msiEndpoint + logger.Instance.Writef(logger.LogInfo, "Managed identity custom endpoint is %s\n", endpoint) + } + msiEndpointURL, err := url.Parse(endpoint) + if err != nil { + return nil, err + } + // cloud shell sends its data in the request body + if msiType != msiTypeCloudShell { + v := url.Values{} + v.Set("resource", resource) + clientIDParam := "client_id" + switch msiType { + case msiTypeAppServiceV20170901: + clientIDParam = "clientid" + v.Set("api-version", appServiceAPIVersion2017) + break + case msiTypeIMDS: + v.Set("api-version", msiAPIVersion) + } + if userAssignedID != "" { + v.Set(clientIDParam, userAssignedID) + } else if identityResourceID != "" { + v.Set("mi_res_id", identityResourceID) + } + msiEndpointURL.RawQuery = v.Encode() } - msiEndpointURL.RawQuery = v.Encode() spt := &ServicePrincipalToken{ - oauthConfig: OAuthConfig{ - TokenEndpoint: *msiEndpointURL, + inner: servicePrincipalToken{ + Token: newToken(), + OauthConfig: OAuthConfig{ + TokenEndpoint: *msiEndpointURL, + }, + Secret: &ServicePrincipalMSISecret{ + msiType: msiType, + clientResourceID: identityResourceID, + }, + Resource: resource, + AutoRefresh: true, + RefreshWithin: defaultRefresh, + ClientID: userAssignedID, }, - secret: &ServicePrincipalMSISecret{}, - resource: resource, - autoRefresh: true, - refreshLock: &sync.RWMutex{}, - refreshWithin: defaultRefresh, - sender: &http.Client{}, - refreshCallbacks: callbacks, - } - - if userAssignedID != nil { - spt.clientID = *userAssignedID + refreshLock: &sync.RWMutex{}, + sender: sender(), + refreshCallbacks: callbacks, + MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts, } return spt, nil @@ -543,12 +886,13 @@ func (spt *ServicePrincipalToken) EnsureFresh() error { // EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by // RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error { - if spt.autoRefresh && spt.token.WillExpireIn(spt.refreshWithin) { - // take the write lock then check to see if the token was already refreshed + // must take the read lock when initially checking the token's expiration + if spt.inner.AutoRefresh && spt.Token().WillExpireIn(spt.inner.RefreshWithin) { + // take the write lock then check again to see if the token was already refreshed spt.refreshLock.Lock() defer spt.refreshLock.Unlock() - if spt.token.WillExpireIn(spt.refreshWithin) { - return spt.refreshInternal(ctx, spt.resource) + if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) { + return spt.refreshInternal(ctx, spt.inner.Resource) } } return nil @@ -558,7 +902,7 @@ func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) er func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { if spt.refreshCallbacks != nil { for _, callback := range spt.refreshCallbacks { - err := callback(spt.token) + err := callback(spt.inner.Token) if err != nil { return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err) } @@ -568,27 +912,27 @@ func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { } // Refresh obtains a fresh token for the Service Principal. -// This method is not safe for concurrent use and should be syncrhonized. +// This method is safe for concurrent use. func (spt *ServicePrincipalToken) Refresh() error { return spt.RefreshWithContext(context.Background()) } // RefreshWithContext obtains a fresh token for the Service Principal. -// This method is not safe for concurrent use and should be syncrhonized. +// This method is safe for concurrent use. func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error { spt.refreshLock.Lock() defer spt.refreshLock.Unlock() - return spt.refreshInternal(ctx, spt.resource) + return spt.refreshInternal(ctx, spt.inner.Resource) } // RefreshExchange refreshes the token, but for a different resource. -// This method is not safe for concurrent use and should be syncrhonized. +// This method is safe for concurrent use. func (spt *ServicePrincipalToken) RefreshExchange(resource string) error { return spt.RefreshExchangeWithContext(context.Background(), resource) } // RefreshExchangeWithContext refreshes the token, but for a different resource. -// This method is not safe for concurrent use and should be syncrhonized. +// This method is safe for concurrent use. func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error { spt.refreshLock.Lock() defer spt.refreshLock.Unlock() @@ -596,7 +940,7 @@ func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context } func (spt *ServicePrincipalToken) getGrantType() string { - switch spt.secret.(type) { + switch spt.inner.Secret.(type) { case *ServicePrincipalUsernamePasswordSecret: return OAuthGrantTypeUserPass case *ServicePrincipalAuthorizationCodeSecret: @@ -606,31 +950,72 @@ func (spt *ServicePrincipalToken) getGrantType() string { } } -func isIMDS(u url.URL) bool { - imds, err := url.Parse(msiEndpoint) - if err != nil { - return false - } - return u.Host == imds.Host && u.Path == imds.Path -} - func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error { - req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), nil) + if spt.customRefreshFunc != nil { + token, err := spt.customRefreshFunc(ctx, resource) + if err != nil { + return err + } + spt.inner.Token = *token + return spt.InvokeRefreshCallbacks(spt.inner.Token) + } + req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil) if err != nil { return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err) } + req.Header.Add("User-Agent", UserAgent()) req = req.WithContext(ctx) - if !isIMDS(spt.oauthConfig.TokenEndpoint) { + var resp *http.Response + authBodyFilter := func(b []byte) []byte { + if logger.Level() != logger.LogAuth { + return []byte("**REDACTED** authentication body") + } + return b + } + if msiSecret, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok { + switch msiSecret.msiType { + case msiTypeAppServiceV20170901: + req.Method = http.MethodGet + req.Header.Set("secret", os.Getenv(msiSecretEnv)) + break + case msiTypeCloudShell: + req.Header.Set("Metadata", "true") + data := url.Values{} + data.Set("resource", spt.inner.Resource) + if spt.inner.ClientID != "" { + data.Set("client_id", spt.inner.ClientID) + } else if msiSecret.clientResourceID != "" { + data.Set("msi_res_id", msiSecret.clientResourceID) + } + req.Body = ioutil.NopCloser(strings.NewReader(data.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + break + case msiTypeIMDS: + req.Method = http.MethodGet + req.Header.Set("Metadata", "true") + break + } + logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter}) + resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts) + } else { v := url.Values{} - v.Set("client_id", spt.clientID) + v.Set("client_id", spt.inner.ClientID) v.Set("resource", resource) - if spt.token.RefreshToken != "" { + if spt.inner.Token.RefreshToken != "" { v.Set("grant_type", OAuthGrantTypeRefreshToken) - v.Set("refresh_token", spt.token.RefreshToken) + v.Set("refresh_token", spt.inner.Token.RefreshToken) + // web apps must specify client_secret when refreshing tokens + // see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens + if spt.getGrantType() == OAuthGrantTypeAuthorizationCode { + err := spt.inner.Secret.SetAuthenticationValues(spt, &v) + if err != nil { + return err + } + } } else { v.Set("grant_type", spt.getGrantType()) - err := spt.secret.SetAuthenticationValues(spt, &v) + err := spt.inner.Secret.SetAuthenticationValues(spt, &v) if err != nil { return err } @@ -641,31 +1026,26 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource req.ContentLength = int64(len(s)) req.Header.Set(contentType, mimeTypeFormPost) req.Body = body - } - - if _, ok := spt.secret.(*ServicePrincipalMSISecret); ok { - req.Method = http.MethodGet - req.Header.Set(metadataHeader, "true") - } - - var resp *http.Response - if isIMDS(spt.oauthConfig.TokenEndpoint) { - resp, err = retry(spt.sender, req) - } else { + logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter}) resp, err = spt.sender.Do(req) } + + // don't return a TokenRefreshError here; this will allow retry logic to apply if err != nil { - return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil) + return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err) + } else if resp == nil { + return fmt.Errorf("adal: received nil response and error") } + logger.Instance.WriteResponse(resp, logger.Filter{Body: authBodyFilter}) defer resp.Body.Close() rb, err := ioutil.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { if err != nil { - return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp) + return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v Endpoint %s", resp.StatusCode, err, req.URL.String()), resp) } - return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp) + return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s Endpoint %s", resp.StatusCode, string(rb), req.URL.String()), resp) } // for the following error cases don't return a TokenRefreshError. the operation succeeded @@ -678,18 +1058,65 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource if len(strings.Trim(string(rb), " ")) == 0 { return fmt.Errorf("adal: Empty service principal token received during refresh") } - var token Token + token := struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + + // AAD returns expires_in as a string, ADFS returns it as an int + ExpiresIn json.Number `json:"expires_in"` + // expires_on can be in two formats, a UTC time stamp or the number of seconds. + ExpiresOn string `json:"expires_on"` + NotBefore json.Number `json:"not_before"` + + Resource string `json:"resource"` + Type string `json:"token_type"` + }{} + // return a TokenRefreshError in the follow error cases as the token is in an unexpected format err = json.Unmarshal(rb, &token) if err != nil { - return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)) + return newTokenRefreshError(fmt.Sprintf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)), resp) } + expiresOn := json.Number("") + // ADFS doesn't include the expires_on field + if token.ExpiresOn != "" { + if expiresOn, err = parseExpiresOn(token.ExpiresOn); err != nil { + return newTokenRefreshError(fmt.Sprintf("adal: failed to parse expires_on: %v value '%s'", err, token.ExpiresOn), resp) + } + } + spt.inner.Token.AccessToken = token.AccessToken + spt.inner.Token.RefreshToken = token.RefreshToken + spt.inner.Token.ExpiresIn = token.ExpiresIn + spt.inner.Token.ExpiresOn = expiresOn + spt.inner.Token.NotBefore = token.NotBefore + spt.inner.Token.Resource = token.Resource + spt.inner.Token.Type = token.Type + + return spt.InvokeRefreshCallbacks(spt.inner.Token) +} - spt.token = token - - return spt.InvokeRefreshCallbacks(token) +// converts expires_on to the number of seconds +func parseExpiresOn(s string) (json.Number, error) { + // convert the expiration date to the number of seconds from now + timeToDuration := func(t time.Time) json.Number { + dur := t.Sub(time.Now().UTC()) + return json.Number(strconv.FormatInt(int64(dur.Round(time.Second).Seconds()), 10)) + } + if _, err := strconv.ParseInt(s, 10, 64); err == nil { + // this is the number of seconds case, no conversion required + return json.Number(s), nil + } else if eo, err := time.Parse(expiresOnDateFormatPM, s); err == nil { + return timeToDuration(eo), nil + } else if eo, err := time.Parse(expiresOnDateFormat, s); err == nil { + return timeToDuration(eo), nil + } else { + // unknown format + return json.Number(""), err + } } -func retry(sender Sender, req *http.Request) (resp *http.Response, err error) { +// retry logic specific to retrieving a token from the IMDS endpoint +func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) { + // copied from client.go due to circular dependency retries := []int{ http.StatusRequestTimeout, // 408 http.StatusTooManyRequests, // 429 @@ -698,8 +1125,10 @@ func retry(sender Sender, req *http.Request) (resp *http.Response, err error) { http.StatusServiceUnavailable, // 503 http.StatusGatewayTimeout, // 504 } - // Extra retry status codes requered - retries = append(retries, http.StatusNotFound, + // extra retry status codes specific to IMDS + retries = append(retries, + http.StatusNotFound, + http.StatusGone, // all remaining 5xx http.StatusNotImplemented, http.StatusHTTPVersionNotSupported, @@ -709,56 +1138,55 @@ func retry(sender Sender, req *http.Request) (resp *http.Response, err error) { http.StatusNotExtended, http.StatusNetworkAuthenticationRequired) + // see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance + + const maxDelay time.Duration = 60 * time.Second + attempt := 0 - maxAttempts := 5 + delay := time.Duration(0) + + // maxAttempts is user-specified, ensure that its value is greater than zero else no request will be made + if maxAttempts < 1 { + maxAttempts = defaultMaxMSIRefreshAttempts + } for attempt < maxAttempts { + if resp != nil && resp.Body != nil { + io.Copy(ioutil.Discard, resp.Body) + resp.Body.Close() + } resp, err = sender.Do(req) - // retry on temporary network errors, e.g. transient network failures. - if (err != nil && !isTemporaryNetworkError(err)) || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) { + // we want to retry if err is not nil or the status code is in the list of retry codes + if err == nil && !responseHasStatusCode(resp, retries...) { return } - if !delay(resp, req.Context().Done()) { - select { - case <-time.After(time.Second): - attempt++ - case <-req.Context().Done(): - err = req.Context().Err() - return - } + // perform exponential backoff with a cap. + // must increment attempt before calculating delay. + attempt++ + // the base value of 2 is the "delta backoff" as specified in the guidance doc + delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second) + if delay > maxDelay { + delay = maxDelay } - } - return -} -func isTemporaryNetworkError(err error) bool { - if netErr, ok := err.(net.Error); ok && netErr.Temporary() { - return true - } - return false -} - -func containsInt(ints []int, n int) bool { - for _, i := range ints { - if i == n { - return true + select { + case <-time.After(delay): + // intentionally left blank + case <-req.Context().Done(): + err = req.Context().Err() + return } } - return false + return } -func delay(resp *http.Response, cancel <-chan struct{}) bool { - if resp == nil { - return false - } - retryAfter, _ := strconv.Atoi(resp.Header.Get("Retry-After")) - if resp.StatusCode == http.StatusTooManyRequests && retryAfter > 0 { - select { - case <-time.After(time.Duration(retryAfter) * time.Second): - return true - case <-cancel: - return false +func responseHasStatusCode(resp *http.Response, codes ...int) bool { + if resp != nil { + for _, i := range codes { + if i == resp.StatusCode { + return true + } } } return false @@ -766,13 +1194,13 @@ func delay(resp *http.Response, cancel <-chan struct{}) bool { // SetAutoRefresh enables or disables automatic refreshing of stale tokens. func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) { - spt.autoRefresh = autoRefresh + spt.inner.AutoRefresh = autoRefresh } // SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will // refresh the token. func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) { - spt.refreshWithin = d + spt.inner.RefreshWithin = d return } @@ -784,12 +1212,125 @@ func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s } func (spt *ServicePrincipalToken) OAuthToken() string { spt.refreshLock.RLock() defer spt.refreshLock.RUnlock() - return spt.token.OAuthToken() + return spt.inner.Token.OAuthToken() } // Token returns a copy of the current token. func (spt *ServicePrincipalToken) Token() Token { spt.refreshLock.RLock() defer spt.refreshLock.RUnlock() - return spt.token + return spt.inner.Token +} + +// MultiTenantServicePrincipalToken contains tokens for multi-tenant authorization. +type MultiTenantServicePrincipalToken struct { + PrimaryToken *ServicePrincipalToken + AuxiliaryTokens []*ServicePrincipalToken +} + +// PrimaryOAuthToken returns the primary authorization token. +func (mt *MultiTenantServicePrincipalToken) PrimaryOAuthToken() string { + return mt.PrimaryToken.OAuthToken() +} + +// AuxiliaryOAuthTokens returns one to three auxiliary authorization tokens. +func (mt *MultiTenantServicePrincipalToken) AuxiliaryOAuthTokens() []string { + tokens := make([]string, len(mt.AuxiliaryTokens)) + for i := range mt.AuxiliaryTokens { + tokens[i] = mt.AuxiliaryTokens[i].OAuthToken() + } + return tokens +} + +// NewMultiTenantServicePrincipalToken creates a new MultiTenantServicePrincipalToken with the specified credentials and resource. +func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig, clientID string, secret string, resource string) (*MultiTenantServicePrincipalToken, error) { + if err := validateStringParam(clientID, "clientID"); err != nil { + return nil, err + } + if err := validateStringParam(secret, "secret"); err != nil { + return nil, err + } + if err := validateStringParam(resource, "resource"); err != nil { + return nil, err + } + auxTenants := multiTenantCfg.AuxiliaryTenants() + m := MultiTenantServicePrincipalToken{ + AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)), + } + primary, err := NewServicePrincipalToken(*multiTenantCfg.PrimaryTenant(), clientID, secret, resource) + if err != nil { + return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err) + } + m.PrimaryToken = primary + for i := range auxTenants { + aux, err := NewServicePrincipalToken(*auxTenants[i], clientID, secret, resource) + if err != nil { + return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err) + } + m.AuxiliaryTokens[i] = aux + } + return &m, nil +} + +// NewMultiTenantServicePrincipalTokenFromCertificate creates a new MultiTenantServicePrincipalToken with the specified certificate credentials and resource. +func NewMultiTenantServicePrincipalTokenFromCertificate(multiTenantCfg MultiTenantOAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string) (*MultiTenantServicePrincipalToken, error) { + if err := validateStringParam(clientID, "clientID"); err != nil { + return nil, err + } + if err := validateStringParam(resource, "resource"); err != nil { + return nil, err + } + if certificate == nil { + return nil, fmt.Errorf("parameter 'certificate' cannot be nil") + } + if privateKey == nil { + return nil, fmt.Errorf("parameter 'privateKey' cannot be nil") + } + auxTenants := multiTenantCfg.AuxiliaryTenants() + m := MultiTenantServicePrincipalToken{ + AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)), + } + primary, err := NewServicePrincipalTokenWithSecret( + *multiTenantCfg.PrimaryTenant(), + clientID, + resource, + &ServicePrincipalCertificateSecret{ + PrivateKey: privateKey, + Certificate: certificate, + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err) + } + m.PrimaryToken = primary + for i := range auxTenants { + aux, err := NewServicePrincipalTokenWithSecret( + *auxTenants[i], + clientID, + resource, + &ServicePrincipalCertificateSecret{ + PrivateKey: privateKey, + Certificate: certificate, + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err) + } + m.AuxiliaryTokens[i] = aux + } + return &m, nil +} + +// MSIAvailable returns true if the MSI endpoint is available for authentication. +func MSIAvailable(ctx context.Context, sender Sender) bool { + resp, err := getMSIEndpoint(ctx, sender) + if err == nil { + resp.Body.Close() + } + return err == nil +} + +// used for testing purposes +var msiAvailableHook = func(ctx context.Context, sender Sender) bool { + return MSIAvailable(ctx, sender) } diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/token_1.13.go b/vendor/github.com/Azure/go-autorest/autorest/adal/token_1.13.go new file mode 100644 index 0000000000..953f755028 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/token_1.13.go @@ -0,0 +1,75 @@ +// +build go1.13 + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adal + +import ( + "context" + "fmt" + "net/http" + "time" +) + +func getMSIEndpoint(ctx context.Context, sender Sender) (*http.Response, error) { + tempCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + // http.NewRequestWithContext() was added in Go 1.13 + req, _ := http.NewRequestWithContext(tempCtx, http.MethodGet, msiEndpoint, nil) + q := req.URL.Query() + q.Add("api-version", msiAPIVersion) + req.URL.RawQuery = q.Encode() + return sender.Do(req) +} + +// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by +// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. +func (mt *MultiTenantServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error { + if err := mt.PrimaryToken.EnsureFreshWithContext(ctx); err != nil { + return fmt.Errorf("failed to refresh primary token: %w", err) + } + for _, aux := range mt.AuxiliaryTokens { + if err := aux.EnsureFreshWithContext(ctx); err != nil { + return fmt.Errorf("failed to refresh auxiliary token: %w", err) + } + } + return nil +} + +// RefreshWithContext obtains a fresh token for the Service Principal. +func (mt *MultiTenantServicePrincipalToken) RefreshWithContext(ctx context.Context) error { + if err := mt.PrimaryToken.RefreshWithContext(ctx); err != nil { + return fmt.Errorf("failed to refresh primary token: %w", err) + } + for _, aux := range mt.AuxiliaryTokens { + if err := aux.RefreshWithContext(ctx); err != nil { + return fmt.Errorf("failed to refresh auxiliary token: %w", err) + } + } + return nil +} + +// RefreshExchangeWithContext refreshes the token, but for a different resource. +func (mt *MultiTenantServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error { + if err := mt.PrimaryToken.RefreshExchangeWithContext(ctx, resource); err != nil { + return fmt.Errorf("failed to refresh primary token: %w", err) + } + for _, aux := range mt.AuxiliaryTokens { + if err := aux.RefreshExchangeWithContext(ctx, resource); err != nil { + return fmt.Errorf("failed to refresh auxiliary token: %w", err) + } + } + return nil +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/token_legacy.go b/vendor/github.com/Azure/go-autorest/autorest/adal/token_legacy.go new file mode 100644 index 0000000000..729bfbd0ab --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/token_legacy.go @@ -0,0 +1,74 @@ +// +build !go1.13 + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adal + +import ( + "context" + "net/http" + "time" +) + +func getMSIEndpoint(ctx context.Context, sender Sender) (*http.Response, error) { + tempCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + req, _ := http.NewRequest(http.MethodGet, msiEndpoint, nil) + req = req.WithContext(tempCtx) + q := req.URL.Query() + q.Add("api-version", msiAPIVersion) + req.URL.RawQuery = q.Encode() + return sender.Do(req) +} + +// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by +// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. +func (mt *MultiTenantServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error { + if err := mt.PrimaryToken.EnsureFreshWithContext(ctx); err != nil { + return err + } + for _, aux := range mt.AuxiliaryTokens { + if err := aux.EnsureFreshWithContext(ctx); err != nil { + return err + } + } + return nil +} + +// RefreshWithContext obtains a fresh token for the Service Principal. +func (mt *MultiTenantServicePrincipalToken) RefreshWithContext(ctx context.Context) error { + if err := mt.PrimaryToken.RefreshWithContext(ctx); err != nil { + return err + } + for _, aux := range mt.AuxiliaryTokens { + if err := aux.RefreshWithContext(ctx); err != nil { + return err + } + } + return nil +} + +// RefreshExchangeWithContext refreshes the token, but for a different resource. +func (mt *MultiTenantServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error { + if err := mt.PrimaryToken.RefreshExchangeWithContext(ctx, resource); err != nil { + return err + } + for _, aux := range mt.AuxiliaryTokens { + if err := aux.RefreshExchangeWithContext(ctx, resource); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/adal/version.go b/vendor/github.com/Azure/go-autorest/autorest/adal/version.go new file mode 100644 index 0000000000..c867b34843 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/adal/version.go @@ -0,0 +1,45 @@ +package adal + +import ( + "fmt" + "runtime" +) + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +const number = "v1.0.0" + +var ( + ua = fmt.Sprintf("Go/%s (%s-%s) go-autorest/adal/%s", + runtime.Version(), + runtime.GOARCH, + runtime.GOOS, + number, + ) +) + +// UserAgent returns a string containing the Go version, system architecture and OS, and the adal version. +func UserAgent() string { + return ua +} + +// AddToUserAgent adds an extension to the current user agent +func AddToUserAgent(extension string) error { + if extension != "" { + ua = fmt.Sprintf("%s %s", ua, extension) + return nil + } + return fmt.Errorf("Extension was empty, User Agent remained as '%s'", ua) +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/authorization.go b/vendor/github.com/Azure/go-autorest/autorest/authorization.go index 77eff45bdd..1226c41115 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/authorization.go +++ b/vendor/github.com/Azure/go-autorest/autorest/authorization.go @@ -15,6 +15,8 @@ package autorest // limitations under the License. import ( + "crypto/tls" + "encoding/base64" "fmt" "net/http" "net/url" @@ -30,6 +32,8 @@ const ( apiKeyAuthorizerHeader = "Ocp-Apim-Subscription-Key" bingAPISdkHeader = "X-BingApis-SDK-Client" golangBingAPISdkHeaderValue = "Go-SDK" + authorization = "Authorization" + basic = "Basic" ) // Authorizer is the interface that provides a PrepareDecorator used to supply request @@ -68,7 +72,7 @@ func NewAPIKeyAuthorizer(headers map[string]interface{}, queryParameters map[str return &APIKeyAuthorizer{headers: headers, queryParameters: queryParameters} } -// WithAuthorization returns a PrepareDecorator that adds an HTTP headers and Query Paramaters +// WithAuthorization returns a PrepareDecorator that adds an HTTP headers and Query Parameters. func (aka *APIKeyAuthorizer) WithAuthorization() PrepareDecorator { return func(p Preparer) Preparer { return DecoratePreparer(p, WithHeaders(aka.headers), WithQueryParameters(aka.queryParameters)) @@ -134,6 +138,11 @@ func (ba *BearerAuthorizer) WithAuthorization() PrepareDecorator { } } +// TokenProvider returns OAuthTokenProvider so that it can be used for authorization outside the REST. +func (ba *BearerAuthorizer) TokenProvider() adal.OAuthTokenProvider { + return ba.tokenProvider +} + // BearerAuthorizerCallbackFunc is the authentication callback signature. type BearerAuthorizerCallbackFunc func(tenantID, resource string) (*BearerAuthorizer, error) @@ -145,11 +154,11 @@ type BearerAuthorizerCallback struct { // NewBearerAuthorizerCallback creates a bearer authorization callback. The callback // is invoked when the HTTP request is submitted. -func NewBearerAuthorizerCallback(sender Sender, callback BearerAuthorizerCallbackFunc) *BearerAuthorizerCallback { - if sender == nil { - sender = &http.Client{} +func NewBearerAuthorizerCallback(s Sender, callback BearerAuthorizerCallbackFunc) *BearerAuthorizerCallback { + if s == nil { + s = sender(tls.RenegotiateNever) } - return &BearerAuthorizerCallback{sender: sender, callback: callback} + return &BearerAuthorizerCallback{sender: s, callback: callback} } // WithAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose value @@ -167,20 +176,21 @@ func (bacb *BearerAuthorizerCallback) WithAuthorization() PrepareDecorator { removeRequestBody(&rCopy) resp, err := bacb.sender.Do(&rCopy) - if err == nil && resp.StatusCode == 401 { - defer resp.Body.Close() - if hasBearerChallenge(resp) { - bc, err := newBearerChallenge(resp) + if err != nil { + return r, err + } + DrainResponseBody(resp) + if resp.StatusCode == 401 && hasBearerChallenge(resp.Header) { + bc, err := newBearerChallenge(resp.Header) + if err != nil { + return r, err + } + if bacb.callback != nil { + ba, err := bacb.callback(bc.values[tenantID], bc.values["resource"]) if err != nil { return r, err } - if bacb.callback != nil { - ba, err := bacb.callback(bc.values[tenantID], bc.values["resource"]) - if err != nil { - return r, err - } - return Prepare(r, ba.WithAuthorization()) - } + return Prepare(r, ba.WithAuthorization()) } } } @@ -190,8 +200,8 @@ func (bacb *BearerAuthorizerCallback) WithAuthorization() PrepareDecorator { } // returns true if the HTTP response contains a bearer challenge -func hasBearerChallenge(resp *http.Response) bool { - authHeader := resp.Header.Get(bearerChallengeHeader) +func hasBearerChallenge(header http.Header) bool { + authHeader := header.Get(bearerChallengeHeader) if len(authHeader) == 0 || strings.Index(authHeader, bearer) < 0 { return false } @@ -202,8 +212,8 @@ type bearerChallenge struct { values map[string]string } -func newBearerChallenge(resp *http.Response) (bc bearerChallenge, err error) { - challenge := strings.TrimSpace(resp.Header.Get(bearerChallengeHeader)) +func newBearerChallenge(header http.Header) (bc bearerChallenge, err error) { + challenge := strings.TrimSpace(header.Get(bearerChallengeHeader)) trimmedChallenge := challenge[len(bearer)+1:] // challenge is a set of key=value pairs that are comma delimited @@ -257,3 +267,87 @@ func (egta EventGridKeyAuthorizer) WithAuthorization() PrepareDecorator { } return NewAPIKeyAuthorizerWithHeaders(headers).WithAuthorization() } + +// BasicAuthorizer implements basic HTTP authorization by adding the Authorization HTTP header +// with the value "Basic " where is a base64-encoded username:password tuple. +type BasicAuthorizer struct { + userName string + password string +} + +// NewBasicAuthorizer creates a new BasicAuthorizer with the specified username and password. +func NewBasicAuthorizer(userName, password string) *BasicAuthorizer { + return &BasicAuthorizer{ + userName: userName, + password: password, + } +} + +// WithAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose +// value is "Basic " followed by the base64-encoded username:password tuple. +func (ba *BasicAuthorizer) WithAuthorization() PrepareDecorator { + headers := make(map[string]interface{}) + headers[authorization] = basic + " " + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", ba.userName, ba.password))) + + return NewAPIKeyAuthorizerWithHeaders(headers).WithAuthorization() +} + +// MultiTenantServicePrincipalTokenAuthorizer provides authentication across tenants. +type MultiTenantServicePrincipalTokenAuthorizer interface { + WithAuthorization() PrepareDecorator +} + +// NewMultiTenantServicePrincipalTokenAuthorizer crates a BearerAuthorizer using the given token provider +func NewMultiTenantServicePrincipalTokenAuthorizer(tp adal.MultitenantOAuthTokenProvider) MultiTenantServicePrincipalTokenAuthorizer { + return NewMultiTenantBearerAuthorizer(tp) +} + +// MultiTenantBearerAuthorizer implements bearer authorization across multiple tenants. +type MultiTenantBearerAuthorizer struct { + tp adal.MultitenantOAuthTokenProvider +} + +// NewMultiTenantBearerAuthorizer creates a MultiTenantBearerAuthorizer using the given token provider. +func NewMultiTenantBearerAuthorizer(tp adal.MultitenantOAuthTokenProvider) *MultiTenantBearerAuthorizer { + return &MultiTenantBearerAuthorizer{tp: tp} +} + +// WithAuthorization returns a PrepareDecorator that adds an HTTP Authorization header using the +// primary token along with the auxiliary authorization header using the auxiliary tokens. +// +// By default, the token will be automatically refreshed through the Refresher interface. +func (mt *MultiTenantBearerAuthorizer) WithAuthorization() PrepareDecorator { + return func(p Preparer) Preparer { + return PreparerFunc(func(r *http.Request) (*http.Request, error) { + r, err := p.Prepare(r) + if err != nil { + return r, err + } + if refresher, ok := mt.tp.(adal.RefresherWithContext); ok { + err = refresher.EnsureFreshWithContext(r.Context()) + if err != nil { + var resp *http.Response + if tokError, ok := err.(adal.TokenRefreshError); ok { + resp = tokError.Response() + } + return r, NewErrorWithError(err, "azure.multiTenantSPTAuthorizer", "WithAuthorization", resp, + "Failed to refresh one or more Tokens for request to %s", r.URL) + } + } + r, err = Prepare(r, WithHeader(headerAuthorization, fmt.Sprintf("Bearer %s", mt.tp.PrimaryOAuthToken()))) + if err != nil { + return r, err + } + auxTokens := mt.tp.AuxiliaryOAuthTokens() + for i := range auxTokens { + auxTokens[i] = fmt.Sprintf("Bearer %s", auxTokens[i]) + } + return Prepare(r, WithHeader(headerAuxAuthorization, strings.Join(auxTokens, ", "))) + }) + } +} + +// TokenProvider returns the underlying MultitenantOAuthTokenProvider for this authorizer. +func (mt *MultiTenantBearerAuthorizer) TokenProvider() adal.MultitenantOAuthTokenProvider { + return mt.tp +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/authorization_sas.go b/vendor/github.com/Azure/go-autorest/autorest/authorization_sas.go new file mode 100644 index 0000000000..66501493bd --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/authorization_sas.go @@ -0,0 +1,66 @@ +package autorest + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import ( + "fmt" + "net/http" + "strings" +) + +// SASTokenAuthorizer implements an authorization for SAS Token Authentication +// this can be used for interaction with Blob Storage Endpoints +type SASTokenAuthorizer struct { + sasToken string +} + +// NewSASTokenAuthorizer creates a SASTokenAuthorizer using the given credentials +func NewSASTokenAuthorizer(sasToken string) (*SASTokenAuthorizer, error) { + if strings.TrimSpace(sasToken) == "" { + return nil, fmt.Errorf("sasToken cannot be empty") + } + + token := sasToken + if strings.HasPrefix(sasToken, "?") { + token = strings.TrimPrefix(sasToken, "?") + } + + return &SASTokenAuthorizer{ + sasToken: token, + }, nil +} + +// WithAuthorization returns a PrepareDecorator that adds a shared access signature token to the +// URI's query parameters. This can be used for the Blob, Queue, and File Services. +// +// See https://docs.microsoft.com/en-us/rest/api/storageservices/delegate-access-with-shared-access-signature +func (sas *SASTokenAuthorizer) WithAuthorization() PrepareDecorator { + return func(p Preparer) Preparer { + return PreparerFunc(func(r *http.Request) (*http.Request, error) { + r, err := p.Prepare(r) + if err != nil { + return r, err + } + + if r.URL.RawQuery == "" { + r.URL.RawQuery = sas.sasToken + } else if !strings.Contains(r.URL.RawQuery, sas.sasToken) { + r.URL.RawQuery = fmt.Sprintf("%s&%s", r.URL.RawQuery, sas.sasToken) + } + + return Prepare(r) + }) + } +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/authorization_storage.go b/vendor/github.com/Azure/go-autorest/autorest/authorization_storage.go new file mode 100644 index 0000000000..2af5030a1c --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/authorization_storage.go @@ -0,0 +1,307 @@ +package autorest + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "sort" + "strings" + "time" +) + +// SharedKeyType defines the enumeration for the various shared key types. +// See https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key for details on the shared key types. +type SharedKeyType string + +const ( + // SharedKey is used to authorize against blobs, files and queues services. + SharedKey SharedKeyType = "sharedKey" + + // SharedKeyForTable is used to authorize against the table service. + SharedKeyForTable SharedKeyType = "sharedKeyTable" + + // SharedKeyLite is used to authorize against blobs, files and queues services. It's provided for + // backwards compatibility with API versions before 2009-09-19. Prefer SharedKey instead. + SharedKeyLite SharedKeyType = "sharedKeyLite" + + // SharedKeyLiteForTable is used to authorize against the table service. It's provided for + // backwards compatibility with older table API versions. Prefer SharedKeyForTable instead. + SharedKeyLiteForTable SharedKeyType = "sharedKeyLiteTable" +) + +const ( + headerAccept = "Accept" + headerAcceptCharset = "Accept-Charset" + headerContentEncoding = "Content-Encoding" + headerContentLength = "Content-Length" + headerContentMD5 = "Content-MD5" + headerContentLanguage = "Content-Language" + headerIfModifiedSince = "If-Modified-Since" + headerIfMatch = "If-Match" + headerIfNoneMatch = "If-None-Match" + headerIfUnmodifiedSince = "If-Unmodified-Since" + headerDate = "Date" + headerXMSDate = "X-Ms-Date" + headerXMSVersion = "x-ms-version" + headerRange = "Range" +) + +const storageEmulatorAccountName = "devstoreaccount1" + +// SharedKeyAuthorizer implements an authorization for Shared Key +// this can be used for interaction with Blob, File and Queue Storage Endpoints +type SharedKeyAuthorizer struct { + accountName string + accountKey []byte + keyType SharedKeyType +} + +// NewSharedKeyAuthorizer creates a SharedKeyAuthorizer using the provided credentials and shared key type. +func NewSharedKeyAuthorizer(accountName, accountKey string, keyType SharedKeyType) (*SharedKeyAuthorizer, error) { + key, err := base64.StdEncoding.DecodeString(accountKey) + if err != nil { + return nil, fmt.Errorf("malformed storage account key: %v", err) + } + return &SharedKeyAuthorizer{ + accountName: accountName, + accountKey: key, + keyType: keyType, + }, nil +} + +// WithAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose +// value is " " followed by the computed key. +// This can be used for the Blob, Queue, and File Services +// +// from: https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key +// You may use Shared Key authorization to authorize a request made against the +// 2009-09-19 version and later of the Blob and Queue services, +// and version 2014-02-14 and later of the File services. +func (sk *SharedKeyAuthorizer) WithAuthorization() PrepareDecorator { + return func(p Preparer) Preparer { + return PreparerFunc(func(r *http.Request) (*http.Request, error) { + r, err := p.Prepare(r) + if err != nil { + return r, err + } + + sk, err := buildSharedKey(sk.accountName, sk.accountKey, r, sk.keyType) + if err != nil { + return r, err + } + return Prepare(r, WithHeader(headerAuthorization, sk)) + }) + } +} + +func buildSharedKey(accName string, accKey []byte, req *http.Request, keyType SharedKeyType) (string, error) { + canRes, err := buildCanonicalizedResource(accName, req.URL.String(), keyType) + if err != nil { + return "", err + } + + if req.Header == nil { + req.Header = http.Header{} + } + + // ensure date is set + if req.Header.Get(headerDate) == "" && req.Header.Get(headerXMSDate) == "" { + date := time.Now().UTC().Format(http.TimeFormat) + req.Header.Set(headerXMSDate, date) + } + canString, err := buildCanonicalizedString(req.Method, req.Header, canRes, keyType) + if err != nil { + return "", err + } + return createAuthorizationHeader(accName, accKey, canString, keyType), nil +} + +func buildCanonicalizedResource(accountName, uri string, keyType SharedKeyType) (string, error) { + errMsg := "buildCanonicalizedResource error: %s" + u, err := url.Parse(uri) + if err != nil { + return "", fmt.Errorf(errMsg, err.Error()) + } + + cr := bytes.NewBufferString("") + if accountName != storageEmulatorAccountName { + cr.WriteString("/") + cr.WriteString(getCanonicalizedAccountName(accountName)) + } + + if len(u.Path) > 0 { + // Any portion of the CanonicalizedResource string that is derived from + // the resource's URI should be encoded exactly as it is in the URI. + // -- https://msdn.microsoft.com/en-gb/library/azure/dd179428.aspx + cr.WriteString(u.EscapedPath()) + } else { + // a slash is required to indicate the root path + cr.WriteString("/") + } + + params, err := url.ParseQuery(u.RawQuery) + if err != nil { + return "", fmt.Errorf(errMsg, err.Error()) + } + + // See https://github.com/Azure/azure-storage-net/blob/master/Lib/Common/Core/Util/AuthenticationUtility.cs#L277 + if keyType == SharedKey { + if len(params) > 0 { + cr.WriteString("\n") + + keys := []string{} + for key := range params { + keys = append(keys, key) + } + sort.Strings(keys) + + completeParams := []string{} + for _, key := range keys { + if len(params[key]) > 1 { + sort.Strings(params[key]) + } + + completeParams = append(completeParams, fmt.Sprintf("%s:%s", key, strings.Join(params[key], ","))) + } + cr.WriteString(strings.Join(completeParams, "\n")) + } + } else { + // search for "comp" parameter, if exists then add it to canonicalizedresource + if v, ok := params["comp"]; ok { + cr.WriteString("?comp=" + v[0]) + } + } + + return string(cr.Bytes()), nil +} + +func getCanonicalizedAccountName(accountName string) string { + // since we may be trying to access a secondary storage account, we need to + // remove the -secondary part of the storage name + return strings.TrimSuffix(accountName, "-secondary") +} + +func buildCanonicalizedString(verb string, headers http.Header, canonicalizedResource string, keyType SharedKeyType) (string, error) { + contentLength := headers.Get(headerContentLength) + if contentLength == "0" { + contentLength = "" + } + date := headers.Get(headerDate) + if v := headers.Get(headerXMSDate); v != "" { + if keyType == SharedKey || keyType == SharedKeyLite { + date = "" + } else { + date = v + } + } + var canString string + switch keyType { + case SharedKey: + canString = strings.Join([]string{ + verb, + headers.Get(headerContentEncoding), + headers.Get(headerContentLanguage), + contentLength, + headers.Get(headerContentMD5), + headers.Get(headerContentType), + date, + headers.Get(headerIfModifiedSince), + headers.Get(headerIfMatch), + headers.Get(headerIfNoneMatch), + headers.Get(headerIfUnmodifiedSince), + headers.Get(headerRange), + buildCanonicalizedHeader(headers), + canonicalizedResource, + }, "\n") + case SharedKeyForTable: + canString = strings.Join([]string{ + verb, + headers.Get(headerContentMD5), + headers.Get(headerContentType), + date, + canonicalizedResource, + }, "\n") + case SharedKeyLite: + canString = strings.Join([]string{ + verb, + headers.Get(headerContentMD5), + headers.Get(headerContentType), + date, + buildCanonicalizedHeader(headers), + canonicalizedResource, + }, "\n") + case SharedKeyLiteForTable: + canString = strings.Join([]string{ + date, + canonicalizedResource, + }, "\n") + default: + return "", fmt.Errorf("key type '%s' is not supported", keyType) + } + return canString, nil +} + +func buildCanonicalizedHeader(headers http.Header) string { + cm := make(map[string]string) + + for k := range headers { + headerName := strings.TrimSpace(strings.ToLower(k)) + if strings.HasPrefix(headerName, "x-ms-") { + cm[headerName] = headers.Get(k) + } + } + + if len(cm) == 0 { + return "" + } + + keys := []string{} + for key := range cm { + keys = append(keys, key) + } + + sort.Strings(keys) + + ch := bytes.NewBufferString("") + + for _, key := range keys { + ch.WriteString(key) + ch.WriteRune(':') + ch.WriteString(cm[key]) + ch.WriteRune('\n') + } + + return strings.TrimSuffix(string(ch.Bytes()), "\n") +} + +func createAuthorizationHeader(accountName string, accountKey []byte, canonicalizedString string, keyType SharedKeyType) string { + h := hmac.New(sha256.New, accountKey) + h.Write([]byte(canonicalizedString)) + signature := base64.StdEncoding.EncodeToString(h.Sum(nil)) + var key string + switch keyType { + case SharedKey, SharedKeyForTable: + key = "SharedKey" + case SharedKeyLite, SharedKeyLiteForTable: + key = "SharedKeyLite" + } + return fmt.Sprintf("%s %s:%s", key, getCanonicalizedAccountName(accountName), signature) +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/azure/async.go b/vendor/github.com/Azure/go-autorest/autorest/azure/async.go index a58e5ef3f1..45575eedbf 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/azure/async.go +++ b/vendor/github.com/Azure/go-autorest/autorest/azure/async.go @@ -21,11 +21,13 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" "strings" "time" "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/logger" + "github.com/Azure/go-autorest/tracing" ) const ( @@ -41,87 +43,123 @@ const ( var pollingCodes = [...]int{http.StatusNoContent, http.StatusAccepted, http.StatusCreated, http.StatusOK} +// FutureAPI contains the set of methods on the Future type. +type FutureAPI interface { + // Response returns the last HTTP response. + Response() *http.Response + + // Status returns the last status message of the operation. + Status() string + + // PollingMethod returns the method used to monitor the status of the asynchronous operation. + PollingMethod() PollingMethodType + + // DoneWithContext queries the service to see if the operation has completed. + DoneWithContext(context.Context, autorest.Sender) (bool, error) + + // GetPollingDelay returns a duration the application should wait before checking + // the status of the asynchronous request and true; this value is returned from + // the service via the Retry-After response header. If the header wasn't returned + // then the function returns the zero-value time.Duration and false. + GetPollingDelay() (time.Duration, bool) + + // WaitForCompletionRef will return when one of the following conditions is met: the long + // running operation has completed, the provided context is cancelled, or the client's + // polling duration has been exceeded. It will retry failed polling attempts based on + // the retry value defined in the client up to the maximum retry attempts. + // If no deadline is specified in the context then the client.PollingDuration will be + // used to determine if a default deadline should be used. + // If PollingDuration is greater than zero the value will be used as the context's timeout. + // If PollingDuration is zero then no default deadline will be used. + WaitForCompletionRef(context.Context, autorest.Client) error + + // MarshalJSON implements the json.Marshaler interface. + MarshalJSON() ([]byte, error) + + // MarshalJSON implements the json.Unmarshaler interface. + UnmarshalJSON([]byte) error + + // PollingURL returns the URL used for retrieving the status of the long-running operation. + PollingURL() string + + // GetResult should be called once polling has completed successfully. + // It makes the final GET call to retrieve the resultant payload. + GetResult(autorest.Sender) (*http.Response, error) +} + +var _ FutureAPI = (*Future)(nil) + // Future provides a mechanism to access the status and results of an asynchronous request. // Since futures are stateful they should be passed by value to avoid race conditions. type Future struct { - req *http.Request - resp *http.Response - ps pollingState + pt pollingTracker } -// NewFuture returns a new Future object initialized with the specified request. -func NewFuture(req *http.Request) Future { - return Future{req: req} +// NewFutureFromResponse returns a new Future object initialized +// with the initial response from an asynchronous operation. +func NewFutureFromResponse(resp *http.Response) (Future, error) { + pt, err := createPollingTracker(resp) + return Future{pt: pt}, err } -// Response returns the last HTTP response or nil if there isn't one. +// Response returns the last HTTP response. func (f Future) Response() *http.Response { - return f.resp + if f.pt == nil { + return nil + } + return f.pt.latestResponse() } // Status returns the last status message of the operation. func (f Future) Status() string { - if f.ps.State == "" { - return "Unknown" + if f.pt == nil { + return "" } - return f.ps.State + return f.pt.pollingStatus() } // PollingMethod returns the method used to monitor the status of the asynchronous operation. func (f Future) PollingMethod() PollingMethodType { - return f.ps.PollingMethod + if f.pt == nil { + return PollingUnknown + } + return f.pt.pollingMethod() } -// Done queries the service to see if the operation has completed. -func (f *Future) Done(sender autorest.Sender) (bool, error) { - // exit early if this future has terminated - if f.ps.hasTerminated() { - return true, f.errorInfo() +// DoneWithContext queries the service to see if the operation has completed. +func (f *Future) DoneWithContext(ctx context.Context, sender autorest.Sender) (done bool, err error) { + ctx = tracing.StartSpan(ctx, "github.com/Azure/go-autorest/autorest/azure/async.DoneWithContext") + defer func() { + sc := -1 + resp := f.Response() + if resp != nil { + sc = resp.StatusCode + } + tracing.EndSpan(ctx, sc, err) + }() + + if f.pt == nil { + return false, autorest.NewError("Future", "Done", "future is not initialized") } - resp, err := sender.Do(f.req) - f.resp = resp - if err != nil { + if f.pt.hasTerminated() { + return true, f.pt.pollingError() + } + if err := f.pt.pollForStatus(ctx, sender); err != nil { return false, err } - - if !autorest.ResponseHasStatusCode(resp, pollingCodes[:]...) { - // check response body for error content - if resp.Body != nil { - type respErr struct { - ServiceError ServiceError `json:"error"` - } - re := respErr{} - - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return false, err - } - err = json.Unmarshal(b, &re) - if err != nil { - return false, err - } - return false, re.ServiceError - } - - // try to return something meaningful - return false, ServiceError{ - Code: fmt.Sprintf("%v", resp.StatusCode), - Message: resp.Status, - } + if err := f.pt.checkForErrors(); err != nil { + return f.pt.hasTerminated(), err } - - err = updatePollingState(resp, &f.ps) - if err != nil { + if err := f.pt.updatePollingState(f.pt.provisioningStateApplicable()); err != nil { return false, err } - - if f.ps.hasTerminated() { - return true, f.errorInfo() + if err := f.pt.initPollingMethod(); err != nil { + return false, err } - - f.req, err = newPollingRequest(f.ps) - return false, err + if err := f.pt.updatePollingMethod(); err != nil { + return false, err + } + return f.pt.hasTerminated(), f.pt.pollingError() } // GetPollingDelay returns a duration the application should wait before checking @@ -129,11 +167,15 @@ func (f *Future) Done(sender autorest.Sender) (bool, error) { // the service via the Retry-After response header. If the header wasn't returned // then the function returns the zero-value time.Duration and false. func (f Future) GetPollingDelay() (time.Duration, bool) { - if f.resp == nil { + if f.pt == nil { + return 0, false + } + resp := f.pt.latestResponse() + if resp == nil { return 0, false } - retry := f.resp.Header.Get(autorest.HeaderRetryAfter) + retry := resp.Header.Get(autorest.HeaderRetryAfter) if retry == "" { return 0, false } @@ -146,18 +188,44 @@ func (f Future) GetPollingDelay() (time.Duration, bool) { return d, true } -// WaitForCompletion will return when one of the following conditions is met: the long +// WaitForCompletionRef will return when one of the following conditions is met: the long // running operation has completed, the provided context is cancelled, or the client's // polling duration has been exceeded. It will retry failed polling attempts based on // the retry value defined in the client up to the maximum retry attempts. -func (f Future) WaitForCompletion(ctx context.Context, client autorest.Client) error { - ctx, cancel := context.WithTimeout(ctx, client.PollingDuration) - defer cancel() - - done, err := f.Done(client) - for attempts := 0; !done; done, err = f.Done(client) { +// If no deadline is specified in the context then the client.PollingDuration will be +// used to determine if a default deadline should be used. +// If PollingDuration is greater than zero the value will be used as the context's timeout. +// If PollingDuration is zero then no default deadline will be used. +func (f *Future) WaitForCompletionRef(ctx context.Context, client autorest.Client) (err error) { + ctx = tracing.StartSpan(ctx, "github.com/Azure/go-autorest/autorest/azure/async.WaitForCompletionRef") + defer func() { + sc := -1 + resp := f.Response() + if resp != nil { + sc = resp.StatusCode + } + tracing.EndSpan(ctx, sc, err) + }() + cancelCtx := ctx + // if the provided context already has a deadline don't override it + _, hasDeadline := ctx.Deadline() + if d := client.PollingDuration; !hasDeadline && d != 0 { + var cancel context.CancelFunc + cancelCtx, cancel = context.WithTimeout(ctx, d) + defer cancel() + } + // if the initial response has a Retry-After, sleep for the specified amount of time before starting to poll + if delay, ok := f.GetPollingDelay(); ok { + logger.Instance.Writeln(logger.LogInfo, "WaitForCompletionRef: initial polling delay") + if delayElapsed := autorest.DelayForBackoff(delay, 0, cancelCtx.Done()); !delayElapsed { + err = cancelCtx.Err() + return + } + } + done, err := f.DoneWithContext(ctx, client) + for attempts := 0; !done; done, err = f.DoneWithContext(ctx, client) { if attempts >= client.RetryAttempts { - return autorest.NewErrorWithError(err, "azure", "WaitForCompletion", f.resp, "the number of retries has been exceeded") + return autorest.NewErrorWithError(err, "Future", "WaitForCompletion", f.pt.latestResponse(), "the number of retries has been exceeded") } // we want delayAttempt to be zero in the non-error case so // that DelayForBackoff doesn't perform exponential back-off @@ -168,330 +236,746 @@ func (f Future) WaitForCompletion(ctx context.Context, client autorest.Client) e var ok bool delay, ok = f.GetPollingDelay() if !ok { + logger.Instance.Writeln(logger.LogInfo, "WaitForCompletionRef: Using client polling delay") delay = client.PollingDelay } } else { // there was an error polling for status so perform exponential // back-off based on the number of attempts using the client's retry // duration. update attempts after delayAttempt to avoid off-by-one. + logger.Instance.Writef(logger.LogError, "WaitForCompletionRef: %s\n", err) delayAttempt = attempts delay = client.RetryDuration attempts++ } // wait until the delay elapses or the context is cancelled - delayElapsed := autorest.DelayForBackoff(delay, delayAttempt, ctx.Done()) + delayElapsed := autorest.DelayForBackoff(delay, delayAttempt, cancelCtx.Done()) if !delayElapsed { - return autorest.NewErrorWithError(ctx.Err(), "azure", "WaitForCompletion", f.resp, "context has been cancelled") + return autorest.NewErrorWithError(cancelCtx.Err(), "Future", "WaitForCompletion", f.pt.latestResponse(), "context has been cancelled") } } - return err -} - -// if the operation failed the polling state will contain -// error information and implements the error interface -func (f *Future) errorInfo() error { - if !f.ps.hasSucceeded() { - return f.ps - } - return nil + return } // MarshalJSON implements the json.Marshaler interface. func (f Future) MarshalJSON() ([]byte, error) { - return json.Marshal(&f.ps) + return json.Marshal(f.pt) } // UnmarshalJSON implements the json.Unmarshaler interface. func (f *Future) UnmarshalJSON(data []byte) error { - err := json.Unmarshal(data, &f.ps) + // unmarshal into JSON object to determine the tracker type + obj := map[string]interface{}{} + err := json.Unmarshal(data, &obj) if err != nil { return err } - f.req, err = newPollingRequest(f.ps) - return err + if obj["method"] == nil { + return autorest.NewError("Future", "UnmarshalJSON", "missing 'method' property") + } + method := obj["method"].(string) + switch strings.ToUpper(method) { + case http.MethodDelete: + f.pt = &pollingTrackerDelete{} + case http.MethodPatch: + f.pt = &pollingTrackerPatch{} + case http.MethodPost: + f.pt = &pollingTrackerPost{} + case http.MethodPut: + f.pt = &pollingTrackerPut{} + default: + return autorest.NewError("Future", "UnmarshalJSON", "unsupoorted method '%s'", method) + } + // now unmarshal into the tracker + return json.Unmarshal(data, &f.pt) } // PollingURL returns the URL used for retrieving the status of the long-running operation. -// For LROs that use the Location header the final URL value is used to retrieve the result. func (f Future) PollingURL() string { - return f.ps.URI + if f.pt == nil { + return "" + } + return f.pt.pollingURL() +} + +// GetResult should be called once polling has completed successfully. +// It makes the final GET call to retrieve the resultant payload. +func (f Future) GetResult(sender autorest.Sender) (*http.Response, error) { + if f.pt.finalGetURL() == "" { + // we can end up in this situation if the async operation returns a 200 + // with no polling URLs. in that case return the response which should + // contain the JSON payload (only do this for successful terminal cases). + if lr := f.pt.latestResponse(); lr != nil && f.pt.hasSucceeded() { + return lr, nil + } + return nil, autorest.NewError("Future", "GetResult", "missing URL for retrieving result") + } + req, err := http.NewRequest(http.MethodGet, f.pt.finalGetURL(), nil) + if err != nil { + return nil, err + } + resp, err := sender.Do(req) + if err == nil && resp.Body != nil { + // copy the body and close it so callers don't have to + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return resp, err + } + resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + } + return resp, err } -// DoPollForAsynchronous returns a SendDecorator that polls if the http.Response is for an Azure -// long-running operation. It will delay between requests for the duration specified in the -// RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled by -// closing the optional channel on the http.Request. -func DoPollForAsynchronous(delay time.Duration) autorest.SendDecorator { - return func(s autorest.Sender) autorest.Sender { - return autorest.SenderFunc(func(r *http.Request) (resp *http.Response, err error) { - resp, err = s.Do(r) - if err != nil { - return resp, err - } - if !autorest.ResponseHasStatusCode(resp, pollingCodes[:]...) { - return resp, nil - } +type pollingTracker interface { + // these methods can differ per tracker - ps := pollingState{} - for err == nil { - err = updatePollingState(resp, &ps) - if err != nil { - break - } - if ps.hasTerminated() { - if !ps.hasSucceeded() { - err = ps - } - break - } - - r, err = newPollingRequest(ps) - if err != nil { - return resp, err - } - r = r.WithContext(resp.Request.Context()) - - delay = autorest.GetRetryAfter(resp, delay) - resp, err = autorest.SendWithSender(s, r, - autorest.AfterDelay(delay)) + // checks the response headers and status code to determine the polling mechanism + updatePollingMethod() error + + // checks the response for tracker-specific error conditions + checkForErrors() error + + // returns true if provisioning state should be checked + provisioningStateApplicable() bool + + // methods common to all trackers + + // initializes a tracker's polling URL and method, called for each iteration. + // these values can be overridden by each polling tracker as required. + initPollingMethod() error + + // initializes the tracker's internal state, call this when the tracker is created + initializeState() error + + // makes an HTTP request to check the status of the LRO + pollForStatus(ctx context.Context, sender autorest.Sender) error + + // updates internal tracker state, call this after each call to pollForStatus + updatePollingState(provStateApl bool) error + + // returns the error response from the service, can be nil + pollingError() error + + // returns the polling method being used + pollingMethod() PollingMethodType + + // returns the state of the LRO as returned from the service + pollingStatus() string + + // returns the URL used for polling status + pollingURL() string + + // returns the URL used for the final GET to retrieve the resource + finalGetURL() string + + // returns true if the LRO is in a terminal state + hasTerminated() bool + + // returns true if the LRO is in a failed terminal state + hasFailed() bool + + // returns true if the LRO is in a successful terminal state + hasSucceeded() bool + + // returns the cached HTTP response after a call to pollForStatus(), can be nil + latestResponse() *http.Response +} + +type pollingTrackerBase struct { + // resp is the last response, either from the submission of the LRO or from polling + resp *http.Response + + // method is the HTTP verb, this is needed for deserialization + Method string `json:"method"` + + // rawBody is the raw JSON response body + rawBody map[string]interface{} + + // denotes if polling is using async-operation or location header + Pm PollingMethodType `json:"pollingMethod"` + + // the URL to poll for status + URI string `json:"pollingURI"` + + // the state of the LRO as returned from the service + State string `json:"lroState"` + + // the URL to GET for the final result + FinalGetURI string `json:"resultURI"` + + // used to hold an error object returned from the service + Err *ServiceError `json:"error,omitempty"` +} + +func (pt *pollingTrackerBase) initializeState() error { + // determine the initial polling state based on response body and/or HTTP status + // code. this is applicable to the initial LRO response, not polling responses! + pt.Method = pt.resp.Request.Method + if err := pt.updateRawBody(); err != nil { + return err + } + switch pt.resp.StatusCode { + case http.StatusOK: + if ps := pt.getProvisioningState(); ps != nil { + pt.State = *ps + if pt.hasFailed() { + pt.updateErrorFromResponse() + return pt.pollingError() } + } else { + pt.State = operationSucceeded + } + case http.StatusCreated: + if ps := pt.getProvisioningState(); ps != nil { + pt.State = *ps + } else { + pt.State = operationInProgress + } + case http.StatusAccepted: + pt.State = operationInProgress + case http.StatusNoContent: + pt.State = operationSucceeded + default: + pt.State = operationFailed + pt.updateErrorFromResponse() + return pt.pollingError() + } + return pt.initPollingMethod() +} - return resp, err - }) +func (pt pollingTrackerBase) getProvisioningState() *string { + if pt.rawBody != nil && pt.rawBody["properties"] != nil { + p := pt.rawBody["properties"].(map[string]interface{}) + if ps := p["provisioningState"]; ps != nil { + s := ps.(string) + return &s + } } + return nil } -func getAsyncOperation(resp *http.Response) string { - return resp.Header.Get(http.CanonicalHeaderKey(headerAsyncOperation)) +func (pt *pollingTrackerBase) updateRawBody() error { + pt.rawBody = map[string]interface{}{} + if pt.resp.ContentLength != 0 { + defer pt.resp.Body.Close() + b, err := ioutil.ReadAll(pt.resp.Body) + if err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "updateRawBody", nil, "failed to read response body") + } + // put the body back so it's available to other callers + pt.resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + // observed in 204 responses over HTTP/2.0; the content length is -1 but body is empty + if len(b) == 0 { + return nil + } + if err = json.Unmarshal(b, &pt.rawBody); err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "updateRawBody", nil, "failed to unmarshal response body") + } + } + return nil } -func hasSucceeded(state string) bool { - return strings.EqualFold(state, operationSucceeded) +func (pt *pollingTrackerBase) pollForStatus(ctx context.Context, sender autorest.Sender) error { + req, err := http.NewRequest(http.MethodGet, pt.URI, nil) + if err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "pollForStatus", nil, "failed to create HTTP request") + } + + req = req.WithContext(ctx) + preparer := autorest.CreatePreparer(autorest.GetPrepareDecorators(ctx)...) + req, err = preparer.Prepare(req) + if err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "pollForStatus", nil, "failed preparing HTTP request") + } + pt.resp, err = sender.Do(req) + if err != nil { + return autorest.NewErrorWithError(err, "pollingTrackerBase", "pollForStatus", nil, "failed to send HTTP request") + } + if autorest.ResponseHasStatusCode(pt.resp, pollingCodes[:]...) { + // reset the service error on success case + pt.Err = nil + err = pt.updateRawBody() + } else { + // check response body for error content + pt.updateErrorFromResponse() + err = pt.pollingError() + } + return err } -func hasTerminated(state string) bool { - return strings.EqualFold(state, operationCanceled) || strings.EqualFold(state, operationFailed) || strings.EqualFold(state, operationSucceeded) +// attempts to unmarshal a ServiceError type from the response body. +// if that fails then make a best attempt at creating something meaningful. +// NOTE: this assumes that the async operation has failed. +func (pt *pollingTrackerBase) updateErrorFromResponse() { + var err error + if pt.resp.ContentLength != 0 { + type respErr struct { + ServiceError *ServiceError `json:"error"` + } + re := respErr{} + defer pt.resp.Body.Close() + var b []byte + if b, err = ioutil.ReadAll(pt.resp.Body); err != nil { + goto Default + } + // put the body back so it's available to other callers + pt.resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + if len(b) == 0 { + goto Default + } + if err = json.Unmarshal(b, &re); err != nil { + goto Default + } + // unmarshalling the error didn't yield anything, try unwrapped error + if re.ServiceError == nil { + err = json.Unmarshal(b, &re.ServiceError) + if err != nil { + goto Default + } + } + // the unmarshaller will ensure re.ServiceError is non-nil + // even if there was no content unmarshalled so check the code. + if re.ServiceError.Code != "" { + pt.Err = re.ServiceError + return + } + } +Default: + se := &ServiceError{ + Code: pt.pollingStatus(), + Message: "The async operation failed.", + } + if err != nil { + se.InnerError = make(map[string]interface{}) + se.InnerError["unmarshalError"] = err.Error() + } + // stick the response body into the error object in hopes + // it contains something useful to help diagnose the failure. + if len(pt.rawBody) > 0 { + se.AdditionalInfo = []map[string]interface{}{ + pt.rawBody, + } + } + pt.Err = se } -func hasFailed(state string) bool { - return strings.EqualFold(state, operationFailed) +func (pt *pollingTrackerBase) updatePollingState(provStateApl bool) error { + if pt.Pm == PollingAsyncOperation && pt.rawBody["status"] != nil { + pt.State = pt.rawBody["status"].(string) + } else { + if pt.resp.StatusCode == http.StatusAccepted { + pt.State = operationInProgress + } else if provStateApl { + if ps := pt.getProvisioningState(); ps != nil { + pt.State = *ps + } else { + pt.State = operationSucceeded + } + } else { + return autorest.NewError("pollingTrackerBase", "updatePollingState", "the response from the async operation has an invalid status code") + } + } + // if the operation has failed update the error state + if pt.hasFailed() { + pt.updateErrorFromResponse() + } + return nil } -type provisioningTracker interface { - state() string - hasSucceeded() bool - hasTerminated() bool +func (pt pollingTrackerBase) pollingError() error { + if pt.Err == nil { + return nil + } + return pt.Err } -type operationResource struct { - // Note: - // The specification states services should return the "id" field. However some return it as - // "operationId". - ID string `json:"id"` - OperationID string `json:"operationId"` - Name string `json:"name"` - Status string `json:"status"` - Properties map[string]interface{} `json:"properties"` - OperationError ServiceError `json:"error"` - StartTime date.Time `json:"startTime"` - EndTime date.Time `json:"endTime"` - PercentComplete float64 `json:"percentComplete"` +func (pt pollingTrackerBase) pollingMethod() PollingMethodType { + return pt.Pm } -func (or operationResource) state() string { - return or.Status +func (pt pollingTrackerBase) pollingStatus() string { + return pt.State } -func (or operationResource) hasSucceeded() bool { - return hasSucceeded(or.state()) +func (pt pollingTrackerBase) pollingURL() string { + return pt.URI } -func (or operationResource) hasTerminated() bool { - return hasTerminated(or.state()) +func (pt pollingTrackerBase) finalGetURL() string { + return pt.FinalGetURI } -type provisioningProperties struct { - ProvisioningState string `json:"provisioningState"` +func (pt pollingTrackerBase) hasTerminated() bool { + return strings.EqualFold(pt.State, operationCanceled) || strings.EqualFold(pt.State, operationFailed) || strings.EqualFold(pt.State, operationSucceeded) } -type provisioningStatus struct { - Properties provisioningProperties `json:"properties,omitempty"` - ProvisioningError ServiceError `json:"error,omitempty"` +func (pt pollingTrackerBase) hasFailed() bool { + return strings.EqualFold(pt.State, operationCanceled) || strings.EqualFold(pt.State, operationFailed) } -func (ps provisioningStatus) state() string { - return ps.Properties.ProvisioningState +func (pt pollingTrackerBase) hasSucceeded() bool { + return strings.EqualFold(pt.State, operationSucceeded) } -func (ps provisioningStatus) hasSucceeded() bool { - return hasSucceeded(ps.state()) +func (pt pollingTrackerBase) latestResponse() *http.Response { + return pt.resp } -func (ps provisioningStatus) hasTerminated() bool { - return hasTerminated(ps.state()) +// error checking common to all trackers +func (pt pollingTrackerBase) baseCheckForErrors() error { + // for Azure-AsyncOperations the response body cannot be nil or empty + if pt.Pm == PollingAsyncOperation { + if pt.resp.Body == nil || pt.resp.ContentLength == 0 { + return autorest.NewError("pollingTrackerBase", "baseCheckForErrors", "for Azure-AsyncOperation response body cannot be nil") + } + if pt.rawBody["status"] == nil { + return autorest.NewError("pollingTrackerBase", "baseCheckForErrors", "missing status property in Azure-AsyncOperation response body") + } + } + return nil } -func (ps provisioningStatus) hasProvisioningError() bool { - // code and message are required fields so only check them - return len(ps.ProvisioningError.Code) > 0 || - len(ps.ProvisioningError.Message) > 0 +// default initialization of polling URL/method. each verb tracker will update this as required. +func (pt *pollingTrackerBase) initPollingMethod() error { + if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + return nil + } + if lh, err := getURLFromLocationHeader(pt.resp); err != nil { + return err + } else if lh != "" { + pt.URI = lh + pt.Pm = PollingLocation + return nil + } + // it's ok if we didn't find a polling header, this will be handled elsewhere + return nil } -// PollingMethodType defines a type used for enumerating polling mechanisms. -type PollingMethodType string +// DELETE -const ( - // PollingAsyncOperation indicates the polling method uses the Azure-AsyncOperation header. - PollingAsyncOperation PollingMethodType = "AsyncOperation" +type pollingTrackerDelete struct { + pollingTrackerBase +} - // PollingLocation indicates the polling method uses the Location header. - PollingLocation PollingMethodType = "Location" +func (pt *pollingTrackerDelete) updatePollingMethod() error { + // for 201 the Location header is required + if pt.resp.StatusCode == http.StatusCreated { + if lh, err := getURLFromLocationHeader(pt.resp); err != nil { + return err + } else if lh == "" { + return autorest.NewError("pollingTrackerDelete", "updateHeaders", "missing Location header in 201 response") + } else { + pt.URI = lh + } + pt.Pm = PollingLocation + pt.FinalGetURI = pt.URI + } + // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary + if pt.resp.StatusCode == http.StatusAccepted { + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + // if the Location header is invalid and we already have a polling URL + // then we don't care if the Location header URL is malformed. + if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { + return err + } else if lh != "" { + if ao == "" { + pt.URI = lh + pt.Pm = PollingLocation + } + // when both headers are returned we use the value in the Location header for the final GET + pt.FinalGetURI = lh + } + // make sure a polling URL was found + if pt.URI == "" { + return autorest.NewError("pollingTrackerPost", "updateHeaders", "didn't get any suitable polling URLs in 202 response") + } + } + return nil +} - // PollingUnknown indicates an unknown polling method and is the default value. - PollingUnknown PollingMethodType = "" -) +func (pt pollingTrackerDelete) checkForErrors() error { + return pt.baseCheckForErrors() +} -type pollingState struct { - PollingMethod PollingMethodType `json:"pollingMethod"` - URI string `json:"uri"` - State string `json:"state"` - ServiceError *ServiceError `json:"error,omitempty"` +func (pt pollingTrackerDelete) provisioningStateApplicable() bool { + return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusNoContent } -func (ps pollingState) hasSucceeded() bool { - return hasSucceeded(ps.State) +// PATCH + +type pollingTrackerPatch struct { + pollingTrackerBase } -func (ps pollingState) hasTerminated() bool { - return hasTerminated(ps.State) +func (pt *pollingTrackerPatch) updatePollingMethod() error { + // by default we can use the original URL for polling and final GET + if pt.URI == "" { + pt.URI = pt.resp.Request.URL.String() + } + if pt.FinalGetURI == "" { + pt.FinalGetURI = pt.resp.Request.URL.String() + } + if pt.Pm == PollingUnknown { + pt.Pm = PollingRequestURI + } + // for 201 it's permissible for no headers to be returned + if pt.resp.StatusCode == http.StatusCreated { + if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + } + // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary + // note the absence of the "final GET" mechanism for PATCH + if pt.resp.StatusCode == http.StatusAccepted { + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + if ao == "" { + if lh, err := getURLFromLocationHeader(pt.resp); err != nil { + return err + } else if lh == "" { + return autorest.NewError("pollingTrackerPatch", "updateHeaders", "didn't get any suitable polling URLs in 202 response") + } else { + pt.URI = lh + pt.Pm = PollingLocation + } + } + } + return nil } -func (ps pollingState) hasFailed() bool { - return hasFailed(ps.State) +func (pt pollingTrackerPatch) checkForErrors() error { + return pt.baseCheckForErrors() } -func (ps pollingState) Error() string { - s := fmt.Sprintf("Long running operation terminated with status '%s'", ps.State) - if ps.ServiceError != nil { - s = fmt.Sprintf("%s: %+v", s, *ps.ServiceError) - } - return s +func (pt pollingTrackerPatch) provisioningStateApplicable() bool { + return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusCreated } -// updatePollingState maps the operation status -- retrieved from either a provisioningState -// field, the status field of an OperationResource, or inferred from the HTTP status code -- -// into a well-known states. Since the process begins from the initial request, the state -// always comes from either a the provisioningState returned or is inferred from the HTTP -// status code. Subsequent requests will read an Azure OperationResource object if the -// service initially returned the Azure-AsyncOperation header. The responseFormat field notes -// the expected response format. -func updatePollingState(resp *http.Response, ps *pollingState) error { - // Determine the response shape - // -- The first response will always be a provisioningStatus response; only the polling requests, - // depending on the header returned, may be something otherwise. - var pt provisioningTracker - if ps.PollingMethod == PollingAsyncOperation { - pt = &operationResource{} - } else { - pt = &provisioningStatus{} - } +// POST - // If this is the first request (that is, the polling response shape is unknown), determine how - // to poll and what to expect - if ps.PollingMethod == PollingUnknown { - req := resp.Request - if req == nil { - return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Original HTTP request is missing") - } +type pollingTrackerPost struct { + pollingTrackerBase +} - // Prefer the Azure-AsyncOperation header - ps.URI = getAsyncOperation(resp) - if ps.URI != "" { - ps.PollingMethod = PollingAsyncOperation +func (pt *pollingTrackerPost) updatePollingMethod() error { + // 201 requires Location header + if pt.resp.StatusCode == http.StatusCreated { + if lh, err := getURLFromLocationHeader(pt.resp); err != nil { + return err + } else if lh == "" { + return autorest.NewError("pollingTrackerPost", "updateHeaders", "missing Location header in 201 response") } else { - ps.PollingMethod = PollingLocation + pt.URI = lh + pt.FinalGetURI = lh + pt.Pm = PollingLocation } - - // Else, use the Location header - if ps.URI == "" { - ps.URI = autorest.GetLocation(resp) + } + // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary + if pt.resp.StatusCode == http.StatusAccepted { + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation } - - // Lastly, requests against an existing resource, use the last request URI - if ps.URI == "" { - m := strings.ToUpper(req.Method) - if m == http.MethodPatch || m == http.MethodPut || m == http.MethodGet { - ps.URI = req.URL.String() + // if the Location header is invalid and we already have a polling URL + // then we don't care if the Location header URL is malformed. + if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { + return err + } else if lh != "" { + if ao == "" { + pt.URI = lh + pt.Pm = PollingLocation } + // when both headers are returned we use the value in the Location header for the final GET + pt.FinalGetURI = lh + } + // make sure a polling URL was found + if pt.URI == "" { + return autorest.NewError("pollingTrackerPost", "updateHeaders", "didn't get any suitable polling URLs in 202 response") } } + return nil +} - // Read and interpret the response (saving the Body in case no polling is necessary) - b := &bytes.Buffer{} - err := autorest.Respond(resp, - autorest.ByCopying(b), - autorest.ByUnmarshallingJSON(pt), - autorest.ByClosing()) - resp.Body = ioutil.NopCloser(b) - if err != nil { - return err - } +func (pt pollingTrackerPost) checkForErrors() error { + return pt.baseCheckForErrors() +} - // Interpret the results - // -- Terminal states apply regardless - // -- Unknown states are per-service inprogress states - // -- Otherwise, infer state from HTTP status code - if pt.hasTerminated() { - ps.State = pt.state() - } else if pt.state() != "" { - ps.State = operationInProgress - } else { - switch resp.StatusCode { - case http.StatusAccepted: - ps.State = operationInProgress +func (pt pollingTrackerPost) provisioningStateApplicable() bool { + return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusNoContent +} - case http.StatusNoContent, http.StatusCreated, http.StatusOK: - ps.State = operationSucceeded +// PUT - default: - ps.State = operationFailed - } - } +type pollingTrackerPut struct { + pollingTrackerBase +} - if strings.EqualFold(ps.State, operationInProgress) && ps.URI == "" { - return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Unable to obtain polling URI for %s %s", resp.Request.Method, resp.Request.URL) +func (pt *pollingTrackerPut) updatePollingMethod() error { + // by default we can use the original URL for polling and final GET + if pt.URI == "" { + pt.URI = pt.resp.Request.URL.String() } - - // For failed operation, check for error code and message in - // -- Operation resource - // -- Response - // -- Otherwise, Unknown - if ps.hasFailed() { - if or, ok := pt.(*operationResource); ok { - ps.ServiceError = &or.OperationError - } else if p, ok := pt.(*provisioningStatus); ok && p.hasProvisioningError() { - ps.ServiceError = &p.ProvisioningError - } else { - ps.ServiceError = &ServiceError{ - Code: "Unknown", - Message: "None", + if pt.FinalGetURI == "" { + pt.FinalGetURI = pt.resp.Request.URL.String() + } + if pt.Pm == PollingUnknown { + pt.Pm = PollingRequestURI + } + // for 201 it's permissible for no headers to be returned + if pt.resp.StatusCode == http.StatusCreated { + if ao, err := getURLFromAsyncOpHeader(pt.resp); err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + } + // for 202 prefer the Azure-AsyncOperation header but fall back to Location if necessary + if pt.resp.StatusCode == http.StatusAccepted { + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } else if ao != "" { + pt.URI = ao + pt.Pm = PollingAsyncOperation + } + // if the Location header is invalid and we already have a polling URL + // then we don't care if the Location header URL is malformed. + if lh, err := getURLFromLocationHeader(pt.resp); err != nil && pt.URI == "" { + return err + } else if lh != "" { + if ao == "" { + pt.URI = lh + pt.Pm = PollingLocation } } + // make sure a polling URL was found + if pt.URI == "" { + return autorest.NewError("pollingTrackerPut", "updateHeaders", "didn't get any suitable polling URLs in 202 response") + } } return nil } -func newPollingRequest(ps pollingState) (*http.Request, error) { - reqPoll, err := autorest.Prepare(&http.Request{}, - autorest.AsGet(), - autorest.WithBaseURL(ps.URI)) +func (pt pollingTrackerPut) checkForErrors() error { + err := pt.baseCheckForErrors() + if err != nil { + return err + } + // if there are no LRO headers then the body cannot be empty + ao, err := getURLFromAsyncOpHeader(pt.resp) + if err != nil { + return err + } + lh, err := getURLFromLocationHeader(pt.resp) if err != nil { - return nil, autorest.NewErrorWithError(err, "azure", "newPollingRequest", nil, "Failure creating poll request to %s", ps.URI) + return err + } + if ao == "" && lh == "" && len(pt.rawBody) == 0 { + return autorest.NewError("pollingTrackerPut", "checkForErrors", "the response did not contain a body") + } + return nil +} + +func (pt pollingTrackerPut) provisioningStateApplicable() bool { + return pt.resp.StatusCode == http.StatusOK || pt.resp.StatusCode == http.StatusCreated +} + +// creates a polling tracker based on the verb of the original request +func createPollingTracker(resp *http.Response) (pollingTracker, error) { + var pt pollingTracker + switch strings.ToUpper(resp.Request.Method) { + case http.MethodDelete: + pt = &pollingTrackerDelete{pollingTrackerBase: pollingTrackerBase{resp: resp}} + case http.MethodPatch: + pt = &pollingTrackerPatch{pollingTrackerBase: pollingTrackerBase{resp: resp}} + case http.MethodPost: + pt = &pollingTrackerPost{pollingTrackerBase: pollingTrackerBase{resp: resp}} + case http.MethodPut: + pt = &pollingTrackerPut{pollingTrackerBase: pollingTrackerBase{resp: resp}} + default: + return nil, autorest.NewError("azure", "createPollingTracker", "unsupported HTTP method %s", resp.Request.Method) } + if err := pt.initializeState(); err != nil { + return pt, err + } + // this initializes the polling header values, we do this during creation in case the + // initial response send us invalid values; this way the API call will return a non-nil + // error (not doing this means the error shows up in Future.Done) + return pt, pt.updatePollingMethod() +} + +// gets the polling URL from the Azure-AsyncOperation header. +// ensures the URL is well-formed and absolute. +func getURLFromAsyncOpHeader(resp *http.Response) (string, error) { + s := resp.Header.Get(http.CanonicalHeaderKey(headerAsyncOperation)) + if s == "" { + return "", nil + } + if !isValidURL(s) { + return "", autorest.NewError("azure", "getURLFromAsyncOpHeader", "invalid polling URL '%s'", s) + } + return s, nil +} + +// gets the polling URL from the Location header. +// ensures the URL is well-formed and absolute. +func getURLFromLocationHeader(resp *http.Response) (string, error) { + s := resp.Header.Get(http.CanonicalHeaderKey(autorest.HeaderLocation)) + if s == "" { + return "", nil + } + if !isValidURL(s) { + return "", autorest.NewError("azure", "getURLFromLocationHeader", "invalid polling URL '%s'", s) + } + return s, nil +} - return reqPoll, nil +// verify that the URL is valid and absolute +func isValidURL(s string) bool { + u, err := url.Parse(s) + return err == nil && u.IsAbs() } +// PollingMethodType defines a type used for enumerating polling mechanisms. +type PollingMethodType string + +const ( + // PollingAsyncOperation indicates the polling method uses the Azure-AsyncOperation header. + PollingAsyncOperation PollingMethodType = "AsyncOperation" + + // PollingLocation indicates the polling method uses the Location header. + PollingLocation PollingMethodType = "Location" + + // PollingRequestURI indicates the polling method uses the original request URI. + PollingRequestURI PollingMethodType = "RequestURI" + + // PollingUnknown indicates an unknown polling method and is the default value. + PollingUnknown PollingMethodType = "" +) + // AsyncOpIncompleteError is the type that's returned from a future that has not completed. type AsyncOpIncompleteError struct { // FutureType is the name of the type composed of a azure.Future. diff --git a/vendor/github.com/Azure/go-autorest/autorest/azure/azure.go b/vendor/github.com/Azure/go-autorest/autorest/azure/azure.go index 18d029526f..b6c6314f07 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/azure/azure.go +++ b/vendor/github.com/Azure/go-autorest/autorest/azure/azure.go @@ -17,6 +17,7 @@ package azure // limitations under the License. import ( + "bytes" "encoding/json" "fmt" "io/ioutil" @@ -36,6 +37,9 @@ const ( // should be included in the response. HeaderReturnClientID = "x-ms-return-client-request-id" + // HeaderContentType is the type of the content in the HTTP response. + HeaderContentType = "Content-Type" + // HeaderRequestID is the Azure extension header of the service generated request ID returned // in the response. HeaderRequestID = "x-ms-request-id" @@ -44,11 +48,12 @@ const ( // ServiceError encapsulates the error response from an Azure service. // It adhears to the OData v4 specification for error responses. type ServiceError struct { - Code string `json:"code"` - Message string `json:"message"` - Target *string `json:"target"` - Details []map[string]interface{} `json:"details"` - InnerError map[string]interface{} `json:"innererror"` + Code string `json:"code"` + Message string `json:"message"` + Target *string `json:"target"` + Details []map[string]interface{} `json:"details"` + InnerError map[string]interface{} `json:"innererror"` + AdditionalInfo []map[string]interface{} `json:"additionalInfo"` } func (se ServiceError) Error() string { @@ -74,56 +79,98 @@ func (se ServiceError) Error() string { result += fmt.Sprintf(" InnerError=%v", string(d)) } + if se.AdditionalInfo != nil { + d, err := json.Marshal(se.AdditionalInfo) + if err != nil { + result += fmt.Sprintf(" AdditionalInfo=%v", se.AdditionalInfo) + } + result += fmt.Sprintf(" AdditionalInfo=%v", string(d)) + } + return result } // UnmarshalJSON implements the json.Unmarshaler interface for the ServiceError type. func (se *ServiceError) UnmarshalJSON(b []byte) error { - // per the OData v4 spec the details field must be an array of JSON objects. - // unfortunately not all services adhear to the spec and just return a single - // object instead of an array with one object. so we have to perform some - // shenanigans to accommodate both cases. // http://docs.oasis-open.org/odata/odata-json-format/v4.0/os/odata-json-format-v4.0-os.html#_Toc372793091 - type serviceError1 struct { - Code string `json:"code"` - Message string `json:"message"` - Target *string `json:"target"` - Details []map[string]interface{} `json:"details"` - InnerError map[string]interface{} `json:"innererror"` + type serviceErrorInternal struct { + Code string `json:"code"` + Message string `json:"message"` + Target *string `json:"target,omitempty"` + AdditionalInfo []map[string]interface{} `json:"additionalInfo,omitempty"` + // not all services conform to the OData v4 spec. + // the following fields are where we've seen discrepancies + + // spec calls for []map[string]interface{} but have seen map[string]interface{} + Details interface{} `json:"details,omitempty"` + + // spec calls for map[string]interface{} but have seen []map[string]interface{} and string + InnerError interface{} `json:"innererror,omitempty"` } - type serviceError2 struct { - Code string `json:"code"` - Message string `json:"message"` - Target *string `json:"target"` - Details map[string]interface{} `json:"details"` - InnerError map[string]interface{} `json:"innererror"` + sei := serviceErrorInternal{} + if err := json.Unmarshal(b, &sei); err != nil { + return err } - se1 := serviceError1{} - err := json.Unmarshal(b, &se1) - if err == nil { - se.populate(se1.Code, se1.Message, se1.Target, se1.Details, se1.InnerError) - return nil + // copy the fields we know to be correct + se.AdditionalInfo = sei.AdditionalInfo + se.Code = sei.Code + se.Message = sei.Message + se.Target = sei.Target + + // converts an []interface{} to []map[string]interface{} + arrayOfObjs := func(v interface{}) ([]map[string]interface{}, bool) { + arrayOf, ok := v.([]interface{}) + if !ok { + return nil, false + } + final := []map[string]interface{}{} + for _, item := range arrayOf { + as, ok := item.(map[string]interface{}) + if !ok { + return nil, false + } + final = append(final, as) + } + return final, true } - se2 := serviceError2{} - err = json.Unmarshal(b, &se2) - if err == nil { - se.populate(se2.Code, se2.Message, se2.Target, nil, se2.InnerError) - se.Details = append(se.Details, se2.Details) - return nil + // convert the remaining fields, falling back to raw JSON if necessary + + if c, ok := arrayOfObjs(sei.Details); ok { + se.Details = c + } else if c, ok := sei.Details.(map[string]interface{}); ok { + se.Details = []map[string]interface{}{c} + } else if sei.Details != nil { + // stuff into Details + se.Details = []map[string]interface{}{ + {"raw": sei.Details}, + } } - return err -} -func (se *ServiceError) populate(code, message string, target *string, details []map[string]interface{}, inner map[string]interface{}) { - se.Code = code - se.Message = message - se.Target = target - se.Details = details - se.InnerError = inner + if c, ok := sei.InnerError.(map[string]interface{}); ok { + se.InnerError = c + } else if c, ok := arrayOfObjs(sei.InnerError); ok { + // if there's only one error extract it + if len(c) == 1 { + se.InnerError = c[0] + } else { + // multiple errors, stuff them into the value + se.InnerError = map[string]interface{}{ + "multi": c, + } + } + } else if c, ok := sei.InnerError.(string); ok { + se.InnerError = map[string]interface{}{"error": c} + } else if sei.InnerError != nil { + // stuff into InnerError + se.InnerError = map[string]interface{}{ + "raw": sei.InnerError, + } + } + return nil } // RequestError describes an error response returned by Azure service. @@ -131,7 +178,7 @@ type RequestError struct { autorest.DetailedError // The error returned by the Azure service. - ServiceError *ServiceError `json:"error"` + ServiceError *ServiceError `json:"error" xml:"Error"` // The request id (from the x-ms-request-id-header) of the request. RequestID string @@ -158,8 +205,13 @@ type Resource struct { ResourceName string } +// String function returns a string in form of azureResourceID +func (r Resource) String() string { + return fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/%s/%s/%s", r.SubscriptionID, r.ResourceGroup, r.Provider, r.ResourceType, r.ResourceName) +} + // ParseResourceID parses a resource ID into a ResourceDetails struct. -// See https://docs.microsoft.com/en-us/azure/azure-resource-manager/resource-group-template-functions-resource#return-value-4. +// See https://docs.microsoft.com/en-us/azure/azure-resource-manager/templates/template-functions-resource?tabs=json#resourceid. func ParseResourceID(resourceID string) (Resource, error) { const resourceIDPatternText = `(?i)subscriptions/(.+)/resourceGroups/(.+)/providers/(.+?)/(.+?)/(.+)` @@ -273,22 +325,57 @@ func WithErrorUnlessStatusCode(codes ...int) autorest.RespondDecorator { var e RequestError defer resp.Body.Close() + encodedAs := autorest.EncodedAsJSON + if strings.Contains(resp.Header.Get("Content-Type"), "xml") { + encodedAs = autorest.EncodedAsXML + } + // Copy and replace the Body in case it does not contain an error object. // This will leave the Body available to the caller. - b, decodeErr := autorest.CopyAndDecode(autorest.EncodedAsJSON, resp.Body, &e) + b, decodeErr := autorest.CopyAndDecode(encodedAs, resp.Body, &e) resp.Body = ioutil.NopCloser(&b) if decodeErr != nil { return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), decodeErr) - } else if e.ServiceError == nil { + } + if e.ServiceError == nil { // Check if error is unwrapped ServiceError - if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil || e.ServiceError.Message == "" { + decoder := autorest.NewDecoder(encodedAs, bytes.NewReader(b.Bytes())) + if err := decoder.Decode(&e.ServiceError); err != nil { + return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), err) + } + + // for example, should the API return the literal value `null` as the response + if e.ServiceError == nil { e.ServiceError = &ServiceError{ Code: "Unknown", Message: "Unknown service error", + Details: []map[string]interface{}{ + { + "HttpResponse.Body": b.String(), + }, + }, } } } + if e.ServiceError != nil && e.ServiceError.Message == "" { + // if we're here it means the returned error wasn't OData v4 compliant. + // try to unmarshal the body in hopes of getting something. + rawBody := map[string]interface{}{} + decoder := autorest.NewDecoder(encodedAs, bytes.NewReader(b.Bytes())) + if err := decoder.Decode(&rawBody); err != nil { + return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), err) + } + + e.ServiceError = &ServiceError{ + Code: "Unknown", + Message: "Unknown service error", + } + if len(rawBody) > 0 { + e.ServiceError.Details = []map[string]interface{}{rawBody} + } + } + e.Response = resp e.RequestID = ExtractRequestID(resp) if e.StatusCode == nil { e.StatusCode = resp.StatusCode diff --git a/vendor/github.com/Azure/go-autorest/autorest/azure/environments.go b/vendor/github.com/Azure/go-autorest/autorest/azure/environments.go index 7e41f7fd99..3b61a2b6e9 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/azure/environments.go +++ b/vendor/github.com/Azure/go-autorest/autorest/azure/environments.go @@ -22,9 +22,14 @@ import ( "strings" ) -// EnvironmentFilepathName captures the name of the environment variable containing the path to the file -// to be used while populating the Azure Environment. -const EnvironmentFilepathName = "AZURE_ENVIRONMENT_FILEPATH" +const ( + // EnvironmentFilepathName captures the name of the environment variable containing the path to the file + // to be used while populating the Azure Environment. + EnvironmentFilepathName = "AZURE_ENVIRONMENT_FILEPATH" + + // NotAvailable is used for endpoints and resource IDs that are not available for a given cloud. + NotAvailable = "N/A" +) var environments = map[string]Environment{ "AZURECHINACLOUD": ChinaCloud, @@ -33,28 +38,48 @@ var environments = map[string]Environment{ "AZUREUSGOVERNMENTCLOUD": USGovernmentCloud, } +// ResourceIdentifier contains a set of Azure resource IDs. +type ResourceIdentifier struct { + Graph string `json:"graph"` + KeyVault string `json:"keyVault"` + Datalake string `json:"datalake"` + Batch string `json:"batch"` + OperationalInsights string `json:"operationalInsights"` + OSSRDBMS string `json:"ossRDBMS"` + Storage string `json:"storage"` + Synapse string `json:"synapse"` + ServiceBus string `json:"serviceBus"` +} + // Environment represents a set of endpoints for each of Azure's Clouds. type Environment struct { - Name string `json:"name"` - ManagementPortalURL string `json:"managementPortalURL"` - PublishSettingsURL string `json:"publishSettingsURL"` - ServiceManagementEndpoint string `json:"serviceManagementEndpoint"` - ResourceManagerEndpoint string `json:"resourceManagerEndpoint"` - ActiveDirectoryEndpoint string `json:"activeDirectoryEndpoint"` - GalleryEndpoint string `json:"galleryEndpoint"` - KeyVaultEndpoint string `json:"keyVaultEndpoint"` - GraphEndpoint string `json:"graphEndpoint"` - ServiceBusEndpoint string `json:"serviceBusEndpoint"` - BatchManagementEndpoint string `json:"batchManagementEndpoint"` - StorageEndpointSuffix string `json:"storageEndpointSuffix"` - SQLDatabaseDNSSuffix string `json:"sqlDatabaseDNSSuffix"` - TrafficManagerDNSSuffix string `json:"trafficManagerDNSSuffix"` - KeyVaultDNSSuffix string `json:"keyVaultDNSSuffix"` - ServiceBusEndpointSuffix string `json:"serviceBusEndpointSuffix"` - ServiceManagementVMDNSSuffix string `json:"serviceManagementVMDNSSuffix"` - ResourceManagerVMDNSSuffix string `json:"resourceManagerVMDNSSuffix"` - ContainerRegistryDNSSuffix string `json:"containerRegistryDNSSuffix"` - TokenAudience string `json:"tokenAudience"` + Name string `json:"name"` + ManagementPortalURL string `json:"managementPortalURL"` + PublishSettingsURL string `json:"publishSettingsURL"` + ServiceManagementEndpoint string `json:"serviceManagementEndpoint"` + ResourceManagerEndpoint string `json:"resourceManagerEndpoint"` + ActiveDirectoryEndpoint string `json:"activeDirectoryEndpoint"` + GalleryEndpoint string `json:"galleryEndpoint"` + KeyVaultEndpoint string `json:"keyVaultEndpoint"` + GraphEndpoint string `json:"graphEndpoint"` + ServiceBusEndpoint string `json:"serviceBusEndpoint"` + BatchManagementEndpoint string `json:"batchManagementEndpoint"` + StorageEndpointSuffix string `json:"storageEndpointSuffix"` + CosmosDBDNSSuffix string `json:"cosmosDBDNSSuffix"` + MariaDBDNSSuffix string `json:"mariaDBDNSSuffix"` + MySQLDatabaseDNSSuffix string `json:"mySqlDatabaseDNSSuffix"` + PostgresqlDatabaseDNSSuffix string `json:"postgresqlDatabaseDNSSuffix"` + SQLDatabaseDNSSuffix string `json:"sqlDatabaseDNSSuffix"` + TrafficManagerDNSSuffix string `json:"trafficManagerDNSSuffix"` + KeyVaultDNSSuffix string `json:"keyVaultDNSSuffix"` + ServiceBusEndpointSuffix string `json:"serviceBusEndpointSuffix"` + ServiceManagementVMDNSSuffix string `json:"serviceManagementVMDNSSuffix"` + ResourceManagerVMDNSSuffix string `json:"resourceManagerVMDNSSuffix"` + ContainerRegistryDNSSuffix string `json:"containerRegistryDNSSuffix"` + TokenAudience string `json:"tokenAudience"` + APIManagementHostNameSuffix string `json:"apiManagementHostNameSuffix"` + SynapseEndpointSuffix string `json:"synapseEndpointSuffix"` + ResourceIdentifiers ResourceIdentifier `json:"resourceIdentifiers"` } var ( @@ -72,6 +97,10 @@ var ( ServiceBusEndpoint: "https://servicebus.windows.net/", BatchManagementEndpoint: "https://batch.core.windows.net/", StorageEndpointSuffix: "core.windows.net", + CosmosDBDNSSuffix: "documents.azure.com", + MariaDBDNSSuffix: "mariadb.database.azure.com", + MySQLDatabaseDNSSuffix: "mysql.database.azure.com", + PostgresqlDatabaseDNSSuffix: "postgres.database.azure.com", SQLDatabaseDNSSuffix: "database.windows.net", TrafficManagerDNSSuffix: "trafficmanager.net", KeyVaultDNSSuffix: "vault.azure.net", @@ -80,6 +109,19 @@ var ( ResourceManagerVMDNSSuffix: "cloudapp.azure.com", ContainerRegistryDNSSuffix: "azurecr.io", TokenAudience: "https://management.azure.com/", + APIManagementHostNameSuffix: "azure-api.net", + SynapseEndpointSuffix: "dev.azuresynapse.net", + ResourceIdentifiers: ResourceIdentifier{ + Graph: "https://graph.windows.net/", + KeyVault: "https://vault.azure.net", + Datalake: "https://datalake.azure.net/", + Batch: "https://batch.core.windows.net/", + OperationalInsights: "https://api.loganalytics.io", + OSSRDBMS: "https://ossrdbms-aad.database.windows.net", + Storage: "https://storage.azure.com/", + Synapse: "https://dev.azuresynapse.net", + ServiceBus: "https://servicebus.azure.net/", + }, } // USGovernmentCloud is the cloud environment for the US Government @@ -96,14 +138,31 @@ var ( ServiceBusEndpoint: "https://servicebus.usgovcloudapi.net/", BatchManagementEndpoint: "https://batch.core.usgovcloudapi.net/", StorageEndpointSuffix: "core.usgovcloudapi.net", + CosmosDBDNSSuffix: "documents.azure.us", + MariaDBDNSSuffix: "mariadb.database.usgovcloudapi.net", + MySQLDatabaseDNSSuffix: "mysql.database.usgovcloudapi.net", + PostgresqlDatabaseDNSSuffix: "postgres.database.usgovcloudapi.net", SQLDatabaseDNSSuffix: "database.usgovcloudapi.net", TrafficManagerDNSSuffix: "usgovtrafficmanager.net", KeyVaultDNSSuffix: "vault.usgovcloudapi.net", ServiceBusEndpointSuffix: "servicebus.usgovcloudapi.net", ServiceManagementVMDNSSuffix: "usgovcloudapp.net", - ResourceManagerVMDNSSuffix: "cloudapp.windowsazure.us", - ContainerRegistryDNSSuffix: "azurecr.io", + ResourceManagerVMDNSSuffix: "cloudapp.usgovcloudapi.net", + ContainerRegistryDNSSuffix: "azurecr.us", TokenAudience: "https://management.usgovcloudapi.net/", + APIManagementHostNameSuffix: "azure-api.us", + SynapseEndpointSuffix: NotAvailable, + ResourceIdentifiers: ResourceIdentifier{ + Graph: "https://graph.windows.net/", + KeyVault: "https://vault.usgovcloudapi.net", + Datalake: NotAvailable, + Batch: "https://batch.core.usgovcloudapi.net/", + OperationalInsights: "https://api.loganalytics.us", + OSSRDBMS: "https://ossrdbms-aad.database.usgovcloudapi.net", + Storage: "https://storage.azure.com/", + Synapse: NotAvailable, + ServiceBus: "https://servicebus.azure.net/", + }, } // ChinaCloud is the cloud environment operated in China @@ -120,14 +179,31 @@ var ( ServiceBusEndpoint: "https://servicebus.chinacloudapi.cn/", BatchManagementEndpoint: "https://batch.chinacloudapi.cn/", StorageEndpointSuffix: "core.chinacloudapi.cn", + CosmosDBDNSSuffix: "documents.azure.cn", + MariaDBDNSSuffix: "mariadb.database.chinacloudapi.cn", + MySQLDatabaseDNSSuffix: "mysql.database.chinacloudapi.cn", + PostgresqlDatabaseDNSSuffix: "postgres.database.chinacloudapi.cn", SQLDatabaseDNSSuffix: "database.chinacloudapi.cn", TrafficManagerDNSSuffix: "trafficmanager.cn", KeyVaultDNSSuffix: "vault.azure.cn", ServiceBusEndpointSuffix: "servicebus.chinacloudapi.cn", ServiceManagementVMDNSSuffix: "chinacloudapp.cn", - ResourceManagerVMDNSSuffix: "cloudapp.azure.cn", - ContainerRegistryDNSSuffix: "azurecr.io", + ResourceManagerVMDNSSuffix: "cloudapp.chinacloudapi.cn", + ContainerRegistryDNSSuffix: "azurecr.cn", TokenAudience: "https://management.chinacloudapi.cn/", + APIManagementHostNameSuffix: "azure-api.cn", + SynapseEndpointSuffix: "dev.azuresynapse.azure.cn", + ResourceIdentifiers: ResourceIdentifier{ + Graph: "https://graph.chinacloudapi.cn/", + KeyVault: "https://vault.azure.cn", + Datalake: NotAvailable, + Batch: "https://batch.chinacloudapi.cn/", + OperationalInsights: NotAvailable, + OSSRDBMS: "https://ossrdbms-aad.database.chinacloudapi.cn", + Storage: "https://storage.azure.com/", + Synapse: "https://dev.azuresynapse.net", + ServiceBus: "https://servicebus.azure.net/", + }, } // GermanCloud is the cloud environment operated in Germany @@ -144,14 +220,31 @@ var ( ServiceBusEndpoint: "https://servicebus.cloudapi.de/", BatchManagementEndpoint: "https://batch.cloudapi.de/", StorageEndpointSuffix: "core.cloudapi.de", + CosmosDBDNSSuffix: "documents.microsoftazure.de", + MariaDBDNSSuffix: "mariadb.database.cloudapi.de", + MySQLDatabaseDNSSuffix: "mysql.database.cloudapi.de", + PostgresqlDatabaseDNSSuffix: "postgres.database.cloudapi.de", SQLDatabaseDNSSuffix: "database.cloudapi.de", TrafficManagerDNSSuffix: "azuretrafficmanager.de", KeyVaultDNSSuffix: "vault.microsoftazure.de", ServiceBusEndpointSuffix: "servicebus.cloudapi.de", ServiceManagementVMDNSSuffix: "azurecloudapp.de", ResourceManagerVMDNSSuffix: "cloudapp.microsoftazure.de", - ContainerRegistryDNSSuffix: "azurecr.io", + ContainerRegistryDNSSuffix: NotAvailable, TokenAudience: "https://management.microsoftazure.de/", + APIManagementHostNameSuffix: NotAvailable, + SynapseEndpointSuffix: NotAvailable, + ResourceIdentifiers: ResourceIdentifier{ + Graph: "https://graph.cloudapi.de/", + KeyVault: "https://vault.microsoftazure.de", + Datalake: NotAvailable, + Batch: "https://batch.cloudapi.de/", + OperationalInsights: NotAvailable, + OSSRDBMS: "https://ossrdbms-aad.database.cloudapi.de", + Storage: "https://storage.azure.com/", + Synapse: NotAvailable, + ServiceBus: "https://servicebus.azure.net/", + }, } ) @@ -189,3 +282,8 @@ func EnvironmentFromFile(location string) (unmarshaled Environment, err error) { return } + +// SetEnvironment updates the environment map with the specified values. +func SetEnvironment(name string, env Environment) { + environments[strings.ToUpper(name)] = env +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/azure/rp.go b/vendor/github.com/Azure/go-autorest/autorest/azure/rp.go index 65ad0afc82..c6d39f6866 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/azure/rp.go +++ b/vendor/github.com/Azure/go-autorest/autorest/azure/rp.go @@ -47,11 +47,15 @@ func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator { if resp.StatusCode != http.StatusConflict || client.SkipResourceProviderRegistration { return resp, err } + var re RequestError - err = autorest.Respond( - resp, - autorest.ByUnmarshallingJSON(&re), - ) + if strings.Contains(r.Header.Get("Content-Type"), "xml") { + // XML errors (e.g. Storage Data Plane) only return the inner object + err = autorest.Respond(resp, autorest.ByUnmarshallingXML(&re.ServiceError)) + } else { + err = autorest.Respond(resp, autorest.ByUnmarshallingJSON(&re)) + } + if err != nil { return resp, err } @@ -64,7 +68,7 @@ func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator { } } } - return resp, fmt.Errorf("failed request: %s", err) + return resp, err }) } } @@ -140,8 +144,8 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError } // poll for registered provisioning state - now := time.Now() - for err == nil && time.Since(now) < client.PollingDuration { + registrationStartTime := time.Now() + for err == nil && (client.PollingDuration == 0 || (client.PollingDuration != 0 && time.Since(registrationStartTime) < client.PollingDuration)) { // taken from the resources SDK // https://github.com/Azure/azure-sdk-for-go/blob/9f366792afa3e0ddaecdc860e793ba9d75e76c27/arm/resources/resources/providers.go#L45 preparer := autorest.CreatePreparer( @@ -183,7 +187,7 @@ func register(client autorest.Client, originalReq *http.Request, re RequestError return originalReq.Context().Err() } } - if !(time.Since(now) < client.PollingDuration) { + if client.PollingDuration != 0 && !(time.Since(registrationStartTime) < client.PollingDuration) { return errors.New("polling for resource provider registration has exceeded the polling duration") } return err diff --git a/vendor/github.com/Azure/go-autorest/autorest/client.go b/vendor/github.com/Azure/go-autorest/autorest/client.go index 4e92dcad07..bb5f9396e9 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/client.go +++ b/vendor/github.com/Azure/go-autorest/autorest/client.go @@ -16,19 +16,22 @@ package autorest import ( "bytes" + "crypto/tls" + "errors" "fmt" "io" "io/ioutil" "log" "net/http" - "net/http/cookiejar" - "runtime" + "strings" "time" + + "github.com/Azure/go-autorest/logger" ) const ( // DefaultPollingDelay is a reasonable delay between polling requests. - DefaultPollingDelay = 60 * time.Second + DefaultPollingDelay = 30 * time.Second // DefaultPollingDuration is a reasonable total polling duration. DefaultPollingDuration = 15 * time.Minute @@ -41,15 +44,6 @@ const ( ) var ( - // defaultUserAgent builds a string containing the Go version, system archityecture and OS, - // and the go-autorest version. - defaultUserAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s", - runtime.Version(), - runtime.GOARCH, - runtime.GOOS, - Version(), - ) - // StatusCodesForRetry are a defined group of status code for which the client will retry StatusCodesForRetry = []int{ http.StatusRequestTimeout, // 408 @@ -78,6 +72,22 @@ type Response struct { *http.Response `json:"-"` } +// IsHTTPStatus returns true if the returned HTTP status code matches the provided status code. +// If there was no response (i.e. the underlying http.Response is nil) the return value is false. +func (r Response) IsHTTPStatus(statusCode int) bool { + if r.Response == nil { + return false + } + return r.Response.StatusCode == statusCode +} + +// HasHTTPStatus returns true if the returned HTTP status code matches one of the provided status codes. +// If there was no response (i.e. the underlying http.Response is nil) or not status codes are provided +// the return value is false. +func (r Response) HasHTTPStatus(statusCodes ...int) bool { + return ResponseHasStatusCode(r.Response, statusCodes...) +} + // LoggingInspector implements request and response inspectors that log the full request and // response to a supplied log. type LoggingInspector struct { @@ -153,9 +163,11 @@ type Client struct { PollingDelay time.Duration // PollingDuration sets the maximum polling time after which an error is returned. + // Setting this to zero will use the provided context to control the duration. PollingDuration time.Duration - // RetryAttempts sets the default number of retry attempts for client. + // RetryAttempts sets the total number of times the client will attempt to make an HTTP request. + // Set the value to 1 to disable retries. DO NOT set the value to less than 1. RetryAttempts int // RetryDuration sets the delay duration for retries. @@ -169,19 +181,42 @@ type Client struct { // Set to true to skip attempted registration of resource providers (false by default). SkipResourceProviderRegistration bool + + // SendDecorators can be used to override the default chain of SendDecorators. + // This can be used to specify things like a custom retry SendDecorator. + // Set this to an empty slice to use no SendDecorators. + SendDecorators []SendDecorator } // NewClientWithUserAgent returns an instance of a Client with the UserAgent set to the passed // string. func NewClientWithUserAgent(ua string) Client { + return newClient(ua, tls.RenegotiateNever) +} + +// ClientOptions contains various Client configuration options. +type ClientOptions struct { + // UserAgent is an optional user-agent string to append to the default user agent. + UserAgent string + + // Renegotiation is an optional setting to control client-side TLS renegotiation. + Renegotiation tls.RenegotiationSupport +} + +// NewClientWithOptions returns an instance of a Client with the specified values. +func NewClientWithOptions(options ClientOptions) Client { + return newClient(options.UserAgent, options.Renegotiation) +} + +func newClient(ua string, renegotiation tls.RenegotiationSupport) Client { c := Client{ PollingDelay: DefaultPollingDelay, PollingDuration: DefaultPollingDuration, RetryAttempts: DefaultRetryAttempts, RetryDuration: DefaultRetryDuration, - UserAgent: defaultUserAgent, + UserAgent: UserAgent(), } - c.Sender = c.sender() + c.Sender = c.sender(renegotiation) c.AddToUserAgent(ua) return c } @@ -216,17 +251,28 @@ func (c Client) Do(r *http.Request) (*http.Response, error) { } return resp, NewErrorWithError(err, "autorest/Client", "Do", nil, "Preparing request failed") } - - resp, err := SendWithSender(c.sender(), r) + logger.Instance.WriteRequest(r, logger.Filter{ + Header: func(k string, v []string) (bool, []string) { + // remove the auth token from the log + if strings.EqualFold(k, "Authorization") || strings.EqualFold(k, "Ocp-Apim-Subscription-Key") { + v = []string{"**REDACTED**"} + } + return true, v + }, + }) + resp, err := SendWithSender(c.sender(tls.RenegotiateNever), r) + if resp == nil && err == nil { + err = errors.New("autorest: received nil response and error") + } + logger.Instance.WriteResponse(resp, logger.Filter{}) Respond(resp, c.ByInspecting()) return resp, err } // sender returns the Sender to which to send requests. -func (c Client) sender() Sender { +func (c Client) sender(renengotiation tls.RenegotiationSupport) Sender { if c.Sender == nil { - j, _ := cookiejar.New(nil) - return &http.Client{Jar: j} + return sender(renengotiation) } return c.Sender } @@ -262,3 +308,21 @@ func (c Client) ByInspecting() RespondDecorator { } return c.ResponseInspector } + +// Send sends the provided http.Request using the client's Sender or the default sender. +// It returns the http.Response and possible error. It also accepts a, possibly empty, +// default set of SendDecorators used when sending the request. +// SendDecorators have the following precedence: +// 1. In a request's context via WithSendDecorators() +// 2. Specified on the client in SendDecorators +// 3. The default values specified in this method +func (c Client) Send(req *http.Request, decorators ...SendDecorator) (*http.Response, error) { + if c.SendDecorators != nil { + decorators = c.SendDecorators + } + inCtx := req.Context().Value(ctxSendDecorators{}) + if sd, ok := inCtx.([]SendDecorator); ok { + decorators = sd + } + return SendWithSender(c, req, decorators...) +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/date/go.mod b/vendor/github.com/Azure/go-autorest/autorest/date/go.mod new file mode 100644 index 0000000000..f88ecc4022 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/date/go.mod @@ -0,0 +1,5 @@ +module github.com/Azure/go-autorest/autorest/date + +go 1.12 + +require github.com/Azure/go-autorest v14.2.0+incompatible diff --git a/vendor/github.com/Azure/go-autorest/autorest/date/go_mod_tidy_hack.go b/vendor/github.com/Azure/go-autorest/autorest/date/go_mod_tidy_hack.go new file mode 100644 index 0000000000..4e05432071 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/date/go_mod_tidy_hack.go @@ -0,0 +1,24 @@ +// +build modhack + +package date + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file, and the github.com/Azure/go-autorest import, won't actually become part of +// the resultant binary. + +// Necessary for safely adding multi-module repo. +// See: https://github.com/golang/go/wiki/Modules#is-it-possible-to-add-a-module-to-a-multi-module-repository +import _ "github.com/Azure/go-autorest" diff --git a/vendor/github.com/Azure/go-autorest/autorest/error.go b/vendor/github.com/Azure/go-autorest/autorest/error.go index f724f33327..35098eda8e 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/error.go +++ b/vendor/github.com/Azure/go-autorest/autorest/error.go @@ -96,3 +96,8 @@ func (e DetailedError) Error() string { } return fmt.Sprintf("%s#%s: %s: StatusCode=%d -- Original Error: %v", e.PackageType, e.Method, e.Message, e.StatusCode, e.Original) } + +// Unwrap returns the original error. +func (e DetailedError) Unwrap() error { + return e.Original +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/go.mod b/vendor/github.com/Azure/go-autorest/autorest/go.mod new file mode 100644 index 0000000000..fd0b2c0c32 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/go.mod @@ -0,0 +1,12 @@ +module github.com/Azure/go-autorest/autorest + +go 1.12 + +require ( + github.com/Azure/go-autorest v14.2.0+incompatible + github.com/Azure/go-autorest/autorest/adal v0.9.13 + github.com/Azure/go-autorest/autorest/mocks v0.4.1 + github.com/Azure/go-autorest/logger v0.2.1 + github.com/Azure/go-autorest/tracing v0.6.0 + golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0 +) diff --git a/vendor/github.com/Azure/go-autorest/autorest/go_mod_tidy_hack.go b/vendor/github.com/Azure/go-autorest/autorest/go_mod_tidy_hack.go new file mode 100644 index 0000000000..da65e1041e --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/go_mod_tidy_hack.go @@ -0,0 +1,24 @@ +// +build modhack + +package autorest + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file, and the github.com/Azure/go-autorest import, won't actually become part of +// the resultant binary. + +// Necessary for safely adding multi-module repo. +// See: https://github.com/golang/go/wiki/Modules#is-it-possible-to-add-a-module-to-a-multi-module-repository +import _ "github.com/Azure/go-autorest" diff --git a/vendor/github.com/Azure/go-autorest/autorest/preparer.go b/vendor/github.com/Azure/go-autorest/autorest/preparer.go index 6d67bd7337..98574a4155 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/preparer.go +++ b/vendor/github.com/Azure/go-autorest/autorest/preparer.go @@ -16,7 +16,9 @@ package autorest import ( "bytes" + "context" "encoding/json" + "encoding/xml" "fmt" "io" "io/ioutil" @@ -31,11 +33,33 @@ const ( mimeTypeOctetStream = "application/octet-stream" mimeTypeFormPost = "application/x-www-form-urlencoded" - headerAuthorization = "Authorization" - headerContentType = "Content-Type" - headerUserAgent = "User-Agent" + headerAuthorization = "Authorization" + headerAuxAuthorization = "x-ms-authorization-auxiliary" + headerContentType = "Content-Type" + headerUserAgent = "User-Agent" ) +// used as a key type in context.WithValue() +type ctxPrepareDecorators struct{} + +// WithPrepareDecorators adds the specified PrepareDecorators to the provided context. +// If no PrepareDecorators are provided the context is unchanged. +func WithPrepareDecorators(ctx context.Context, prepareDecorator []PrepareDecorator) context.Context { + if len(prepareDecorator) == 0 { + return ctx + } + return context.WithValue(ctx, ctxPrepareDecorators{}, prepareDecorator) +} + +// GetPrepareDecorators returns the PrepareDecorators in the provided context or the provided default PrepareDecorators. +func GetPrepareDecorators(ctx context.Context, defaultPrepareDecorators ...PrepareDecorator) []PrepareDecorator { + inCtx := ctx.Value(ctxPrepareDecorators{}) + if pd, ok := inCtx.([]PrepareDecorator); ok { + return pd + } + return defaultPrepareDecorators +} + // Preparer is the interface that wraps the Prepare method. // // Prepare accepts and possibly modifies an http.Request (e.g., adding Headers). Implementations @@ -103,10 +127,7 @@ func WithHeader(header string, value string) PrepareDecorator { return PreparerFunc(func(r *http.Request) (*http.Request, error) { r, err := p.Prepare(r) if err == nil { - if r.Header == nil { - r.Header = make(http.Header) - } - r.Header.Set(http.CanonicalHeaderKey(header), value) + setHeader(r, http.CanonicalHeaderKey(header), value) } return r, err }) @@ -190,6 +211,9 @@ func AsGet() PrepareDecorator { return WithMethod("GET") } // AsHead returns a PrepareDecorator that sets the HTTP method to HEAD. func AsHead() PrepareDecorator { return WithMethod("HEAD") } +// AsMerge returns a PrepareDecorator that sets the HTTP method to MERGE. +func AsMerge() PrepareDecorator { return WithMethod("MERGE") } + // AsOptions returns a PrepareDecorator that sets the HTTP method to OPTIONS. func AsOptions() PrepareDecorator { return WithMethod("OPTIONS") } @@ -203,7 +227,7 @@ func AsPost() PrepareDecorator { return WithMethod("POST") } func AsPut() PrepareDecorator { return WithMethod("PUT") } // WithBaseURL returns a PrepareDecorator that populates the http.Request with a url.URL constructed -// from the supplied baseUrl. +// from the supplied baseUrl. Query parameters will be encoded as required. func WithBaseURL(baseURL string) PrepareDecorator { return func(p Preparer) Preparer { return PreparerFunc(func(r *http.Request) (*http.Request, error) { @@ -214,11 +238,35 @@ func WithBaseURL(baseURL string) PrepareDecorator { return r, err } if u.Scheme == "" { - err = fmt.Errorf("autorest: No scheme detected in URL %s", baseURL) + return r, fmt.Errorf("autorest: No scheme detected in URL %s", baseURL) } - if err == nil { - r.URL = u + if u.RawQuery != "" { + q, err := url.ParseQuery(u.RawQuery) + if err != nil { + return r, err + } + u.RawQuery = q.Encode() + } + r.URL = u + } + return r, err + }) + } +} + +// WithBytes returns a PrepareDecorator that takes a list of bytes +// which passes the bytes directly to the body +func WithBytes(input *[]byte) PrepareDecorator { + return func(p Preparer) Preparer { + return PreparerFunc(func(r *http.Request) (*http.Request, error) { + r, err := p.Prepare(r) + if err == nil { + if input == nil { + return r, fmt.Errorf("Input Bytes was nil") } + + r.ContentLength = int64(len(*input)) + r.Body = ioutil.NopCloser(bytes.NewReader(*input)) } return r, err }) @@ -244,10 +292,7 @@ func WithFormData(v url.Values) PrepareDecorator { if err == nil { s := v.Encode() - if r.Header == nil { - r.Header = make(http.Header) - } - r.Header.Set(http.CanonicalHeaderKey(headerContentType), mimeTypeFormPost) + setHeader(r, http.CanonicalHeaderKey(headerContentType), mimeTypeFormPost) r.ContentLength = int64(len(s)) r.Body = ioutil.NopCloser(strings.NewReader(s)) } @@ -283,10 +328,7 @@ func WithMultiPartFormData(formDataParameters map[string]interface{}) PrepareDec if err = writer.Close(); err != nil { return r, err } - if r.Header == nil { - r.Header = make(http.Header) - } - r.Header.Set(http.CanonicalHeaderKey(headerContentType), writer.FormDataContentType()) + setHeader(r, http.CanonicalHeaderKey(headerContentType), writer.FormDataContentType()) r.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes())) r.ContentLength = int64(body.Len()) return r, err @@ -377,6 +419,29 @@ func WithJSON(v interface{}) PrepareDecorator { } } +// WithXML returns a PrepareDecorator that encodes the data passed as XML into the body of the +// request and sets the Content-Length header. +func WithXML(v interface{}) PrepareDecorator { + return func(p Preparer) Preparer { + return PreparerFunc(func(r *http.Request) (*http.Request, error) { + r, err := p.Prepare(r) + if err == nil { + b, err := xml.Marshal(v) + if err == nil { + // we have to tack on an XML header + withHeader := xml.Header + string(b) + bytesWithHeader := []byte(withHeader) + + r.ContentLength = int64(len(bytesWithHeader)) + setHeader(r, headerContentLength, fmt.Sprintf("%d", len(bytesWithHeader))) + r.Body = ioutil.NopCloser(bytes.NewReader(bytesWithHeader)) + } + } + return r, err + }) + } +} + // WithPath returns a PrepareDecorator that adds the supplied path to the request URL. If the path // is absolute (that is, it begins with a "/"), it replaces the existing path. func WithPath(path string) PrepareDecorator { @@ -455,7 +520,7 @@ func parseURL(u *url.URL, path string) (*url.URL, error) { // WithQueryParameters returns a PrepareDecorators that encodes and applies the query parameters // given in the supplied map (i.e., key=value). func WithQueryParameters(queryParameters map[string]interface{}) PrepareDecorator { - parameters := ensureValueStrings(queryParameters) + parameters := MapToValues(queryParameters) return func(p Preparer) Preparer { return PreparerFunc(func(r *http.Request) (*http.Request, error) { r, err := p.Prepare(r) @@ -463,14 +528,16 @@ func WithQueryParameters(queryParameters map[string]interface{}) PrepareDecorato if r.URL == nil { return r, NewError("autorest", "WithQueryParameters", "Invoked with a nil URL") } - v := r.URL.Query() for key, value := range parameters { - d, err := url.QueryUnescape(value) - if err != nil { - return r, err + for i := range value { + d, err := url.QueryUnescape(value[i]) + if err != nil { + return r, err + } + value[i] = d } - v.Add(key, d) + v[key] = value } r.URL.RawQuery = v.Encode() } diff --git a/vendor/github.com/Azure/go-autorest/autorest/responder.go b/vendor/github.com/Azure/go-autorest/autorest/responder.go index a908a0adb7..349e1963a2 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/responder.go +++ b/vendor/github.com/Azure/go-autorest/autorest/responder.go @@ -153,6 +153,25 @@ func ByClosingIfError() RespondDecorator { } } +// ByUnmarshallingBytes returns a RespondDecorator that copies the Bytes returned in the +// response Body into the value pointed to by v. +func ByUnmarshallingBytes(v *[]byte) RespondDecorator { + return func(r Responder) Responder { + return ResponderFunc(func(resp *http.Response) error { + err := r.Respond(resp) + if err == nil { + bytes, errInner := ioutil.ReadAll(resp.Body) + if errInner != nil { + err = fmt.Errorf("Error occurred reading http.Response#Body - Error = '%v'", errInner) + } else { + *v = bytes + } + } + return err + }) + } +} + // ByUnmarshallingJSON returns a RespondDecorator that decodes a JSON document returned in the // response Body into the value pointed to by v. func ByUnmarshallingJSON(v interface{}) RespondDecorator { diff --git a/vendor/github.com/Azure/go-autorest/autorest/sender.go b/vendor/github.com/Azure/go-autorest/autorest/sender.go index cacbd81571..6818bd1859 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/sender.go +++ b/vendor/github.com/Azure/go-autorest/autorest/sender.go @@ -15,14 +15,59 @@ package autorest // limitations under the License. import ( + "context" + "crypto/tls" "fmt" "log" "math" "net/http" + "net/http/cookiejar" "strconv" + "sync" "time" + + "github.com/Azure/go-autorest/logger" + "github.com/Azure/go-autorest/tracing" ) +// there is one sender per TLS renegotiation type, i.e. count of tls.RenegotiationSupport enums +const defaultSendersCount = 3 + +type defaultSender struct { + sender Sender + init *sync.Once +} + +// each type of sender will be created on demand in sender() +var defaultSenders [defaultSendersCount]defaultSender + +func init() { + for i := 0; i < defaultSendersCount; i++ { + defaultSenders[i].init = &sync.Once{} + } +} + +// used as a key type in context.WithValue() +type ctxSendDecorators struct{} + +// WithSendDecorators adds the specified SendDecorators to the provided context. +// If no SendDecorators are provided the context is unchanged. +func WithSendDecorators(ctx context.Context, sendDecorator []SendDecorator) context.Context { + if len(sendDecorator) == 0 { + return ctx + } + return context.WithValue(ctx, ctxSendDecorators{}, sendDecorator) +} + +// GetSendDecorators returns the SendDecorators in the provided context or the provided default SendDecorators. +func GetSendDecorators(ctx context.Context, defaultSendDecorators ...SendDecorator) []SendDecorator { + inCtx := ctx.Value(ctxSendDecorators{}) + if sd, ok := inCtx.([]SendDecorator); ok { + return sd + } + return defaultSendDecorators +} + // Sender is the interface that wraps the Do method to send HTTP requests. // // The standard http.Client conforms to this interface. @@ -38,14 +83,14 @@ func (sf SenderFunc) Do(r *http.Request) (*http.Response, error) { return sf(r) } -// SendDecorator takes and possibily decorates, by wrapping, a Sender. Decorators may affect the +// SendDecorator takes and possibly decorates, by wrapping, a Sender. Decorators may affect the // http.Request and pass it along or, first, pass the http.Request along then react to the // http.Response result. type SendDecorator func(Sender) Sender // CreateSender creates, decorates, and returns, as a Sender, the default http.Client. func CreateSender(decorators ...SendDecorator) Sender { - return DecorateSender(&http.Client{}, decorators...) + return DecorateSender(sender(tls.RenegotiateNever), decorators...) } // DecorateSender accepts a Sender and a, possibly empty, set of SendDecorators, which is applies to @@ -68,7 +113,7 @@ func DecorateSender(s Sender, decorators ...SendDecorator) Sender { // // Send will not poll or retry requests. func Send(r *http.Request, decorators ...SendDecorator) (*http.Response, error) { - return SendWithSender(&http.Client{}, r, decorators...) + return SendWithSender(sender(tls.RenegotiateNever), r, decorators...) } // SendWithSender sends the passed http.Request, through the provided Sender, returning the @@ -80,6 +125,34 @@ func SendWithSender(s Sender, r *http.Request, decorators ...SendDecorator) (*ht return DecorateSender(s, decorators...).Do(r) } +func sender(renengotiation tls.RenegotiationSupport) Sender { + // note that we can't init defaultSenders in init() since it will + // execute before calling code has had a chance to enable tracing + defaultSenders[renengotiation].init.Do(func() { + // Use behaviour compatible with DefaultTransport, but require TLS minimum version. + defaultTransport := http.DefaultTransport.(*http.Transport) + transport := &http.Transport{ + Proxy: defaultTransport.Proxy, + DialContext: defaultTransport.DialContext, + MaxIdleConns: defaultTransport.MaxIdleConns, + IdleConnTimeout: defaultTransport.IdleConnTimeout, + TLSHandshakeTimeout: defaultTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + Renegotiation: renengotiation, + }, + } + var roundTripper http.RoundTripper = transport + if tracing.IsEnabled() { + roundTripper = tracing.NewTransport(transport) + } + j, _ := cookiejar.New(nil) + defaultSenders[renengotiation].sender = &http.Client{Jar: j, Transport: roundTripper} + }) + return defaultSenders[renengotiation].sender +} + // AfterDelay returns a SendDecorator that delays for the passed time.Duration before // invoking the Sender. The delay may be terminated by closing the optional channel on the // http.Request. If canceled, no further Senders are invoked. @@ -194,6 +267,7 @@ func DoRetryForAttempts(attempts int, backoff time.Duration) SendDecorator { if err != nil { return resp, err } + DrainResponseBody(resp) resp, err = s.Do(rr.Request()) if err == nil { return resp, err @@ -207,56 +281,90 @@ func DoRetryForAttempts(attempts int, backoff time.Duration) SendDecorator { } } +// Count429AsRetry indicates that a 429 response should be included as a retry attempt. +var Count429AsRetry = true + +// Max429Delay is the maximum duration to wait between retries on a 429 if no Retry-After header was received. +var Max429Delay time.Duration + // DoRetryForStatusCodes returns a SendDecorator that retries for specified statusCodes for up to the specified // number of attempts, exponentially backing off between requests using the supplied backoff -// time.Duration (which may be zero). Retrying may be canceled by closing the optional channel on -// the http.Request. +// time.Duration (which may be zero). Retrying may be canceled by cancelling the context on the http.Request. +// NOTE: Code http.StatusTooManyRequests (429) will *not* be counted against the number of attempts. func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) SendDecorator { return func(s Sender) Sender { - return SenderFunc(func(r *http.Request) (resp *http.Response, err error) { - rr := NewRetriableRequest(r) - // Increment to add the first call (attempts denotes number of retries) - attempts++ - for attempt := 0; attempt < attempts; { - err = rr.Prepare() - if err != nil { - return resp, err - } - resp, err = s.Do(rr.Request()) - // if the error isn't temporary don't bother retrying - if err != nil && !IsTemporaryNetworkError(err) { - return nil, err - } - // we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication - // resp and err will both have a value, so in this case we don't want to retry as it will never succeed. - if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) { - return resp, err - } - delayed := DelayWithRetryAfter(resp, r.Context().Done()) - if !delayed && !DelayForBackoff(backoff, attempt, r.Context().Done()) { - return nil, r.Context().Err() - } - // don't count a 429 against the number of attempts - // so that we continue to retry until it succeeds - if resp == nil || resp.StatusCode != http.StatusTooManyRequests { - attempt++ - } - } - return resp, err + return SenderFunc(func(r *http.Request) (*http.Response, error) { + return doRetryForStatusCodesImpl(s, r, Count429AsRetry, attempts, backoff, 0, codes...) }) } } -// DelayWithRetryAfter invokes time.After for the duration specified in the "Retry-After" header in -// responses with status code 429 +// DoRetryForStatusCodesWithCap returns a SendDecorator that retries for specified statusCodes for up to the +// specified number of attempts, exponentially backing off between requests using the supplied backoff +// time.Duration (which may be zero). To cap the maximum possible delay between iterations specify a value greater +// than zero for cap. Retrying may be canceled by cancelling the context on the http.Request. +func DoRetryForStatusCodesWithCap(attempts int, backoff, cap time.Duration, codes ...int) SendDecorator { + return func(s Sender) Sender { + return SenderFunc(func(r *http.Request) (*http.Response, error) { + return doRetryForStatusCodesImpl(s, r, Count429AsRetry, attempts, backoff, cap, codes...) + }) + } +} + +func doRetryForStatusCodesImpl(s Sender, r *http.Request, count429 bool, attempts int, backoff, cap time.Duration, codes ...int) (resp *http.Response, err error) { + rr := NewRetriableRequest(r) + // Increment to add the first call (attempts denotes number of retries) + for attempt, delayCount := 0, 0; attempt < attempts+1; { + err = rr.Prepare() + if err != nil { + return + } + DrainResponseBody(resp) + resp, err = s.Do(rr.Request()) + // we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication + // resp and err will both have a value, so in this case we don't want to retry as it will never succeed. + if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) { + return resp, err + } + delayed := DelayWithRetryAfter(resp, r.Context().Done()) + // if this was a 429 set the delay cap as specified. + // applicable only in the absence of a retry-after header. + if resp != nil && resp.StatusCode == http.StatusTooManyRequests { + cap = Max429Delay + } + if !delayed && !DelayForBackoffWithCap(backoff, cap, delayCount, r.Context().Done()) { + return resp, r.Context().Err() + } + // when count429 == false don't count a 429 against the number + // of attempts so that we continue to retry until it succeeds + if count429 || (resp == nil || resp.StatusCode != http.StatusTooManyRequests) { + attempt++ + } + // delay count is tracked separately from attempts to + // ensure that 429 participates in exponential back-off + delayCount++ + } + return resp, err +} + +// DelayWithRetryAfter invokes time.After for the duration specified in the "Retry-After" header. +// The value of Retry-After can be either the number of seconds or a date in RFC1123 format. +// The function returns true after successfully waiting for the specified duration. If there is +// no Retry-After header or the wait is cancelled the return value is false. func DelayWithRetryAfter(resp *http.Response, cancel <-chan struct{}) bool { if resp == nil { return false } - retryAfter, _ := strconv.Atoi(resp.Header.Get("Retry-After")) - if resp.StatusCode == http.StatusTooManyRequests && retryAfter > 0 { + var dur time.Duration + ra := resp.Header.Get("Retry-After") + if retryAfter, _ := strconv.Atoi(ra); retryAfter > 0 { + dur = time.Duration(retryAfter) * time.Second + } else if t, err := time.Parse(time.RFC1123, ra); err == nil { + dur = t.Sub(time.Now()) + } + if dur > 0 { select { - case <-time.After(time.Duration(retryAfter) * time.Second): + case <-time.After(dur): return true case <-cancel: return false @@ -279,6 +387,7 @@ func DoRetryForDuration(d time.Duration, backoff time.Duration) SendDecorator { if err != nil { return resp, err } + DrainResponseBody(resp) resp, err = s.Do(rr.Request()) if err == nil { return resp, err @@ -316,8 +425,23 @@ func WithLogging(logger *log.Logger) SendDecorator { // Note: Passing attempt 1 will result in doubling "backoff" duration. Treat this as a zero-based attempt // count. func DelayForBackoff(backoff time.Duration, attempt int, cancel <-chan struct{}) bool { + return DelayForBackoffWithCap(backoff, 0, attempt, cancel) +} + +// DelayForBackoffWithCap invokes time.After for the supplied backoff duration raised to the power of +// passed attempt (i.e., an exponential backoff delay). Backoff duration is in seconds and can set +// to zero for no delay. To cap the maximum possible delay specify a value greater than zero for cap. +// The delay may be canceled by closing the passed channel. If terminated early, returns false. +// Note: Passing attempt 1 will result in doubling "backoff" duration. Treat this as a zero-based attempt +// count. +func DelayForBackoffWithCap(backoff, cap time.Duration, attempt int, cancel <-chan struct{}) bool { + d := time.Duration(backoff.Seconds()*math.Pow(2, float64(attempt))) * time.Second + if cap > 0 && d > cap { + d = cap + } + logger.Instance.Writef(logger.LogInfo, "DelayForBackoffWithCap: sleeping for %s\n", d) select { - case <-time.After(time.Duration(backoff.Seconds()*math.Pow(2, float64(attempt))) * time.Second): + case <-time.After(d): return true case <-cancel: return false diff --git a/vendor/github.com/Azure/go-autorest/autorest/utility.go b/vendor/github.com/Azure/go-autorest/autorest/utility.go index f3a42bfc26..3467b8fa60 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/utility.go +++ b/vendor/github.com/Azure/go-autorest/autorest/utility.go @@ -20,13 +20,12 @@ import ( "encoding/xml" "fmt" "io" + "io/ioutil" "net" "net/http" "net/url" "reflect" "strings" - - "github.com/Azure/go-autorest/autorest/adal" ) // EncodedAs is a series of constants specifying various data encodings @@ -140,24 +139,24 @@ func MapToValues(m map[string]interface{}) url.Values { return v } -// AsStringSlice method converts interface{} to []string. This expects a -//that the parameter passed to be a slice or array of a type that has the underlying -//type a string. +// AsStringSlice method converts interface{} to []string. +// s must be of type slice or array or an error is returned. +// Each element of s will be converted to its string representation. func AsStringSlice(s interface{}) ([]string, error) { v := reflect.ValueOf(s) if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, NewError("autorest", "AsStringSlice", "the value's type is not an array.") + return nil, NewError("autorest", "AsStringSlice", "the value's type is not a slice or array.") } stringSlice := make([]string, 0, v.Len()) for i := 0; i < v.Len(); i++ { - stringSlice = append(stringSlice, v.Index(i).String()) + stringSlice = append(stringSlice, fmt.Sprintf("%v", v.Index(i))) } return stringSlice, nil } // String method converts interface v to string. If interface is a list, it -// joins list elements using the seperator. Note that only sep[0] will be used for +// joins list elements using the separator. Note that only sep[0] will be used for // joining if any separator is specified. func String(v interface{}, sep ...string) string { if len(sep) == 0 { @@ -206,22 +205,28 @@ func ChangeToGet(req *http.Request) *http.Request { return req } -// IsTokenRefreshError returns true if the specified error implements the TokenRefreshError -// interface. If err is a DetailedError it will walk the chain of Original errors. -func IsTokenRefreshError(err error) bool { - if _, ok := err.(adal.TokenRefreshError); ok { +// IsTemporaryNetworkError returns true if the specified error is a temporary network error or false +// if it's not. If the error doesn't implement the net.Error interface the return value is true. +func IsTemporaryNetworkError(err error) bool { + if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) { return true } - if de, ok := err.(DetailedError); ok { - return IsTokenRefreshError(de.Original) - } return false } -// IsTemporaryNetworkError returns true if the specified error is a temporary network error. -func IsTemporaryNetworkError(err error) bool { - if netErr, ok := err.(net.Error); ok && netErr.Temporary() { - return true +// DrainResponseBody reads the response body then closes it. +func DrainResponseBody(resp *http.Response) error { + if resp != nil && resp.Body != nil { + _, err := io.Copy(ioutil.Discard, resp.Body) + resp.Body.Close() + return err } - return false + return nil +} + +func setHeader(r *http.Request, key, value string) { + if r.Header == nil { + r.Header = make(http.Header) + } + r.Header.Set(key, value) } diff --git a/vendor/github.com/Azure/go-autorest/autorest/utility_1.13.go b/vendor/github.com/Azure/go-autorest/autorest/utility_1.13.go new file mode 100644 index 0000000000..4cb5e6849f --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/utility_1.13.go @@ -0,0 +1,29 @@ +// +build go1.13 + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package autorest + +import ( + "errors" + + "github.com/Azure/go-autorest/autorest/adal" +) + +// IsTokenRefreshError returns true if the specified error implements the TokenRefreshError interface. +func IsTokenRefreshError(err error) bool { + var tre adal.TokenRefreshError + return errors.As(err, &tre) +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/utility_legacy.go b/vendor/github.com/Azure/go-autorest/autorest/utility_legacy.go new file mode 100644 index 0000000000..ebb51b4f53 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/autorest/utility_legacy.go @@ -0,0 +1,31 @@ +// +build !go1.13 + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package autorest + +import "github.com/Azure/go-autorest/autorest/adal" + +// IsTokenRefreshError returns true if the specified error implements the TokenRefreshError +// interface. If err is a DetailedError it will walk the chain of Original errors. +func IsTokenRefreshError(err error) bool { + if _, ok := err.(adal.TokenRefreshError); ok { + return true + } + if de, ok := err.(DetailedError); ok { + return IsTokenRefreshError(de.Original) + } + return false +} diff --git a/vendor/github.com/Azure/go-autorest/autorest/version.go b/vendor/github.com/Azure/go-autorest/autorest/version.go index d265055f51..713e23581d 100644 --- a/vendor/github.com/Azure/go-autorest/autorest/version.go +++ b/vendor/github.com/Azure/go-autorest/autorest/version.go @@ -14,7 +14,28 @@ package autorest // See the License for the specific language governing permissions and // limitations under the License. +import ( + "fmt" + "runtime" +) + +const number = "v14.2.1" + +var ( + userAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s", + runtime.Version(), + runtime.GOARCH, + runtime.GOOS, + number, + ) +) + +// UserAgent returns a string containing the Go version, system architecture and OS, and the go-autorest version. +func UserAgent() string { + return userAgent +} + // Version returns the semantic version (see http://semver.org). func Version() string { - return "v10.8.1" + return number } diff --git a/vendor/github.com/Azure/go-autorest/doc.go b/vendor/github.com/Azure/go-autorest/doc.go new file mode 100644 index 0000000000..99ae6ca988 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/doc.go @@ -0,0 +1,18 @@ +/* +Package go-autorest provides an HTTP request client for use with Autorest-generated API client packages. +*/ +package go_autorest + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. diff --git a/vendor/github.com/Azure/go-autorest/logger/go.mod b/vendor/github.com/Azure/go-autorest/logger/go.mod new file mode 100644 index 0000000000..bedeaee039 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/logger/go.mod @@ -0,0 +1,5 @@ +module github.com/Azure/go-autorest/logger + +go 1.12 + +require github.com/Azure/go-autorest v14.2.0+incompatible diff --git a/vendor/github.com/Azure/go-autorest/logger/go_mod_tidy_hack.go b/vendor/github.com/Azure/go-autorest/logger/go_mod_tidy_hack.go new file mode 100644 index 0000000000..0aa27680db --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/logger/go_mod_tidy_hack.go @@ -0,0 +1,24 @@ +// +build modhack + +package logger + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file, and the github.com/Azure/go-autorest import, won't actually become part of +// the resultant binary. + +// Necessary for safely adding multi-module repo. +// See: https://github.com/golang/go/wiki/Modules#is-it-possible-to-add-a-module-to-a-multi-module-repository +import _ "github.com/Azure/go-autorest" diff --git a/vendor/github.com/Azure/go-autorest/logger/logger.go b/vendor/github.com/Azure/go-autorest/logger/logger.go new file mode 100644 index 0000000000..2f5d8cc1a1 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/logger/logger.go @@ -0,0 +1,337 @@ +package logger + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" +) + +// LevelType tells a logger the minimum level to log. When code reports a log entry, +// the LogLevel indicates the level of the log entry. The logger only records entries +// whose level is at least the level it was told to log. See the Log* constants. +// For example, if a logger is configured with LogError, then LogError, LogPanic, +// and LogFatal entries will be logged; lower level entries are ignored. +type LevelType uint32 + +const ( + // LogNone tells a logger not to log any entries passed to it. + LogNone LevelType = iota + + // LogFatal tells a logger to log all LogFatal entries passed to it. + LogFatal + + // LogPanic tells a logger to log all LogPanic and LogFatal entries passed to it. + LogPanic + + // LogError tells a logger to log all LogError, LogPanic and LogFatal entries passed to it. + LogError + + // LogWarning tells a logger to log all LogWarning, LogError, LogPanic and LogFatal entries passed to it. + LogWarning + + // LogInfo tells a logger to log all LogInfo, LogWarning, LogError, LogPanic and LogFatal entries passed to it. + LogInfo + + // LogDebug tells a logger to log all LogDebug, LogInfo, LogWarning, LogError, LogPanic and LogFatal entries passed to it. + LogDebug + + // LogAuth is a special case of LogDebug, it tells a logger to also log the body of an authentication request and response. + // NOTE: this can disclose sensitive information, use with care. + LogAuth +) + +const ( + logNone = "NONE" + logFatal = "FATAL" + logPanic = "PANIC" + logError = "ERROR" + logWarning = "WARNING" + logInfo = "INFO" + logDebug = "DEBUG" + logAuth = "AUTH" + logUnknown = "UNKNOWN" +) + +// ParseLevel converts the specified string into the corresponding LevelType. +func ParseLevel(s string) (lt LevelType, err error) { + switch strings.ToUpper(s) { + case logFatal: + lt = LogFatal + case logPanic: + lt = LogPanic + case logError: + lt = LogError + case logWarning: + lt = LogWarning + case logInfo: + lt = LogInfo + case logDebug: + lt = LogDebug + case logAuth: + lt = LogAuth + default: + err = fmt.Errorf("bad log level '%s'", s) + } + return +} + +// String implements the stringer interface for LevelType. +func (lt LevelType) String() string { + switch lt { + case LogNone: + return logNone + case LogFatal: + return logFatal + case LogPanic: + return logPanic + case LogError: + return logError + case LogWarning: + return logWarning + case LogInfo: + return logInfo + case LogDebug: + return logDebug + case LogAuth: + return logAuth + default: + return logUnknown + } +} + +// Filter defines functions for filtering HTTP request/response content. +type Filter struct { + // URL returns a potentially modified string representation of a request URL. + URL func(u *url.URL) string + + // Header returns a potentially modified set of values for the specified key. + // To completely exclude the header key/values return false. + Header func(key string, val []string) (bool, []string) + + // Body returns a potentially modified request/response body. + Body func(b []byte) []byte +} + +func (f Filter) processURL(u *url.URL) string { + if f.URL == nil { + return u.String() + } + return f.URL(u) +} + +func (f Filter) processHeader(k string, val []string) (bool, []string) { + if f.Header == nil { + return true, val + } + return f.Header(k, val) +} + +func (f Filter) processBody(b []byte) []byte { + if f.Body == nil { + return b + } + return f.Body(b) +} + +// Writer defines methods for writing to a logging facility. +type Writer interface { + // Writeln writes the specified message with the standard log entry header and new-line character. + Writeln(level LevelType, message string) + + // Writef writes the specified format specifier with the standard log entry header and no new-line character. + Writef(level LevelType, format string, a ...interface{}) + + // WriteRequest writes the specified HTTP request to the logger if the log level is greater than + // or equal to LogInfo. The request body, if set, is logged at level LogDebug or higher. + // Custom filters can be specified to exclude URL, header, and/or body content from the log. + // By default no request content is excluded. + WriteRequest(req *http.Request, filter Filter) + + // WriteResponse writes the specified HTTP response to the logger if the log level is greater than + // or equal to LogInfo. The response body, if set, is logged at level LogDebug or higher. + // Custom filters can be specified to exclude URL, header, and/or body content from the log. + // By default no response content is excluded. + WriteResponse(resp *http.Response, filter Filter) +} + +// Instance is the default log writer initialized during package init. +// This can be replaced with a custom implementation as required. +var Instance Writer + +// default log level +var logLevel = LogNone + +// Level returns the value specified in AZURE_GO_AUTOREST_LOG_LEVEL. +// If no value was specified the default value is LogNone. +// Custom loggers can call this to retrieve the configured log level. +func Level() LevelType { + return logLevel +} + +func init() { + // separated for testing purposes + initDefaultLogger() +} + +func initDefaultLogger() { + // init with nilLogger so callers don't have to do a nil check on Default + Instance = nilLogger{} + llStr := strings.ToLower(os.Getenv("AZURE_GO_SDK_LOG_LEVEL")) + if llStr == "" { + return + } + var err error + logLevel, err = ParseLevel(llStr) + if err != nil { + fmt.Fprintf(os.Stderr, "go-autorest: failed to parse log level: %s\n", err.Error()) + return + } + if logLevel == LogNone { + return + } + // default to stderr + dest := os.Stderr + lfStr := os.Getenv("AZURE_GO_SDK_LOG_FILE") + if strings.EqualFold(lfStr, "stdout") { + dest = os.Stdout + } else if lfStr != "" { + lf, err := os.Create(lfStr) + if err == nil { + dest = lf + } else { + fmt.Fprintf(os.Stderr, "go-autorest: failed to create log file, using stderr: %s\n", err.Error()) + } + } + Instance = fileLogger{ + logLevel: logLevel, + mu: &sync.Mutex{}, + logFile: dest, + } +} + +// the nil logger does nothing +type nilLogger struct{} + +func (nilLogger) Writeln(LevelType, string) {} + +func (nilLogger) Writef(LevelType, string, ...interface{}) {} + +func (nilLogger) WriteRequest(*http.Request, Filter) {} + +func (nilLogger) WriteResponse(*http.Response, Filter) {} + +// A File is used instead of a Logger so the stream can be flushed after every write. +type fileLogger struct { + logLevel LevelType + mu *sync.Mutex // for synchronizing writes to logFile + logFile *os.File +} + +func (fl fileLogger) Writeln(level LevelType, message string) { + fl.Writef(level, "%s\n", message) +} + +func (fl fileLogger) Writef(level LevelType, format string, a ...interface{}) { + if fl.logLevel >= level { + fl.mu.Lock() + defer fl.mu.Unlock() + fmt.Fprintf(fl.logFile, "%s %s", entryHeader(level), fmt.Sprintf(format, a...)) + fl.logFile.Sync() + } +} + +func (fl fileLogger) WriteRequest(req *http.Request, filter Filter) { + if req == nil || fl.logLevel < LogInfo { + return + } + b := &bytes.Buffer{} + fmt.Fprintf(b, "%s REQUEST: %s %s\n", entryHeader(LogInfo), req.Method, filter.processURL(req.URL)) + // dump headers + for k, v := range req.Header { + if ok, mv := filter.processHeader(k, v); ok { + fmt.Fprintf(b, "%s: %s\n", k, strings.Join(mv, ",")) + } + } + if fl.shouldLogBody(req.Header, req.Body) { + // dump body + body, err := ioutil.ReadAll(req.Body) + if err == nil { + fmt.Fprintln(b, string(filter.processBody(body))) + if nc, ok := req.Body.(io.Seeker); ok { + // rewind to the beginning + nc.Seek(0, io.SeekStart) + } else { + // recreate the body + req.Body = ioutil.NopCloser(bytes.NewReader(body)) + } + } else { + fmt.Fprintf(b, "failed to read body: %v\n", err) + } + } + fl.mu.Lock() + defer fl.mu.Unlock() + fmt.Fprint(fl.logFile, b.String()) + fl.logFile.Sync() +} + +func (fl fileLogger) WriteResponse(resp *http.Response, filter Filter) { + if resp == nil || fl.logLevel < LogInfo { + return + } + b := &bytes.Buffer{} + fmt.Fprintf(b, "%s RESPONSE: %d %s\n", entryHeader(LogInfo), resp.StatusCode, filter.processURL(resp.Request.URL)) + // dump headers + for k, v := range resp.Header { + if ok, mv := filter.processHeader(k, v); ok { + fmt.Fprintf(b, "%s: %s\n", k, strings.Join(mv, ",")) + } + } + if fl.shouldLogBody(resp.Header, resp.Body) { + // dump body + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err == nil { + fmt.Fprintln(b, string(filter.processBody(body))) + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + } else { + fmt.Fprintf(b, "failed to read body: %v\n", err) + } + } + fl.mu.Lock() + defer fl.mu.Unlock() + fmt.Fprint(fl.logFile, b.String()) + fl.logFile.Sync() +} + +// returns true if the provided body should be included in the log +func (fl fileLogger) shouldLogBody(header http.Header, body io.ReadCloser) bool { + ct := header.Get("Content-Type") + return fl.logLevel >= LogDebug && body != nil && !strings.Contains(ct, "application/octet-stream") +} + +// creates standard header for log entries, it contains a timestamp and the log level +func entryHeader(level LevelType) string { + // this format provides a fixed number of digits so the size of the timestamp is constant + return fmt.Sprintf("(%s) %s:", time.Now().Format("2006-01-02T15:04:05.0000000Z07:00"), level.String()) +} diff --git a/vendor/github.com/Azure/go-autorest/tracing/go.mod b/vendor/github.com/Azure/go-autorest/tracing/go.mod new file mode 100644 index 0000000000..a2cdec78c8 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/tracing/go.mod @@ -0,0 +1,5 @@ +module github.com/Azure/go-autorest/tracing + +go 1.12 + +require github.com/Azure/go-autorest v14.2.0+incompatible diff --git a/vendor/github.com/Azure/go-autorest/tracing/go_mod_tidy_hack.go b/vendor/github.com/Azure/go-autorest/tracing/go_mod_tidy_hack.go new file mode 100644 index 0000000000..e163975cd4 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/tracing/go_mod_tidy_hack.go @@ -0,0 +1,24 @@ +// +build modhack + +package tracing + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file, and the github.com/Azure/go-autorest import, won't actually become part of +// the resultant binary. + +// Necessary for safely adding multi-module repo. +// See: https://github.com/golang/go/wiki/Modules#is-it-possible-to-add-a-module-to-a-multi-module-repository +import _ "github.com/Azure/go-autorest" diff --git a/vendor/github.com/Azure/go-autorest/tracing/tracing.go b/vendor/github.com/Azure/go-autorest/tracing/tracing.go new file mode 100644 index 0000000000..0e7a6e9625 --- /dev/null +++ b/vendor/github.com/Azure/go-autorest/tracing/tracing.go @@ -0,0 +1,67 @@ +package tracing + +// Copyright 2018 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import ( + "context" + "net/http" +) + +// Tracer represents an HTTP tracing facility. +type Tracer interface { + NewTransport(base *http.Transport) http.RoundTripper + StartSpan(ctx context.Context, name string) context.Context + EndSpan(ctx context.Context, httpStatusCode int, err error) +} + +var ( + tracer Tracer +) + +// Register will register the provided Tracer. Pass nil to unregister a Tracer. +func Register(t Tracer) { + tracer = t +} + +// IsEnabled returns true if a Tracer has been registered. +func IsEnabled() bool { + return tracer != nil +} + +// NewTransport creates a new instrumenting http.RoundTripper for the +// registered Tracer. If no Tracer has been registered it returns nil. +func NewTransport(base *http.Transport) http.RoundTripper { + if tracer != nil { + return tracer.NewTransport(base) + } + return nil +} + +// StartSpan starts a trace span with the specified name, associating it with the +// provided context. Has no effect if a Tracer has not been registered. +func StartSpan(ctx context.Context, name string) context.Context { + if tracer != nil { + return tracer.StartSpan(ctx, name) + } + return ctx +} + +// EndSpan ends a previously started span stored in the context. +// Has no effect if a Tracer has not been registered. +func EndSpan(ctx context.Context, httpStatusCode int, err error) { + if tracer != nil { + tracer.EndSpan(ctx, httpStatusCode, err) + } +} diff --git a/vendor/github.com/dgrijalva/jwt-go/LICENSE b/vendor/github.com/form3tech-oss/jwt-go/LICENSE similarity index 100% rename from vendor/github.com/dgrijalva/jwt-go/LICENSE rename to vendor/github.com/form3tech-oss/jwt-go/LICENSE diff --git a/vendor/github.com/dgrijalva/jwt-go/README.md b/vendor/github.com/form3tech-oss/jwt-go/README.md similarity index 63% rename from vendor/github.com/dgrijalva/jwt-go/README.md rename to vendor/github.com/form3tech-oss/jwt-go/README.md index f48365fafb..d7749077fd 100644 --- a/vendor/github.com/dgrijalva/jwt-go/README.md +++ b/vendor/github.com/form3tech-oss/jwt-go/README.md @@ -1,11 +1,15 @@ -A [go](http://www.golang.org) (or 'golang' for search engine friendliness) implementation of [JSON Web Tokens](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html) +# jwt-go [![Build Status](https://travis-ci.org/dgrijalva/jwt-go.svg?branch=master)](https://travis-ci.org/dgrijalva/jwt-go) +[![GoDoc](https://godoc.org/github.com/dgrijalva/jwt-go?status.svg)](https://godoc.org/github.com/dgrijalva/jwt-go) + +A [go](http://www.golang.org) (or 'golang' for search engine friendliness) implementation of [JSON Web Tokens](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html) -**BREAKING CHANGES:*** Version 3.0.0 is here. It includes _a lot_ of changes including a few that break the API. We've tried to break as few things as possible, so there should just be a few type signature changes. A full list of breaking changes is available in `VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating your code. +**NEW VERSION COMING:** There have been a lot of improvements suggested since the version 3.0.0 released in 2016. I'm working now on cutting two different releases: 3.2.0 will contain any non-breaking changes or enhancements. 4.0.0 will follow shortly which will include breaking changes. See the 4.0.0 milestone to get an idea of what's coming. If you have other ideas, or would like to participate in 4.0.0, now's the time. If you depend on this library and don't want to be interrupted, I recommend you use your dependency mangement tool to pin to version 3. -**NOTICE:** A vulnerability in JWT was [recently published](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/). As this library doesn't force users to validate the `alg` is what they expected, it's possible your usage is effected. There will be an update soon to remedy this, and it will likey require backwards-incompatible changes to the API. In the short term, please make sure your implementation verifies the `alg` is what you expect. +**SECURITY NOTICE:** Some older versions of Go have a security issue in the cryotp/elliptic. Recommendation is to upgrade to at least 1.8.3. See issue #216 for more detail. +**SECURITY NOTICE:** It's important that you [validate the `alg` presented is what you expect](https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/). This library attempts to make it easy to do the right thing by requiring key types match the expected alg, but you should take the extra step to verify it in your usage. See the examples provided. ## What the heck is a JWT? @@ -15,7 +19,7 @@ In short, it's a signed JSON object that does something useful (for example, aut The first part is called the header. It contains the necessary information for verifying the last part, the signature. For example, which encryption method was used for signing and what key was used. -The part in the middle is the interesting bit. It's called the Claims and contains the actual stuff you care about. Refer to [the RFC](http://self-issued.info/docs/draft-jones-json-web-token.html) for information about reserved keys and the proper way to add your own. +The part in the middle is the interesting bit. It's called the Claims and contains the actual stuff you care about. Refer to [the RFC](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html) for information about reserved keys and the proper way to add your own. ## What's in the box? @@ -33,11 +37,11 @@ See [the project documentation](https://godoc.org/github.com/dgrijalva/jwt-go) f This library publishes all the necessary components for adding your own signing methods. Simply implement the `SigningMethod` interface and register a factory method using `RegisterSigningMethod`. -Here's an example of an extension that integrates with the Google App Engine signing tools: https://github.com/someone1/gcp-jwt-go +Here's an example of an extension that integrates with multiple Google Cloud Platform signing tools (AppEngine, IAM API, Cloud KMS): https://github.com/someone1/gcp-jwt-go ## Compliance -This library was last reviewed to comply with [RTF 7519](http://www.rfc-editor.org/info/rfc7519) dated May 2015 with a few notable differences: +This library was last reviewed to comply with [RTF 7519](http://www.rfc-editor.org/info/rfc7519) dated May 2015 with a few notable differences: * In order to protect against accidental use of [Unsecured JWTs](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#UnsecuredJWT), tokens using `alg=none` will only be accepted if the constant `jwt.UnsafeAllowNoneSignatureType` is provided as the key. @@ -47,7 +51,10 @@ This library is considered production ready. Feedback and feature requests are This project uses [Semantic Versioning 2.0.0](http://semver.org). Accepted pull requests will land on `master`. Periodically, versions will be tagged from `master`. You can find all the releases on [the project releases page](https://github.com/dgrijalva/jwt-go/releases). -While we try to make it obvious when we make breaking changes, there isn't a great mechanism for pushing announcements out to users. You may want to use this alternative package include: `gopkg.in/dgrijalva/jwt-go.v2`. It will do the right thing WRT semantic versioning. +While we try to make it obvious when we make breaking changes, there isn't a great mechanism for pushing announcements out to users. You may want to use this alternative package include: `gopkg.in/dgrijalva/jwt-go.v3`. It will do the right thing WRT semantic versioning. + +**BREAKING CHANGES:*** +* Version 3.0.0 includes _a lot_ of changes from the 2.x line, including a few that break the API. We've tried to break as few things as possible, so there should just be a few type signature changes. A full list of breaking changes is available in `VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating your code. ## Usage Tips @@ -68,18 +75,30 @@ Symmetric signing methods, such as HSA, use only a single secret. This is probab Asymmetric signing methods, such as RSA, use different keys for signing and verifying tokens. This makes it possible to produce tokens with a private key, and allow any consumer to access the public key for verification. +### Signing Methods and Key Types + +Each signing method expects a different object type for its signing keys. See the package documentation for details. Here are the most common ones: + +* The [HMAC signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodHMAC) (`HS256`,`HS384`,`HS512`) expect `[]byte` values for signing and validation +* The [RSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodRSA) (`RS256`,`RS384`,`RS512`) expect `*rsa.PrivateKey` for signing and `*rsa.PublicKey` for validation +* The [ECDSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodECDSA) (`ES256`,`ES384`,`ES512`) expect `*ecdsa.PrivateKey` for signing and `*ecdsa.PublicKey` for validation + ### JWT and OAuth It's worth mentioning that OAuth and JWT are not the same thing. A JWT token is simply a signed JSON object. It can be used anywhere such a thing is useful. There is some confusion, though, as JWT is the most common type of bearer token used in OAuth2 authentication. Without going too far down the rabbit hole, here's a description of the interaction of these technologies: -* OAuth is a protocol for allowing an identity provider to be separate from the service a user is logging in to. For example, whenever you use Facebook to log into a different service (Yelp, Spotify, etc), you are using OAuth. +* OAuth is a protocol for allowing an identity provider to be separate from the service a user is logging in to. For example, whenever you use Facebook to log into a different service (Yelp, Spotify, etc), you are using OAuth. * OAuth defines several options for passing around authentication data. One popular method is called a "bearer token". A bearer token is simply a string that _should_ only be held by an authenticated user. Thus, simply presenting this token proves your identity. You can probably derive from here why a JWT might make a good bearer token. * Because bearer tokens are used for authentication, it's important they're kept secret. This is why transactions that use bearer tokens typically happen over SSL. - + +### Troubleshooting + +This library uses descriptive error messages whenever possible. If you are not getting the expected result, have a look at the errors. The most common place people get stuck is providing the correct type of key to the parser. See the above section on signing methods and key types. + ## More Documentation can be found [on godoc.org](http://godoc.org/github.com/dgrijalva/jwt-go). -The command line utility included in this project (cmd/jwt) provides a straightforward example of token creation and parsing as well as a useful tool for debugging your own integration. You'll also find several implementation examples in to documentation. +The command line utility included in this project (cmd/jwt) provides a straightforward example of token creation and parsing as well as a useful tool for debugging your own integration. You'll also find several implementation examples in the documentation. diff --git a/vendor/github.com/dgrijalva/jwt-go/claims.go b/vendor/github.com/form3tech-oss/jwt-go/claims.go similarity index 93% rename from vendor/github.com/dgrijalva/jwt-go/claims.go rename to vendor/github.com/form3tech-oss/jwt-go/claims.go index f0228f02e0..624890666c 100644 --- a/vendor/github.com/dgrijalva/jwt-go/claims.go +++ b/vendor/github.com/form3tech-oss/jwt-go/claims.go @@ -16,7 +16,7 @@ type Claims interface { // https://tools.ietf.org/html/rfc7519#section-4.1 // See examples for how to use this with your own claim types type StandardClaims struct { - Audience string `json:"aud,omitempty"` + Audience []string `json:"aud,omitempty"` ExpiresAt int64 `json:"exp,omitempty"` Id string `json:"jti,omitempty"` IssuedAt int64 `json:"iat,omitempty"` @@ -90,15 +90,17 @@ func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool { // ----- helpers -func verifyAud(aud string, cmp string, required bool) bool { - if aud == "" { +func verifyAud(aud []string, cmp string, required bool) bool { + if len(aud) == 0 { return !required } - if subtle.ConstantTimeCompare([]byte(aud), []byte(cmp)) != 0 { - return true - } else { - return false + + for _, a := range aud { + if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { + return true + } } + return false } func verifyExp(exp int64, now int64, required bool) bool { diff --git a/vendor/github.com/dgrijalva/jwt-go/doc.go b/vendor/github.com/form3tech-oss/jwt-go/doc.go similarity index 100% rename from vendor/github.com/dgrijalva/jwt-go/doc.go rename to vendor/github.com/form3tech-oss/jwt-go/doc.go diff --git a/vendor/github.com/dgrijalva/jwt-go/ecdsa.go b/vendor/github.com/form3tech-oss/jwt-go/ecdsa.go similarity index 97% rename from vendor/github.com/dgrijalva/jwt-go/ecdsa.go rename to vendor/github.com/form3tech-oss/jwt-go/ecdsa.go index 2f59a22236..f977381240 100644 --- a/vendor/github.com/dgrijalva/jwt-go/ecdsa.go +++ b/vendor/github.com/form3tech-oss/jwt-go/ecdsa.go @@ -14,6 +14,7 @@ var ( ) // Implements the ECDSA family of signing methods signing methods +// Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification type SigningMethodECDSA struct { Name string Hash crypto.Hash diff --git a/vendor/github.com/dgrijalva/jwt-go/ecdsa_utils.go b/vendor/github.com/form3tech-oss/jwt-go/ecdsa_utils.go similarity index 93% rename from vendor/github.com/dgrijalva/jwt-go/ecdsa_utils.go rename to vendor/github.com/form3tech-oss/jwt-go/ecdsa_utils.go index d19624b726..db9f4be7d8 100644 --- a/vendor/github.com/dgrijalva/jwt-go/ecdsa_utils.go +++ b/vendor/github.com/form3tech-oss/jwt-go/ecdsa_utils.go @@ -25,7 +25,9 @@ func ParseECPrivateKeyFromPEM(key []byte) (*ecdsa.PrivateKey, error) { // Parse the key var parsedKey interface{} if parsedKey, err = x509.ParseECPrivateKey(block.Bytes); err != nil { - return nil, err + if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil { + return nil, err + } } var pkey *ecdsa.PrivateKey diff --git a/vendor/github.com/dgrijalva/jwt-go/errors.go b/vendor/github.com/form3tech-oss/jwt-go/errors.go similarity index 100% rename from vendor/github.com/dgrijalva/jwt-go/errors.go rename to vendor/github.com/form3tech-oss/jwt-go/errors.go diff --git a/vendor/github.com/dgrijalva/jwt-go/hmac.go b/vendor/github.com/form3tech-oss/jwt-go/hmac.go similarity index 96% rename from vendor/github.com/dgrijalva/jwt-go/hmac.go rename to vendor/github.com/form3tech-oss/jwt-go/hmac.go index c229919254..addbe5d401 100644 --- a/vendor/github.com/dgrijalva/jwt-go/hmac.go +++ b/vendor/github.com/form3tech-oss/jwt-go/hmac.go @@ -7,6 +7,7 @@ import ( ) // Implements the HMAC-SHA family of signing methods signing methods +// Expects key type of []byte for both signing and validation type SigningMethodHMAC struct { Name string Hash crypto.Hash @@ -90,5 +91,5 @@ func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, return EncodeSegment(hasher.Sum(nil)), nil } - return "", ErrInvalidKey + return "", ErrInvalidKeyType } diff --git a/vendor/github.com/dgrijalva/jwt-go/map_claims.go b/vendor/github.com/form3tech-oss/jwt-go/map_claims.go similarity index 94% rename from vendor/github.com/dgrijalva/jwt-go/map_claims.go rename to vendor/github.com/form3tech-oss/jwt-go/map_claims.go index 291213c460..90ab6bea35 100644 --- a/vendor/github.com/dgrijalva/jwt-go/map_claims.go +++ b/vendor/github.com/form3tech-oss/jwt-go/map_claims.go @@ -13,7 +13,15 @@ type MapClaims map[string]interface{} // Compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyAudience(cmp string, req bool) bool { - aud, _ := m["aud"].(string) + aud, ok := m["aud"].([]string) + if !ok { + strAud, ok := m["aud"].(string) + if !ok { + return false + } + aud = append(aud, strAud) + } + return verifyAud(aud, cmp, req) } diff --git a/vendor/github.com/dgrijalva/jwt-go/none.go b/vendor/github.com/form3tech-oss/jwt-go/none.go similarity index 100% rename from vendor/github.com/dgrijalva/jwt-go/none.go rename to vendor/github.com/form3tech-oss/jwt-go/none.go diff --git a/vendor/github.com/dgrijalva/jwt-go/parser.go b/vendor/github.com/form3tech-oss/jwt-go/parser.go similarity index 68% rename from vendor/github.com/dgrijalva/jwt-go/parser.go rename to vendor/github.com/form3tech-oss/jwt-go/parser.go index 7bf1c4ea08..d6901d9adb 100644 --- a/vendor/github.com/dgrijalva/jwt-go/parser.go +++ b/vendor/github.com/form3tech-oss/jwt-go/parser.go @@ -21,55 +21,9 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { } func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) { - parts := strings.Split(tokenString, ".") - if len(parts) != 3 { - return nil, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed) - } - - var err error - token := &Token{Raw: tokenString} - - // parse Header - var headerBytes []byte - if headerBytes, err = DecodeSegment(parts[0]); err != nil { - if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") { - return token, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed) - } - return token, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} - } - if err = json.Unmarshal(headerBytes, &token.Header); err != nil { - return token, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} - } - - // parse Claims - var claimBytes []byte - token.Claims = claims - - if claimBytes, err = DecodeSegment(parts[1]); err != nil { - return token, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} - } - dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) - if p.UseJSONNumber { - dec.UseNumber() - } - // JSON Decode. Special case for map type to avoid weird pointer behavior - if c, ok := token.Claims.(MapClaims); ok { - err = dec.Decode(&c) - } else { - err = dec.Decode(&claims) - } - // Handle decode error + token, parts, err := p.ParseUnverified(tokenString, claims) if err != nil { - return token, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} - } - - // Lookup signature method - if method, ok := token.Header["alg"].(string); ok { - if token.Method = GetSigningMethod(method); token.Method == nil { - return token, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable) - } - } else { - return token, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable) + return token, err } // Verify signing method is in the required set @@ -96,6 +50,9 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf } if key, err = keyFunc(token); err != nil { // keyFunc returned an error + if ve, ok := err.(*ValidationError); ok { + return token, ve + } return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable} } @@ -129,3 +86,63 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf return token, vErr } + +// WARNING: Don't use this method unless you know what you're doing +// +// This method parses the token but doesn't validate the signature. It's only +// ever useful in cases where you know the signature is valid (because it has +// been checked previously in the stack) and you want to extract values from +// it. +func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) { + parts = strings.Split(tokenString, ".") + if len(parts) != 3 { + return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed) + } + + token = &Token{Raw: tokenString} + + // parse Header + var headerBytes []byte + if headerBytes, err = DecodeSegment(parts[0]); err != nil { + if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") { + return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed) + } + return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + } + if err = json.Unmarshal(headerBytes, &token.Header); err != nil { + return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + } + + // parse Claims + var claimBytes []byte + token.Claims = claims + + if claimBytes, err = DecodeSegment(parts[1]); err != nil { + return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + } + dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) + if p.UseJSONNumber { + dec.UseNumber() + } + // JSON Decode. Special case for map type to avoid weird pointer behavior + if c, ok := token.Claims.(MapClaims); ok { + err = dec.Decode(&c) + } else { + err = dec.Decode(&claims) + } + // Handle decode error + if err != nil { + return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + } + + // Lookup signature method + if method, ok := token.Header["alg"].(string); ok { + if token.Method = GetSigningMethod(method); token.Method == nil { + return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable) + } + } else { + return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable) + } + + return token, parts, nil +} diff --git a/vendor/github.com/dgrijalva/jwt-go/rsa.go b/vendor/github.com/form3tech-oss/jwt-go/rsa.go similarity index 91% rename from vendor/github.com/dgrijalva/jwt-go/rsa.go rename to vendor/github.com/form3tech-oss/jwt-go/rsa.go index 0ae0b1984e..e4caf1ca4a 100644 --- a/vendor/github.com/dgrijalva/jwt-go/rsa.go +++ b/vendor/github.com/form3tech-oss/jwt-go/rsa.go @@ -7,6 +7,7 @@ import ( ) // Implements the RSA family of signing methods signing methods +// Expects *rsa.PrivateKey for signing and *rsa.PublicKey for validation type SigningMethodRSA struct { Name string Hash crypto.Hash @@ -44,7 +45,7 @@ func (m *SigningMethodRSA) Alg() string { } // Implements the Verify method from SigningMethod -// For this signing method, must be an rsa.PublicKey structure. +// For this signing method, must be an *rsa.PublicKey structure. func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error { var err error @@ -73,7 +74,7 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface } // Implements the Sign method from SigningMethod -// For this signing method, must be an rsa.PrivateKey structure. +// For this signing method, must be an *rsa.PrivateKey structure. func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) { var rsaKey *rsa.PrivateKey var ok bool diff --git a/vendor/github.com/dgrijalva/jwt-go/rsa_pss.go b/vendor/github.com/form3tech-oss/jwt-go/rsa_pss.go similarity index 71% rename from vendor/github.com/dgrijalva/jwt-go/rsa_pss.go rename to vendor/github.com/form3tech-oss/jwt-go/rsa_pss.go index 10ee9db8a4..c014708648 100644 --- a/vendor/github.com/dgrijalva/jwt-go/rsa_pss.go +++ b/vendor/github.com/form3tech-oss/jwt-go/rsa_pss.go @@ -12,9 +12,14 @@ import ( type SigningMethodRSAPSS struct { *SigningMethodRSA Options *rsa.PSSOptions + // VerifyOptions is optional. If set overrides Options for rsa.VerifyPPS. + // Used to accept tokens signed with rsa.PSSSaltLengthAuto, what doesn't follow + // https://tools.ietf.org/html/rfc7518#section-3.5 but was used previously. + // See https://github.com/dgrijalva/jwt-go/issues/285#issuecomment-437451244 for details. + VerifyOptions *rsa.PSSOptions } -// Specific instances for RS/PS and company +// Specific instances for RS/PS and company. var ( SigningMethodPS256 *SigningMethodRSAPSS SigningMethodPS384 *SigningMethodRSAPSS @@ -24,13 +29,15 @@ var ( func init() { // PS256 SigningMethodPS256 = &SigningMethodRSAPSS{ - &SigningMethodRSA{ + SigningMethodRSA: &SigningMethodRSA{ Name: "PS256", Hash: crypto.SHA256, }, - &rsa.PSSOptions{ + Options: &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + }, + VerifyOptions: &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthAuto, - Hash: crypto.SHA256, }, } RegisterSigningMethod(SigningMethodPS256.Alg(), func() SigningMethod { @@ -39,13 +46,15 @@ func init() { // PS384 SigningMethodPS384 = &SigningMethodRSAPSS{ - &SigningMethodRSA{ + SigningMethodRSA: &SigningMethodRSA{ Name: "PS384", Hash: crypto.SHA384, }, - &rsa.PSSOptions{ + Options: &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + }, + VerifyOptions: &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthAuto, - Hash: crypto.SHA384, }, } RegisterSigningMethod(SigningMethodPS384.Alg(), func() SigningMethod { @@ -54,13 +63,15 @@ func init() { // PS512 SigningMethodPS512 = &SigningMethodRSAPSS{ - &SigningMethodRSA{ + SigningMethodRSA: &SigningMethodRSA{ Name: "PS512", Hash: crypto.SHA512, }, - &rsa.PSSOptions{ + Options: &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + }, + VerifyOptions: &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthAuto, - Hash: crypto.SHA512, }, } RegisterSigningMethod(SigningMethodPS512.Alg(), func() SigningMethod { @@ -94,7 +105,12 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf hasher := m.Hash.New() hasher.Write([]byte(signingString)) - return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, m.Options) + opts := m.Options + if m.VerifyOptions != nil { + opts = m.VerifyOptions + } + + return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts) } // Implements the Sign method from SigningMethod diff --git a/vendor/github.com/dgrijalva/jwt-go/rsa_utils.go b/vendor/github.com/form3tech-oss/jwt-go/rsa_utils.go similarity index 62% rename from vendor/github.com/dgrijalva/jwt-go/rsa_utils.go rename to vendor/github.com/form3tech-oss/jwt-go/rsa_utils.go index 213a90dbbf..14c78c292a 100644 --- a/vendor/github.com/dgrijalva/jwt-go/rsa_utils.go +++ b/vendor/github.com/form3tech-oss/jwt-go/rsa_utils.go @@ -8,7 +8,7 @@ import ( ) var ( - ErrKeyMustBePEMEncoded = errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key") + ErrKeyMustBePEMEncoded = errors.New("Invalid Key: Key must be a PEM encoded PKCS1 or PKCS8 key") ErrNotRSAPrivateKey = errors.New("Key is not a valid RSA private key") ErrNotRSAPublicKey = errors.New("Key is not a valid RSA public key") ) @@ -39,6 +39,38 @@ func ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error) { return pkey, nil } +// Parse PEM encoded PKCS1 or PKCS8 private key protected with password +func ParseRSAPrivateKeyFromPEMWithPassword(key []byte, password string) (*rsa.PrivateKey, error) { + var err error + + // Parse PEM block + var block *pem.Block + if block, _ = pem.Decode(key); block == nil { + return nil, ErrKeyMustBePEMEncoded + } + + var parsedKey interface{} + + var blockDecrypted []byte + if blockDecrypted, err = x509.DecryptPEMBlock(block, []byte(password)); err != nil { + return nil, err + } + + if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil { + if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil { + return nil, err + } + } + + var pkey *rsa.PrivateKey + var ok bool + if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok { + return nil, ErrNotRSAPrivateKey + } + + return pkey, nil +} + // Parse PEM encoded PKCS1 or PKCS8 public key func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) { var err error diff --git a/vendor/github.com/dgrijalva/jwt-go/signing_method.go b/vendor/github.com/form3tech-oss/jwt-go/signing_method.go similarity index 100% rename from vendor/github.com/dgrijalva/jwt-go/signing_method.go rename to vendor/github.com/form3tech-oss/jwt-go/signing_method.go diff --git a/vendor/github.com/dgrijalva/jwt-go/token.go b/vendor/github.com/form3tech-oss/jwt-go/token.go similarity index 100% rename from vendor/github.com/dgrijalva/jwt-go/token.go rename to vendor/github.com/form3tech-oss/jwt-go/token.go diff --git a/vendor/golang.org/x/crypto/README b/vendor/golang.org/x/crypto/README deleted file mode 100644 index f1e0cbf94e..0000000000 --- a/vendor/golang.org/x/crypto/README +++ /dev/null @@ -1,3 +0,0 @@ -This repository holds supplementary Go cryptography libraries. - -To submit changes to this repository, see http://golang.org/doc/contribute.html. diff --git a/vendor/golang.org/x/crypto/README.md b/vendor/golang.org/x/crypto/README.md new file mode 100644 index 0000000000..c9d6fecd1e --- /dev/null +++ b/vendor/golang.org/x/crypto/README.md @@ -0,0 +1,21 @@ +# Go Cryptography + +This repository holds supplementary Go cryptography libraries. + +## Download/Install + +The easiest way to install is to run `go get -u golang.org/x/crypto/...`. You +can also manually git clone the repository to `$GOPATH/src/golang.org/x/crypto`. + +## Report Issues / Send Patches + +This repository uses Gerrit for code changes. To learn how to submit changes to +this repository, see https://golang.org/doc/contribute.html. + +The main issue tracker for the crypto repository is located at +https://github.com/golang/go/issues. Prefix your issue with "x/crypto:" in the +subject line, so it is easy to find. + +Note that contributions to the cryptography package receive additional scrutiny +due to their sensitive nature. Patches may take longer than normal to receive +feedback. diff --git a/vendor/golang.org/x/crypto/bcrypt/bcrypt.go b/vendor/golang.org/x/crypto/bcrypt/bcrypt.go index f8b807f9c3..aeb73f81a1 100644 --- a/vendor/golang.org/x/crypto/bcrypt/bcrypt.go +++ b/vendor/golang.org/x/crypto/bcrypt/bcrypt.go @@ -12,9 +12,10 @@ import ( "crypto/subtle" "errors" "fmt" - "golang.org/x/crypto/blowfish" "io" "strconv" + + "golang.org/x/crypto/blowfish" ) const ( @@ -205,7 +206,6 @@ func bcrypt(password []byte, cost int, salt []byte) ([]byte, error) { } func expensiveBlowfishSetup(key []byte, cost uint32, salt []byte) (*blowfish.Cipher, error) { - csalt, err := base64Decode(salt) if err != nil { return nil, err @@ -213,7 +213,8 @@ func expensiveBlowfishSetup(key []byte, cost uint32, salt []byte) (*blowfish.Cip // Bug compatibility with C bcrypt implementations. They use the trailing // NULL in the key string during expansion. - ckey := append(key, 0) + // We copy the key to prevent changing the underlying array. + ckey := append(key[:len(key):len(key)], 0) c, err := blowfish.NewSaltedCipher(ckey, csalt) if err != nil { @@ -240,11 +241,11 @@ func (p *hashed) Hash() []byte { n = 3 } arr[n] = '$' - n += 1 + n++ copy(arr[n:], []byte(fmt.Sprintf("%02d", p.cost))) n += 2 arr[n] = '$' - n += 1 + n++ copy(arr[n:], p.salt) n += encodedSaltSize copy(arr[n:], p.hash) diff --git a/vendor/golang.org/x/crypto/blowfish/cipher.go b/vendor/golang.org/x/crypto/blowfish/cipher.go index 542984aa8d..213bf204af 100644 --- a/vendor/golang.org/x/crypto/blowfish/cipher.go +++ b/vendor/golang.org/x/crypto/blowfish/cipher.go @@ -3,10 +3,18 @@ // license that can be found in the LICENSE file. // Package blowfish implements Bruce Schneier's Blowfish encryption algorithm. +// +// Blowfish is a legacy cipher and its short block size makes it vulnerable to +// birthday bound attacks (see https://sweet32.info). It should only be used +// where compatibility with legacy systems, not security, is the goal. +// +// Deprecated: any new system should use AES (from crypto/aes, if necessary in +// an AEAD mode like crypto/cipher.NewGCM) or XChaCha20-Poly1305 (from +// golang.org/x/crypto/chacha20poly1305). package blowfish // import "golang.org/x/crypto/blowfish" // The code is a port of Bruce Schneier's C implementation. -// See http://www.schneier.com/blowfish.html. +// See https://www.schneier.com/blowfish.html. import "strconv" @@ -39,7 +47,7 @@ func NewCipher(key []byte) (*Cipher, error) { // NewSaltedCipher creates a returns a Cipher that folds a salt into its key // schedule. For most purposes, NewCipher, instead of NewSaltedCipher, is -// sufficient and desirable. For bcrypt compatiblity, the key can be over 56 +// sufficient and desirable. For bcrypt compatibility, the key can be over 56 // bytes. func NewSaltedCipher(key, salt []byte) (*Cipher, error) { if len(salt) == 0 { diff --git a/vendor/golang.org/x/crypto/blowfish/const.go b/vendor/golang.org/x/crypto/blowfish/const.go index 8c5ee4cb08..d04077595a 100644 --- a/vendor/golang.org/x/crypto/blowfish/const.go +++ b/vendor/golang.org/x/crypto/blowfish/const.go @@ -4,7 +4,7 @@ // The startup permutation array and substitution boxes. // They are the hexadecimal digits of PI; see: -// http://www.schneier.com/code/constants.txt. +// https://www.schneier.com/code/constants.txt. package blowfish diff --git a/vendor/golang.org/x/crypto/go.mod b/vendor/golang.org/x/crypto/go.mod new file mode 100644 index 0000000000..6a004e45c6 --- /dev/null +++ b/vendor/golang.org/x/crypto/go.mod @@ -0,0 +1,8 @@ +module golang.org/x/crypto + +go 1.11 + +require ( + golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 + golang.org/x/sys v0.0.0-20190412213103-97732733099d +) diff --git a/vendor/golang.org/x/crypto/ocsp/ocsp.go b/vendor/golang.org/x/crypto/ocsp/ocsp.go index 602fefa67e..9d3fffa8fe 100644 --- a/vendor/golang.org/x/crypto/ocsp/ocsp.go +++ b/vendor/golang.org/x/crypto/ocsp/ocsp.go @@ -13,29 +13,69 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/rsa" - "crypto/sha1" + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "errors" + "fmt" "math/big" + "strconv" "time" ) var idPKIXOCSPBasic = asn1.ObjectIdentifier([]int{1, 3, 6, 1, 5, 5, 7, 48, 1, 1}) -// These are internal structures that reflect the ASN.1 structure of an OCSP -// response. See RFC 2560, section 4.2. +// ResponseStatus contains the result of an OCSP request. See +// https://tools.ietf.org/html/rfc6960#section-2.3 +type ResponseStatus int const ( - ocspSuccess = 0 - ocspMalformed = 1 - ocspInternalError = 2 - ocspTryLater = 3 - ocspSigRequired = 4 - ocspUnauthorized = 5 + Success ResponseStatus = 0 + Malformed ResponseStatus = 1 + InternalError ResponseStatus = 2 + TryLater ResponseStatus = 3 + // Status code four is unused in OCSP. See + // https://tools.ietf.org/html/rfc6960#section-4.2.1 + SignatureRequired ResponseStatus = 5 + Unauthorized ResponseStatus = 6 ) +func (r ResponseStatus) String() string { + switch r { + case Success: + return "success" + case Malformed: + return "malformed" + case InternalError: + return "internal error" + case TryLater: + return "try later" + case SignatureRequired: + return "signature required" + case Unauthorized: + return "unauthorized" + default: + return "unknown OCSP status: " + strconv.Itoa(int(r)) + } +} + +// ResponseError is an error that may be returned by ParseResponse to indicate +// that the response itself is an error, not just that it's indicating that a +// certificate is revoked, unknown, etc. +type ResponseError struct { + Status ResponseStatus +} + +func (r ResponseError) Error() string { + return "ocsp: error from server: " + r.Status.String() +} + +// These are internal structures that reflect the ASN.1 structure of an OCSP +// response. See RFC 2560, section 4.2. + type certID struct { HashAlgorithm pkix.AlgorithmIdentifier NameHash []byte @@ -60,7 +100,7 @@ type request struct { type responseASN1 struct { Status asn1.Enumerated - Response responseBytes `asn1:"explicit,tag:0"` + Response responseBytes `asn1:"explicit,tag:0,optional"` } type responseBytes struct { @@ -76,26 +116,26 @@ type basicResponse struct { } type responseData struct { - Raw asn1.RawContent - Version int `asn1:"optional,default:1,explicit,tag:0"` - RawResponderName asn1.RawValue `asn1:"optional,explicit,tag:1"` - KeyHash []byte `asn1:"optional,explicit,tag:2"` - ProducedAt time.Time `asn1:"generalized"` - Responses []singleResponse + Raw asn1.RawContent + Version int `asn1:"optional,default:0,explicit,tag:0"` + RawResponderID asn1.RawValue + ProducedAt time.Time `asn1:"generalized"` + Responses []singleResponse } type singleResponse struct { - CertID certID - Good asn1.Flag `asn1:"tag:0,optional"` - Revoked revokedInfo `asn1:"explicit,tag:1,optional"` - Unknown asn1.Flag `asn1:"tag:2,optional"` - ThisUpdate time.Time `asn1:"generalized"` - NextUpdate time.Time `asn1:"generalized,explicit,tag:0,optional"` + CertID certID + Good asn1.Flag `asn1:"tag:0,optional"` + Revoked revokedInfo `asn1:"tag:1,optional"` + Unknown asn1.Flag `asn1:"tag:2,optional"` + ThisUpdate time.Time `asn1:"generalized"` + NextUpdate time.Time `asn1:"generalized,explicit,tag:0,optional"` + SingleExtensions []pkix.Extension `asn1:"explicit,tag:1,optional"` } type revokedInfo struct { - RevocationTime time.Time `asn1:"generalized"` - Reason int `asn1:"explicit,tag:0,optional"` + RevocationTime time.Time `asn1:"generalized"` + Reason asn1.Enumerated `asn1:"explicit,tag:0,optional"` } var ( @@ -106,7 +146,7 @@ var ( oidSignatureSHA384WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 12} oidSignatureSHA512WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 13} oidSignatureDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 3} - oidSignatureDSAWithSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 4, 3, 2} + oidSignatureDSAWithSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 3, 2} oidSignatureECDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 1} oidSignatureECDSAWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 2} oidSignatureECDSAWithSHA384 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 3} @@ -228,20 +268,47 @@ func getHashAlgorithmFromOID(target asn1.ObjectIdentifier) crypto.Hash { return crypto.Hash(0) } +func getOIDFromHashAlgorithm(target crypto.Hash) asn1.ObjectIdentifier { + for hash, oid := range hashOIDs { + if hash == target { + return oid + } + } + return nil +} + // This is the exposed reflection of the internal OCSP structures. +// The status values that can be expressed in OCSP. See RFC 6960. const ( // Good means that the certificate is valid. Good = iota // Revoked means that the certificate has been deliberately revoked. - Revoked = iota + Revoked // Unknown means that the OCSP responder doesn't know about the certificate. - Unknown = iota - // ServerFailed means that the OCSP responder failed to process the request. - ServerFailed = iota + Unknown + // ServerFailed is unused and was never used (see + // https://go-review.googlesource.com/#/c/18944). ParseResponse will + // return a ResponseError when an error response is parsed. + ServerFailed +) + +// The enumerated reasons for revoking a certificate. See RFC 5280. +const ( + Unspecified = 0 + KeyCompromise = 1 + CACompromise = 2 + AffiliationChanged = 3 + Superseded = 4 + CessationOfOperation = 5 + CertificateHold = 6 + + RemoveFromCRL = 8 + PrivilegeWithdrawn = 9 + AACompromise = 10 ) -// Request represents an OCSP request. See RFC 2560. +// Request represents an OCSP request. See RFC 6960. type Request struct { HashAlgorithm crypto.Hash IssuerNameHash []byte @@ -249,9 +316,36 @@ type Request struct { SerialNumber *big.Int } -// Response represents an OCSP response. See RFC 2560. +// Marshal marshals the OCSP request to ASN.1 DER encoded form. +func (req *Request) Marshal() ([]byte, error) { + hashAlg := getOIDFromHashAlgorithm(req.HashAlgorithm) + if hashAlg == nil { + return nil, errors.New("Unknown hash algorithm") + } + return asn1.Marshal(ocspRequest{ + tbsRequest{ + Version: 0, + RequestList: []request{ + { + Cert: certID{ + pkix.AlgorithmIdentifier{ + Algorithm: hashAlg, + Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, + }, + req.IssuerNameHash, + req.IssuerKeyHash, + req.SerialNumber, + }, + }, + }, + }, + }) +} + +// Response represents an OCSP response containing a single SingleResponse. See +// RFC 6960. type Response struct { - // Status is one of {Good, Revoked, Unknown, ServerFailed} + // Status is one of {Good, Revoked, Unknown} Status int SerialNumber *big.Int ProducedAt, ThisUpdate, NextUpdate, RevokedAt time.Time @@ -262,6 +356,34 @@ type Response struct { TBSResponseData []byte Signature []byte SignatureAlgorithm x509.SignatureAlgorithm + + // IssuerHash is the hash used to compute the IssuerNameHash and IssuerKeyHash. + // Valid values are crypto.SHA1, crypto.SHA256, crypto.SHA384, and crypto.SHA512. + // If zero, the default is crypto.SHA1. + IssuerHash crypto.Hash + + // RawResponderName optionally contains the DER-encoded subject of the + // responder certificate. Exactly one of RawResponderName and + // ResponderKeyHash is set. + RawResponderName []byte + // ResponderKeyHash optionally contains the SHA-1 hash of the + // responder's public key. Exactly one of RawResponderName and + // ResponderKeyHash is set. + ResponderKeyHash []byte + + // Extensions contains raw X.509 extensions from the singleExtensions field + // of the OCSP response. When parsing certificates, this can be used to + // extract non-critical extensions that are not parsed by this package. When + // marshaling OCSP responses, the Extensions field is ignored, see + // ExtraExtensions. + Extensions []pkix.Extension + + // ExtraExtensions contains extensions to be copied, raw, into any marshaled + // OCSP response (in the singleExtensions field). Values override any + // extensions that would otherwise be produced based on the other fields. The + // ExtraExtensions field is not populated when parsing certificates, see + // Extensions. + ExtraExtensions []pkix.Extension } // These are pre-serialized error responses for the various non-success codes @@ -323,12 +445,31 @@ func ParseRequest(bytes []byte) (*Request, error) { }, nil } -// ParseResponse parses an OCSP response in DER form. It only supports -// responses for a single certificate. If the response contains a certificate -// then the signature over the response is checked. If issuer is not nil then -// it will be used to validate the signature or embedded certificate. Invalid -// signatures or parse failures will result in a ParseError. +// ParseResponse parses an OCSP response in DER form. The response must contain +// only one certificate status. To parse the status of a specific certificate +// from a response which may contain multiple statuses, use ParseResponseForCert +// instead. +// +// If the response contains an embedded certificate, then that certificate will +// be used to verify the response signature. If the response contains an +// embedded certificate and issuer is not nil, then issuer will be used to verify +// the signature on the embedded certificate. +// +// If the response does not contain an embedded certificate and issuer is not +// nil, then issuer will be used to verify the response signature. +// +// Invalid responses and parse failures will result in a ParseError. +// Error responses will result in a ResponseError. func ParseResponse(bytes []byte, issuer *x509.Certificate) (*Response, error) { + return ParseResponseForCert(bytes, nil, issuer) +} + +// ParseResponseForCert acts identically to ParseResponse, except it supports +// parsing responses that contain multiple statuses. If the response contains +// multiple statuses and cert is not nil, then ParseResponseForCert will return +// the first status which contains a matching serial, otherwise it will return an +// error. If cert is nil, then the first status in the response will be returned. +func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Response, error) { var resp responseASN1 rest, err := asn1.Unmarshal(bytes, &resp) if err != nil { @@ -338,10 +479,8 @@ func ParseResponse(bytes []byte, issuer *x509.Certificate) (*Response, error) { return nil, ParseError("trailing data in OCSP response") } - ret := new(Response) - if resp.Status != ocspSuccess { - ret.Status = ServerFailed - return ret, nil + if status := ResponseStatus(resp.Status); status != Success { + return nil, ResponseError{status} } if !resp.Response.ResponseType.Equal(idPKIXOCSPBasic) { @@ -353,59 +492,116 @@ func ParseResponse(bytes []byte, issuer *x509.Certificate) (*Response, error) { if err != nil { return nil, err } - - if len(basicResp.Certificates) > 1 { - return nil, ParseError("OCSP response contains bad number of certificates") + if len(rest) > 0 { + return nil, ParseError("trailing data in OCSP response") } - if len(basicResp.TBSResponseData.Responses) != 1 { + if n := len(basicResp.TBSResponseData.Responses); n == 0 || cert == nil && n > 1 { return nil, ParseError("OCSP response contains bad number of responses") } - ret.TBSResponseData = basicResp.TBSResponseData.Raw - ret.Signature = basicResp.Signature.RightAlign() - ret.SignatureAlgorithm = getSignatureAlgorithmFromOID(basicResp.SignatureAlgorithm.Algorithm) + var singleResp singleResponse + if cert == nil { + singleResp = basicResp.TBSResponseData.Responses[0] + } else { + match := false + for _, resp := range basicResp.TBSResponseData.Responses { + if cert.SerialNumber.Cmp(resp.CertID.SerialNumber) == 0 { + singleResp = resp + match = true + break + } + } + if !match { + return nil, ParseError("no response matching the supplied certificate") + } + } + + ret := &Response{ + TBSResponseData: basicResp.TBSResponseData.Raw, + Signature: basicResp.Signature.RightAlign(), + SignatureAlgorithm: getSignatureAlgorithmFromOID(basicResp.SignatureAlgorithm.Algorithm), + Extensions: singleResp.SingleExtensions, + SerialNumber: singleResp.CertID.SerialNumber, + ProducedAt: basicResp.TBSResponseData.ProducedAt, + ThisUpdate: singleResp.ThisUpdate, + NextUpdate: singleResp.NextUpdate, + } + + // Handle the ResponderID CHOICE tag. ResponderID can be flattened into + // TBSResponseData once https://go-review.googlesource.com/34503 has been + // released. + rawResponderID := basicResp.TBSResponseData.RawResponderID + switch rawResponderID.Tag { + case 1: // Name + var rdn pkix.RDNSequence + if rest, err := asn1.Unmarshal(rawResponderID.Bytes, &rdn); err != nil || len(rest) != 0 { + return nil, ParseError("invalid responder name") + } + ret.RawResponderName = rawResponderID.Bytes + case 2: // KeyHash + if rest, err := asn1.Unmarshal(rawResponderID.Bytes, &ret.ResponderKeyHash); err != nil || len(rest) != 0 { + return nil, ParseError("invalid responder key hash") + } + default: + return nil, ParseError("invalid responder id tag") + } if len(basicResp.Certificates) > 0 { + // Responders should only send a single certificate (if they + // send any) that connects the responder's certificate to the + // original issuer. We accept responses with multiple + // certificates due to a number responders sending them[1], but + // ignore all but the first. + // + // [1] https://github.com/golang/go/issues/21527 ret.Certificate, err = x509.ParseCertificate(basicResp.Certificates[0].FullBytes) if err != nil { return nil, err } if err := ret.CheckSignatureFrom(ret.Certificate); err != nil { - return nil, ParseError("bad OCSP signature") + return nil, ParseError("bad signature on embedded certificate: " + err.Error()) } if issuer != nil { if err := issuer.CheckSignature(ret.Certificate.SignatureAlgorithm, ret.Certificate.RawTBSCertificate, ret.Certificate.Signature); err != nil { - return nil, ParseError("bad signature on embedded certificate") + return nil, ParseError("bad OCSP signature: " + err.Error()) } } } else if issuer != nil { if err := ret.CheckSignatureFrom(issuer); err != nil { - return nil, ParseError("bad OCSP signature") + return nil, ParseError("bad OCSP signature: " + err.Error()) } } - r := basicResp.TBSResponseData.Responses[0] + for _, ext := range singleResp.SingleExtensions { + if ext.Critical { + return nil, ParseError("unsupported critical extension") + } + } - ret.SerialNumber = r.CertID.SerialNumber + for h, oid := range hashOIDs { + if singleResp.CertID.HashAlgorithm.Algorithm.Equal(oid) { + ret.IssuerHash = h + break + } + } + if ret.IssuerHash == 0 { + return nil, ParseError("unsupported issuer hash algorithm") + } switch { - case bool(r.Good): + case bool(singleResp.Good): ret.Status = Good - case bool(r.Unknown): + case bool(singleResp.Unknown): ret.Status = Unknown default: ret.Status = Revoked - ret.RevokedAt = r.Revoked.RevocationTime - ret.RevocationReason = r.Revoked.Reason + ret.RevokedAt = singleResp.Revoked.RevocationTime + ret.RevocationReason = int(singleResp.Revoked.Reason) } - ret.ProducedAt = basicResp.TBSResponseData.ProducedAt - ret.ThisUpdate = r.ThisUpdate - ret.NextUpdate = r.NextUpdate - return ret, nil } @@ -432,8 +628,7 @@ func CreateRequest(cert, issuer *x509.Certificate, opts *RequestOptions) ([]byte // OCSP seems to be the only place where these raw hash identifiers are // used. I took the following from // http://msdn.microsoft.com/en-us/library/ff635603.aspx - var hashOID asn1.ObjectIdentifier - hashOID, ok := hashOIDs[hashFunc] + _, ok := hashOIDs[hashFunc] if !ok { return nil, x509.ErrUnsupportedAlgorithm } @@ -458,38 +653,28 @@ func CreateRequest(cert, issuer *x509.Certificate, opts *RequestOptions) ([]byte h.Write(issuer.RawSubject) issuerNameHash := h.Sum(nil) - return asn1.Marshal(ocspRequest{ - tbsRequest{ - Version: 0, - RequestList: []request{ - { - Cert: certID{ - pkix.AlgorithmIdentifier{ - Algorithm: hashOID, - Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, - }, - issuerNameHash, - issuerKeyHash, - cert.SerialNumber, - }, - }, - }, - }, - }) + req := &Request{ + HashAlgorithm: hashFunc, + IssuerNameHash: issuerNameHash, + IssuerKeyHash: issuerKeyHash, + SerialNumber: cert.SerialNumber, + } + return req.Marshal() } // CreateResponse returns a DER-encoded OCSP response with the specified contents. // The fields in the response are populated as follows: // -// The responder cert is used to populate the ResponderName field, and the certificate -// itself is provided alongside the OCSP response signature. +// The responder cert is used to populate the responder's name field, and the +// certificate itself is provided alongside the OCSP response signature. // // The issuer cert is used to puplate the IssuerNameHash and IssuerKeyHash fields. -// (SHA-1 is used for the hash function; this is not configurable.) // -// The template is used to populate the SerialNumber, RevocationStatus, RevokedAt, +// The template is used to populate the SerialNumber, Status, RevokedAt, // RevocationReason, ThisUpdate, and NextUpdate fields. // +// If template.IssuerHash is not set, SHA1 will be used. +// // The ProducedAt date is automatically set to the current date, to the nearest minute. func CreateResponse(issuer, responderCert *x509.Certificate, template Response, priv crypto.Signer) ([]byte, error) { var publicKeyInfo struct { @@ -500,7 +685,18 @@ func CreateResponse(issuer, responderCert *x509.Certificate, template Response, return nil, err } - h := sha1.New() + if template.IssuerHash == 0 { + template.IssuerHash = crypto.SHA1 + } + hashOID := getOIDFromHashAlgorithm(template.IssuerHash) + if hashOID == nil { + return nil, errors.New("unsupported issuer hash algorithm") + } + + if !template.IssuerHash.Available() { + return nil, fmt.Errorf("issuer hash algorithm %v not linked into binary", template.IssuerHash) + } + h := template.IssuerHash.New() h.Write(publicKeyInfo.PublicKey.RightAlign()) issuerKeyHash := h.Sum(nil) @@ -511,15 +707,16 @@ func CreateResponse(issuer, responderCert *x509.Certificate, template Response, innerResponse := singleResponse{ CertID: certID{ HashAlgorithm: pkix.AlgorithmIdentifier{ - Algorithm: hashOIDs[crypto.SHA1], + Algorithm: hashOID, Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, }, NameHash: issuerNameHash, IssuerKeyHash: issuerKeyHash, SerialNumber: template.SerialNumber, }, - ThisUpdate: template.ThisUpdate.UTC(), - NextUpdate: template.NextUpdate.UTC(), + ThisUpdate: template.ThisUpdate.UTC(), + NextUpdate: template.NextUpdate.UTC(), + SingleExtensions: template.ExtraExtensions, } switch template.Status { @@ -530,21 +727,21 @@ func CreateResponse(issuer, responderCert *x509.Certificate, template Response, case Revoked: innerResponse.Revoked = revokedInfo{ RevocationTime: template.RevokedAt.UTC(), - Reason: template.RevocationReason, + Reason: asn1.Enumerated(template.RevocationReason), } } - responderName := asn1.RawValue{ + rawResponderID := asn1.RawValue{ Class: 2, // context-specific - Tag: 1, // explicit tag + Tag: 1, // Name (explicit tag) IsCompound: true, Bytes: responderCert.RawSubject, } tbsResponseData := responseData{ - Version: 0, - RawResponderName: responderName, - ProducedAt: time.Now().Truncate(time.Minute).UTC(), - Responses: []singleResponse{innerResponse}, + Version: 0, + RawResponderID: rawResponderID, + ProducedAt: time.Now().Truncate(time.Minute).UTC(), + Responses: []singleResponse{innerResponse}, } tbsResponseDataDER, err := asn1.Marshal(tbsResponseData) @@ -574,7 +771,7 @@ func CreateResponse(issuer, responderCert *x509.Certificate, template Response, } if template.Certificate != nil { response.Certificates = []asn1.RawValue{ - asn1.RawValue{FullBytes: template.Certificate.Raw}, + {FullBytes: template.Certificate.Raw}, } } responseDER, err := asn1.Marshal(response) @@ -583,7 +780,7 @@ func CreateResponse(issuer, responderCert *x509.Certificate, template Response, } return asn1.Marshal(responseASN1{ - Status: ocspSuccess, + Status: asn1.Enumerated(Success), Response: responseBytes{ ResponseType: idPKIXOCSPBasic, Response: responseDER, diff --git a/vendor/golang.org/x/crypto/otr/libotr_test_helper.c b/vendor/golang.org/x/crypto/otr/libotr_test_helper.c index 6423eeda6d..b3ca072d48 100644 --- a/vendor/golang.org/x/crypto/otr/libotr_test_helper.c +++ b/vendor/golang.org/x/crypto/otr/libotr_test_helper.c @@ -13,6 +13,7 @@ #include #include +#include static int g_session_established = 0; @@ -20,92 +21,109 @@ OtrlPolicy policy(void *opdata, ConnContext *context) { return OTRL_POLICY_ALWAYS; } -int is_logged_in(void *opdata, const char *accountname, const char *protocol, const char *recipient) { +int is_logged_in(void *opdata, const char *accountname, const char *protocol, + const char *recipient) { return 1; } -void inject_message(void *opdata, const char *accountname, const char *protocol, const char *recipient, const char *message) { +void inject_message(void *opdata, const char *accountname, const char *protocol, + const char *recipient, const char *message) { printf("%s\n", message); fflush(stdout); fprintf(stderr, "libotr helper sent: %s\n", message); } -void notify(void *opdata, OtrlNotifyLevel level, const char *accountname, const char *protocol, const char *username, const char *title, const char *primary, const char *secondary) { - fprintf(stderr, "NOTIFY: %s %s %s %s\n", username, title, primary, secondary); -} +void update_context_list(void *opdata) {} -int display_otr_message(void *opdata, const char *accountname, const char *protocol, const char *username, const char *msg) { - fprintf(stderr, "MESSAGE: %s %s\n", username, msg); - return 1; +void new_fingerprint(void *opdata, OtrlUserState us, const char *accountname, + const char *protocol, const char *username, + unsigned char fingerprint[20]) { + fprintf(stderr, "NEW FINGERPRINT\n"); + g_session_established = 1; } -void update_context_list(void *opdata) { -} +void write_fingerprints(void *opdata) {} -const char *protocol_name(void *opdata, const char *protocol) { - return "PROTOCOL"; -} +void gone_secure(void *opdata, ConnContext *context) {} -void protocol_name_free(void *opdata, const char *protocol_name) { -} +void gone_insecure(void *opdata, ConnContext *context) {} -void new_fingerprint(void *opdata, OtrlUserState us, const char *accountname, const char *protocol, const char *username, unsigned char fingerprint[20]) { - fprintf(stderr, "NEW FINGERPRINT\n"); - g_session_established = 1; -} +void still_secure(void *opdata, ConnContext *context, int is_reply) {} -void write_fingerprints(void *opdata) { -} +int max_message_size(void *opdata, ConnContext *context) { return 99999; } -void gone_secure(void *opdata, ConnContext *context) { +const char *account_name(void *opdata, const char *account, + const char *protocol) { + return "ACCOUNT"; } -void gone_insecure(void *opdata, ConnContext *context) { -} +void account_name_free(void *opdata, const char *account_name) {} -void still_secure(void *opdata, ConnContext *context, int is_reply) { +const char *error_message(void *opdata, ConnContext *context, + OtrlErrorCode err_code) { + return "ERR"; } -void log_message(void *opdata, const char *message) { - fprintf(stderr, "MESSAGE: %s\n", message); -} +void error_message_free(void *opdata, const char *msg) {} -int max_message_size(void *opdata, ConnContext *context) { - return 99999; -} +void resent_msg_prefix_free(void *opdata, const char *prefix) {} -const char *account_name(void *opdata, const char *account, const char *protocol) { - return "ACCOUNT"; -} +void handle_smp_event(void *opdata, OtrlSMPEvent smp_event, + ConnContext *context, unsigned short progress_event, + char *question) {} -void account_name_free(void *opdata, const char *account_name) { +void handle_msg_event(void *opdata, OtrlMessageEvent msg_event, + ConnContext *context, const char *message, + gcry_error_t err) { + fprintf(stderr, "msg event: %d %s\n", msg_event, message); } OtrlMessageAppOps uiops = { - policy, - NULL, - is_logged_in, - inject_message, - notify, - display_otr_message, - update_context_list, - protocol_name, - protocol_name_free, - new_fingerprint, - write_fingerprints, - gone_secure, - gone_insecure, - still_secure, - log_message, - max_message_size, - account_name, - account_name_free, + policy, + NULL, + is_logged_in, + inject_message, + update_context_list, + new_fingerprint, + write_fingerprints, + gone_secure, + gone_insecure, + still_secure, + max_message_size, + account_name, + account_name_free, + NULL, /* received_symkey */ + error_message, + error_message_free, + NULL, /* resent_msg_prefix */ + resent_msg_prefix_free, + handle_smp_event, + handle_msg_event, + NULL /* create_instag */, + NULL /* convert_msg */, + NULL /* convert_free */, + NULL /* timer_control */, }; -static const char kPrivateKeyData[] = "(privkeys (account (name \"account\") (protocol proto) (private-key (dsa (p #00FC07ABCF0DC916AFF6E9AE47BEF60C7AB9B4D6B2469E436630E36F8A489BE812486A09F30B71224508654940A835301ACC525A4FF133FC152CC53DCC59D65C30A54F1993FE13FE63E5823D4C746DB21B90F9B9C00B49EC7404AB1D929BA7FBA12F2E45C6E0A651689750E8528AB8C031D3561FECEE72EBB4A090D450A9B7A857#) (q #00997BD266EF7B1F60A5C23F3A741F2AEFD07A2081#) (g #535E360E8A95EBA46A4F7DE50AD6E9B2A6DB785A66B64EB9F20338D2A3E8FB0E94725848F1AA6CC567CB83A1CC517EC806F2E92EAE71457E80B2210A189B91250779434B41FC8A8873F6DB94BEA7D177F5D59E7E114EE10A49CFD9CEF88AE43387023B672927BA74B04EB6BBB5E57597766A2F9CE3857D7ACE3E1E3BC1FC6F26#) (y #0AC8670AD767D7A8D9D14CC1AC6744CD7D76F993B77FFD9E39DF01E5A6536EF65E775FCEF2A983E2A19BD6415500F6979715D9FD1257E1FE2B6F5E1E74B333079E7C880D39868462A93454B41877BE62E5EF0A041C2EE9C9E76BD1E12AE25D9628DECB097025DD625EF49C3258A1A3C0FF501E3DC673B76D7BABF349009B6ECF#) (x #14D0345A3562C480A039E3C72764F72D79043216#)))))\n"; - -int -main() { +static const char kPrivateKeyData[] = + "(privkeys (account (name \"account\") (protocol proto) (private-key (dsa " + "(p " + "#00FC07ABCF0DC916AFF6E9AE47BEF60C7AB9B4D6B2469E436630E36F8A489BE812486A09F" + "30B71224508654940A835301ACC525A4FF133FC152CC53DCC59D65C30A54F1993FE13FE63E" + "5823D4C746DB21B90F9B9C00B49EC7404AB1D929BA7FBA12F2E45C6E0A651689750E8528AB" + "8C031D3561FECEE72EBB4A090D450A9B7A857#) (q " + "#00997BD266EF7B1F60A5C23F3A741F2AEFD07A2081#) (g " + "#535E360E8A95EBA46A4F7DE50AD6E9B2A6DB785A66B64EB9F20338D2A3E8FB0E94725848F" + "1AA6CC567CB83A1CC517EC806F2E92EAE71457E80B2210A189B91250779434B41FC8A8873F" + "6DB94BEA7D177F5D59E7E114EE10A49CFD9CEF88AE43387023B672927BA74B04EB6BBB5E57" + "597766A2F9CE3857D7ACE3E1E3BC1FC6F26#) (y " + "#0AC8670AD767D7A8D9D14CC1AC6744CD7D76F993B77FFD9E39DF01E5A6536EF65E775FCEF" + "2A983E2A19BD6415500F6979715D9FD1257E1FE2B6F5E1E74B333079E7C880D39868462A93" + "454B41877BE62E5EF0A041C2EE9C9E76BD1E12AE25D9628DECB097025DD625EF49C3258A1A" + "3C0FF501E3DC673B76D7BABF349009B6ECF#) (x " + "#14D0345A3562C480A039E3C72764F72D79043216#)))))\n"; + +int main() { OTRL_INIT; // We have to write the private key information to a file because the libotr @@ -116,12 +134,13 @@ main() { } char private_key_file[256]; - snprintf(private_key_file, sizeof(private_key_file), "%s/libotr_test_helper_privatekeys-XXXXXX", tmpdir); + snprintf(private_key_file, sizeof(private_key_file), + "%s/libotr_test_helper_privatekeys-XXXXXX", tmpdir); int fd = mkstemp(private_key_file); if (fd == -1) { perror("creating temp file"); } - write(fd, kPrivateKeyData, sizeof(kPrivateKeyData)-1); + write(fd, kPrivateKeyData, sizeof(kPrivateKeyData) - 1); close(fd); OtrlUserState userstate = otrl_userstate_create(); @@ -133,7 +152,7 @@ main() { char buf[4096]; for (;;) { - char* message = fgets(buf, sizeof(buf), stdin); + char *message = fgets(buf, sizeof(buf), stdin); if (strlen(message) == 0) { break; } @@ -142,9 +161,11 @@ main() { char *newmessage = NULL; OtrlTLV *tlvs; - int ignore_message = otrl_message_receiving(userstate, &uiops, NULL, "account", "proto", "peer", message, &newmessage, &tlvs, NULL, NULL); + int ignore_message = otrl_message_receiving( + userstate, &uiops, NULL, "account", "proto", "peer", message, + &newmessage, &tlvs, NULL, NULL, NULL); if (tlvs) { - otrl_tlv_free(tlvs); + otrl_tlv_free(tlvs); } if (newmessage != NULL) { @@ -154,16 +175,21 @@ main() { gcry_error_t err; char *newmessage = NULL; - err = otrl_message_sending(userstate, &uiops, NULL, "account", "proto", "peer", "test message", NULL, &newmessage, NULL, NULL); + err = otrl_message_sending(userstate, &uiops, NULL, "account", "proto", + "peer", 0, "test message", NULL, &newmessage, + OTRL_FRAGMENT_SEND_SKIP, NULL, NULL, NULL); if (newmessage == NULL) { fprintf(stderr, "libotr didn't encrypt message\n"); return 1; } write(1, newmessage, strlen(newmessage)); write(1, "\n", 1); - g_session_established = 0; + fprintf(stderr, "libotr sent: %s\n", newmessage); otrl_message_free(newmessage); + + g_session_established = 0; write(1, "?OTRv2?\n", 8); + fprintf(stderr, "libotr sent: ?OTRv2\n"); } } diff --git a/vendor/golang.org/x/crypto/otr/otr.go b/vendor/golang.org/x/crypto/otr/otr.go index 07ac080e59..29121e9bb6 100644 --- a/vendor/golang.org/x/crypto/otr/otr.go +++ b/vendor/golang.org/x/crypto/otr/otr.go @@ -4,6 +4,10 @@ // Package otr implements the Off The Record protocol as specified in // http://www.cypherpunks.ca/otr/Protocol-v2-3.1.0.html +// +// The version of OTR implemented by this package has been deprecated +// (https://bugs.otr.im/lib/libotr/issues/140). An implementation of OTRv3 is +// available at https://github.com/coyim/otr3. package otr // import "golang.org/x/crypto/otr" import ( @@ -277,7 +281,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se in = in[len(msgPrefix) : len(in)-1] } else if version := isQuery(in); version > 0 { c.authState = authStateAwaitingDHKey - c.myKeyId = 0 + c.reset() toSend = c.encode(c.generateDHCommit()) return } else { @@ -311,7 +315,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se if err = c.processDHCommit(msg); err != nil { return } - c.myKeyId = 0 + c.reset() toSend = c.encode(c.generateDHKey()) return case authStateAwaitingDHKey: @@ -330,7 +334,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se if err = c.processDHCommit(msg); err != nil { return } - c.myKeyId = 0 + c.reset() toSend = c.encode(c.generateDHKey()) return } @@ -343,7 +347,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se if err = c.processDHCommit(msg); err != nil { return } - c.myKeyId = 0 + c.reset() toSend = c.encode(c.generateDHKey()) c.authState = authStateAwaitingRevealSig default: @@ -417,12 +421,11 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se change = SMPSecretNeeded c.smp.saved = &inTLV return - } else if err == smpFailureError { + } + if err == smpFailureError { err = nil change = SMPFailed - return - } - if complete { + } else if complete { change = SMPComplete } if reply.typ != 0 { @@ -638,7 +641,7 @@ func (c *Conversation) serializeDHKey() []byte { } func (c *Conversation) processDHKey(in []byte) (isSame bool, err error) { - gy, in, ok := getMPI(in) + gy, _, ok := getMPI(in) if !ok { err = errors.New("otr: corrupt DH key message") return @@ -848,7 +851,6 @@ func (c *Conversation) rotateDHKeys() { slot := &c.keySlots[i] if slot.used && slot.myKeyId == c.myKeyId-1 { slot.used = false - c.oldMACs = append(c.oldMACs, slot.sendMACKey...) c.oldMACs = append(c.oldMACs, slot.recvMACKey...) } } @@ -924,7 +926,6 @@ func (c *Conversation) processData(in []byte) (out []byte, tlvs []tlv, err error slot := &c.keySlots[i] if slot.used && slot.theirKeyId == theirKeyId-1 { slot.used = false - c.oldMACs = append(c.oldMACs, slot.sendMACKey...) c.oldMACs = append(c.oldMACs, slot.recvMACKey...) } } @@ -946,6 +947,7 @@ func (c *Conversation) processData(in []byte) (out []byte, tlvs []tlv, err error t.data, tlvData, ok3 = getNBytes(tlvData, int(t.length)) if !ok1 || !ok2 || !ok3 { err = errors.New("otr: corrupt tlv data") + return } tlvs = append(tlvs, t) } @@ -1039,8 +1041,7 @@ func (c *Conversation) calcDataKeys(myKeyId, theirKeyId uint32) (slot *keySlot, } } if slot == nil { - err = errors.New("otr: internal error: no key slots") - return + return nil, errors.New("otr: internal error: no more key slots") } var myPriv, myPub, theirPub *big.Int @@ -1096,6 +1097,10 @@ func (c *Conversation) calcDataKeys(myKeyId, theirKeyId uint32) (slot *keySlot, h.Write(slot.recvAESKey) slot.recvMACKey = h.Sum(slot.recvMACKey[:0]) + slot.theirKeyId = theirKeyId + slot.myKeyId = myKeyId + slot.used = true + zero(slot.theirLastCtr[:]) return } @@ -1162,6 +1167,14 @@ func (c *Conversation) encode(msg []byte) [][]byte { return ret } +func (c *Conversation) reset() { + c.myKeyId = 0 + + for i := range c.keySlots { + c.keySlots[i].used = false + } +} + type PublicKey struct { dsa.PublicKey } @@ -1305,6 +1318,12 @@ func (priv *PrivateKey) Import(in []byte) bool { mpis[i] = new(big.Int).SetBytes(mpiBytes) } + for _, mpi := range mpis { + if mpi.Sign() <= 0 { + return false + } + } + priv.PrivateKey.P = mpis[0] priv.PrivateKey.Q = mpis[1] priv.PrivateKey.G = mpis[2] diff --git a/vendor/golang.org/x/crypto/pkcs12/bmp-string.go b/vendor/golang.org/x/crypto/pkcs12/bmp-string.go new file mode 100644 index 0000000000..233b8b62cc --- /dev/null +++ b/vendor/golang.org/x/crypto/pkcs12/bmp-string.go @@ -0,0 +1,50 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pkcs12 + +import ( + "errors" + "unicode/utf16" +) + +// bmpString returns s encoded in UCS-2 with a zero terminator. +func bmpString(s string) ([]byte, error) { + // References: + // https://tools.ietf.org/html/rfc7292#appendix-B.1 + // https://en.wikipedia.org/wiki/Plane_(Unicode)#Basic_Multilingual_Plane + // - non-BMP characters are encoded in UTF 16 by using a surrogate pair of 16-bit codes + // EncodeRune returns 0xfffd if the rune does not need special encoding + // - the above RFC provides the info that BMPStrings are NULL terminated. + + ret := make([]byte, 0, 2*len(s)+2) + + for _, r := range s { + if t, _ := utf16.EncodeRune(r); t != 0xfffd { + return nil, errors.New("pkcs12: string contains characters that cannot be encoded in UCS-2") + } + ret = append(ret, byte(r/256), byte(r%256)) + } + + return append(ret, 0, 0), nil +} + +func decodeBMPString(bmpString []byte) (string, error) { + if len(bmpString)%2 != 0 { + return "", errors.New("pkcs12: odd-length BMP string") + } + + // strip terminator if present + if l := len(bmpString); l >= 2 && bmpString[l-1] == 0 && bmpString[l-2] == 0 { + bmpString = bmpString[:l-2] + } + + s := make([]uint16, 0, len(bmpString)/2) + for len(bmpString) > 0 { + s = append(s, uint16(bmpString[0])<<8+uint16(bmpString[1])) + bmpString = bmpString[2:] + } + + return string(utf16.Decode(s)), nil +} diff --git a/vendor/golang.org/x/crypto/pkcs12/crypto.go b/vendor/golang.org/x/crypto/pkcs12/crypto.go new file mode 100644 index 0000000000..484ca51b71 --- /dev/null +++ b/vendor/golang.org/x/crypto/pkcs12/crypto.go @@ -0,0 +1,131 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pkcs12 + +import ( + "bytes" + "crypto/cipher" + "crypto/des" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + + "golang.org/x/crypto/pkcs12/internal/rc2" +) + +var ( + oidPBEWithSHAAnd3KeyTripleDESCBC = asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 12, 1, 3}) + oidPBEWithSHAAnd40BitRC2CBC = asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 12, 1, 6}) +) + +// pbeCipher is an abstraction of a PKCS#12 cipher. +type pbeCipher interface { + // create returns a cipher.Block given a key. + create(key []byte) (cipher.Block, error) + // deriveKey returns a key derived from the given password and salt. + deriveKey(salt, password []byte, iterations int) []byte + // deriveKey returns an IV derived from the given password and salt. + deriveIV(salt, password []byte, iterations int) []byte +} + +type shaWithTripleDESCBC struct{} + +func (shaWithTripleDESCBC) create(key []byte) (cipher.Block, error) { + return des.NewTripleDESCipher(key) +} + +func (shaWithTripleDESCBC) deriveKey(salt, password []byte, iterations int) []byte { + return pbkdf(sha1Sum, 20, 64, salt, password, iterations, 1, 24) +} + +func (shaWithTripleDESCBC) deriveIV(salt, password []byte, iterations int) []byte { + return pbkdf(sha1Sum, 20, 64, salt, password, iterations, 2, 8) +} + +type shaWith40BitRC2CBC struct{} + +func (shaWith40BitRC2CBC) create(key []byte) (cipher.Block, error) { + return rc2.New(key, len(key)*8) +} + +func (shaWith40BitRC2CBC) deriveKey(salt, password []byte, iterations int) []byte { + return pbkdf(sha1Sum, 20, 64, salt, password, iterations, 1, 5) +} + +func (shaWith40BitRC2CBC) deriveIV(salt, password []byte, iterations int) []byte { + return pbkdf(sha1Sum, 20, 64, salt, password, iterations, 2, 8) +} + +type pbeParams struct { + Salt []byte + Iterations int +} + +func pbDecrypterFor(algorithm pkix.AlgorithmIdentifier, password []byte) (cipher.BlockMode, int, error) { + var cipherType pbeCipher + + switch { + case algorithm.Algorithm.Equal(oidPBEWithSHAAnd3KeyTripleDESCBC): + cipherType = shaWithTripleDESCBC{} + case algorithm.Algorithm.Equal(oidPBEWithSHAAnd40BitRC2CBC): + cipherType = shaWith40BitRC2CBC{} + default: + return nil, 0, NotImplementedError("algorithm " + algorithm.Algorithm.String() + " is not supported") + } + + var params pbeParams + if err := unmarshal(algorithm.Parameters.FullBytes, ¶ms); err != nil { + return nil, 0, err + } + + key := cipherType.deriveKey(params.Salt, password, params.Iterations) + iv := cipherType.deriveIV(params.Salt, password, params.Iterations) + + block, err := cipherType.create(key) + if err != nil { + return nil, 0, err + } + + return cipher.NewCBCDecrypter(block, iv), block.BlockSize(), nil +} + +func pbDecrypt(info decryptable, password []byte) (decrypted []byte, err error) { + cbc, blockSize, err := pbDecrypterFor(info.Algorithm(), password) + if err != nil { + return nil, err + } + + encrypted := info.Data() + if len(encrypted) == 0 { + return nil, errors.New("pkcs12: empty encrypted data") + } + if len(encrypted)%blockSize != 0 { + return nil, errors.New("pkcs12: input is not a multiple of the block size") + } + decrypted = make([]byte, len(encrypted)) + cbc.CryptBlocks(decrypted, encrypted) + + psLen := int(decrypted[len(decrypted)-1]) + if psLen == 0 || psLen > blockSize { + return nil, ErrDecryption + } + + if len(decrypted) < psLen { + return nil, ErrDecryption + } + ps := decrypted[len(decrypted)-psLen:] + decrypted = decrypted[:len(decrypted)-psLen] + if bytes.Compare(ps, bytes.Repeat([]byte{byte(psLen)}, psLen)) != 0 { + return nil, ErrDecryption + } + + return +} + +// decryptable abstracts an object that contains ciphertext. +type decryptable interface { + Algorithm() pkix.AlgorithmIdentifier + Data() []byte +} diff --git a/vendor/golang.org/x/crypto/pkcs12/errors.go b/vendor/golang.org/x/crypto/pkcs12/errors.go new file mode 100644 index 0000000000..7377ce6fb2 --- /dev/null +++ b/vendor/golang.org/x/crypto/pkcs12/errors.go @@ -0,0 +1,23 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pkcs12 + +import "errors" + +var ( + // ErrDecryption represents a failure to decrypt the input. + ErrDecryption = errors.New("pkcs12: decryption error, incorrect padding") + + // ErrIncorrectPassword is returned when an incorrect password is detected. + // Usually, P12/PFX data is signed to be able to verify the password. + ErrIncorrectPassword = errors.New("pkcs12: decryption password incorrect") +) + +// NotImplementedError indicates that the input is not currently supported. +type NotImplementedError string + +func (e NotImplementedError) Error() string { + return "pkcs12: " + string(e) +} diff --git a/vendor/golang.org/x/crypto/pkcs12/internal/rc2/rc2.go b/vendor/golang.org/x/crypto/pkcs12/internal/rc2/rc2.go new file mode 100644 index 0000000000..7499e3fb69 --- /dev/null +++ b/vendor/golang.org/x/crypto/pkcs12/internal/rc2/rc2.go @@ -0,0 +1,271 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rc2 implements the RC2 cipher +/* +https://www.ietf.org/rfc/rfc2268.txt +http://people.csail.mit.edu/rivest/pubs/KRRR98.pdf + +This code is licensed under the MIT license. +*/ +package rc2 + +import ( + "crypto/cipher" + "encoding/binary" +) + +// The rc2 block size in bytes +const BlockSize = 8 + +type rc2Cipher struct { + k [64]uint16 +} + +// New returns a new rc2 cipher with the given key and effective key length t1 +func New(key []byte, t1 int) (cipher.Block, error) { + // TODO(dgryski): error checking for key length + return &rc2Cipher{ + k: expandKey(key, t1), + }, nil +} + +func (*rc2Cipher) BlockSize() int { return BlockSize } + +var piTable = [256]byte{ + 0xd9, 0x78, 0xf9, 0xc4, 0x19, 0xdd, 0xb5, 0xed, 0x28, 0xe9, 0xfd, 0x79, 0x4a, 0xa0, 0xd8, 0x9d, + 0xc6, 0x7e, 0x37, 0x83, 0x2b, 0x76, 0x53, 0x8e, 0x62, 0x4c, 0x64, 0x88, 0x44, 0x8b, 0xfb, 0xa2, + 0x17, 0x9a, 0x59, 0xf5, 0x87, 0xb3, 0x4f, 0x13, 0x61, 0x45, 0x6d, 0x8d, 0x09, 0x81, 0x7d, 0x32, + 0xbd, 0x8f, 0x40, 0xeb, 0x86, 0xb7, 0x7b, 0x0b, 0xf0, 0x95, 0x21, 0x22, 0x5c, 0x6b, 0x4e, 0x82, + 0x54, 0xd6, 0x65, 0x93, 0xce, 0x60, 0xb2, 0x1c, 0x73, 0x56, 0xc0, 0x14, 0xa7, 0x8c, 0xf1, 0xdc, + 0x12, 0x75, 0xca, 0x1f, 0x3b, 0xbe, 0xe4, 0xd1, 0x42, 0x3d, 0xd4, 0x30, 0xa3, 0x3c, 0xb6, 0x26, + 0x6f, 0xbf, 0x0e, 0xda, 0x46, 0x69, 0x07, 0x57, 0x27, 0xf2, 0x1d, 0x9b, 0xbc, 0x94, 0x43, 0x03, + 0xf8, 0x11, 0xc7, 0xf6, 0x90, 0xef, 0x3e, 0xe7, 0x06, 0xc3, 0xd5, 0x2f, 0xc8, 0x66, 0x1e, 0xd7, + 0x08, 0xe8, 0xea, 0xde, 0x80, 0x52, 0xee, 0xf7, 0x84, 0xaa, 0x72, 0xac, 0x35, 0x4d, 0x6a, 0x2a, + 0x96, 0x1a, 0xd2, 0x71, 0x5a, 0x15, 0x49, 0x74, 0x4b, 0x9f, 0xd0, 0x5e, 0x04, 0x18, 0xa4, 0xec, + 0xc2, 0xe0, 0x41, 0x6e, 0x0f, 0x51, 0xcb, 0xcc, 0x24, 0x91, 0xaf, 0x50, 0xa1, 0xf4, 0x70, 0x39, + 0x99, 0x7c, 0x3a, 0x85, 0x23, 0xb8, 0xb4, 0x7a, 0xfc, 0x02, 0x36, 0x5b, 0x25, 0x55, 0x97, 0x31, + 0x2d, 0x5d, 0xfa, 0x98, 0xe3, 0x8a, 0x92, 0xae, 0x05, 0xdf, 0x29, 0x10, 0x67, 0x6c, 0xba, 0xc9, + 0xd3, 0x00, 0xe6, 0xcf, 0xe1, 0x9e, 0xa8, 0x2c, 0x63, 0x16, 0x01, 0x3f, 0x58, 0xe2, 0x89, 0xa9, + 0x0d, 0x38, 0x34, 0x1b, 0xab, 0x33, 0xff, 0xb0, 0xbb, 0x48, 0x0c, 0x5f, 0xb9, 0xb1, 0xcd, 0x2e, + 0xc5, 0xf3, 0xdb, 0x47, 0xe5, 0xa5, 0x9c, 0x77, 0x0a, 0xa6, 0x20, 0x68, 0xfe, 0x7f, 0xc1, 0xad, +} + +func expandKey(key []byte, t1 int) [64]uint16 { + + l := make([]byte, 128) + copy(l, key) + + var t = len(key) + var t8 = (t1 + 7) / 8 + var tm = byte(255 % uint(1<<(8+uint(t1)-8*uint(t8)))) + + for i := len(key); i < 128; i++ { + l[i] = piTable[l[i-1]+l[uint8(i-t)]] + } + + l[128-t8] = piTable[l[128-t8]&tm] + + for i := 127 - t8; i >= 0; i-- { + l[i] = piTable[l[i+1]^l[i+t8]] + } + + var k [64]uint16 + + for i := range k { + k[i] = uint16(l[2*i]) + uint16(l[2*i+1])*256 + } + + return k +} + +func rotl16(x uint16, b uint) uint16 { + return (x >> (16 - b)) | (x << b) +} + +func (c *rc2Cipher) Encrypt(dst, src []byte) { + + r0 := binary.LittleEndian.Uint16(src[0:]) + r1 := binary.LittleEndian.Uint16(src[2:]) + r2 := binary.LittleEndian.Uint16(src[4:]) + r3 := binary.LittleEndian.Uint16(src[6:]) + + var j int + + for j <= 16 { + // mix r0 + r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1) + r0 = rotl16(r0, 1) + j++ + + // mix r1 + r1 = r1 + c.k[j] + (r0 & r3) + ((^r0) & r2) + r1 = rotl16(r1, 2) + j++ + + // mix r2 + r2 = r2 + c.k[j] + (r1 & r0) + ((^r1) & r3) + r2 = rotl16(r2, 3) + j++ + + // mix r3 + r3 = r3 + c.k[j] + (r2 & r1) + ((^r2) & r0) + r3 = rotl16(r3, 5) + j++ + + } + + r0 = r0 + c.k[r3&63] + r1 = r1 + c.k[r0&63] + r2 = r2 + c.k[r1&63] + r3 = r3 + c.k[r2&63] + + for j <= 40 { + // mix r0 + r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1) + r0 = rotl16(r0, 1) + j++ + + // mix r1 + r1 = r1 + c.k[j] + (r0 & r3) + ((^r0) & r2) + r1 = rotl16(r1, 2) + j++ + + // mix r2 + r2 = r2 + c.k[j] + (r1 & r0) + ((^r1) & r3) + r2 = rotl16(r2, 3) + j++ + + // mix r3 + r3 = r3 + c.k[j] + (r2 & r1) + ((^r2) & r0) + r3 = rotl16(r3, 5) + j++ + + } + + r0 = r0 + c.k[r3&63] + r1 = r1 + c.k[r0&63] + r2 = r2 + c.k[r1&63] + r3 = r3 + c.k[r2&63] + + for j <= 60 { + // mix r0 + r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1) + r0 = rotl16(r0, 1) + j++ + + // mix r1 + r1 = r1 + c.k[j] + (r0 & r3) + ((^r0) & r2) + r1 = rotl16(r1, 2) + j++ + + // mix r2 + r2 = r2 + c.k[j] + (r1 & r0) + ((^r1) & r3) + r2 = rotl16(r2, 3) + j++ + + // mix r3 + r3 = r3 + c.k[j] + (r2 & r1) + ((^r2) & r0) + r3 = rotl16(r3, 5) + j++ + } + + binary.LittleEndian.PutUint16(dst[0:], r0) + binary.LittleEndian.PutUint16(dst[2:], r1) + binary.LittleEndian.PutUint16(dst[4:], r2) + binary.LittleEndian.PutUint16(dst[6:], r3) +} + +func (c *rc2Cipher) Decrypt(dst, src []byte) { + + r0 := binary.LittleEndian.Uint16(src[0:]) + r1 := binary.LittleEndian.Uint16(src[2:]) + r2 := binary.LittleEndian.Uint16(src[4:]) + r3 := binary.LittleEndian.Uint16(src[6:]) + + j := 63 + + for j >= 44 { + // unmix r3 + r3 = rotl16(r3, 16-5) + r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0) + j-- + + // unmix r2 + r2 = rotl16(r2, 16-3) + r2 = r2 - c.k[j] - (r1 & r0) - ((^r1) & r3) + j-- + + // unmix r1 + r1 = rotl16(r1, 16-2) + r1 = r1 - c.k[j] - (r0 & r3) - ((^r0) & r2) + j-- + + // unmix r0 + r0 = rotl16(r0, 16-1) + r0 = r0 - c.k[j] - (r3 & r2) - ((^r3) & r1) + j-- + } + + r3 = r3 - c.k[r2&63] + r2 = r2 - c.k[r1&63] + r1 = r1 - c.k[r0&63] + r0 = r0 - c.k[r3&63] + + for j >= 20 { + // unmix r3 + r3 = rotl16(r3, 16-5) + r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0) + j-- + + // unmix r2 + r2 = rotl16(r2, 16-3) + r2 = r2 - c.k[j] - (r1 & r0) - ((^r1) & r3) + j-- + + // unmix r1 + r1 = rotl16(r1, 16-2) + r1 = r1 - c.k[j] - (r0 & r3) - ((^r0) & r2) + j-- + + // unmix r0 + r0 = rotl16(r0, 16-1) + r0 = r0 - c.k[j] - (r3 & r2) - ((^r3) & r1) + j-- + + } + + r3 = r3 - c.k[r2&63] + r2 = r2 - c.k[r1&63] + r1 = r1 - c.k[r0&63] + r0 = r0 - c.k[r3&63] + + for j >= 0 { + // unmix r3 + r3 = rotl16(r3, 16-5) + r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0) + j-- + + // unmix r2 + r2 = rotl16(r2, 16-3) + r2 = r2 - c.k[j] - (r1 & r0) - ((^r1) & r3) + j-- + + // unmix r1 + r1 = rotl16(r1, 16-2) + r1 = r1 - c.k[j] - (r0 & r3) - ((^r0) & r2) + j-- + + // unmix r0 + r0 = rotl16(r0, 16-1) + r0 = r0 - c.k[j] - (r3 & r2) - ((^r3) & r1) + j-- + + } + + binary.LittleEndian.PutUint16(dst[0:], r0) + binary.LittleEndian.PutUint16(dst[2:], r1) + binary.LittleEndian.PutUint16(dst[4:], r2) + binary.LittleEndian.PutUint16(dst[6:], r3) +} diff --git a/vendor/golang.org/x/crypto/pkcs12/mac.go b/vendor/golang.org/x/crypto/pkcs12/mac.go new file mode 100644 index 0000000000..5f38aa7de8 --- /dev/null +++ b/vendor/golang.org/x/crypto/pkcs12/mac.go @@ -0,0 +1,45 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pkcs12 + +import ( + "crypto/hmac" + "crypto/sha1" + "crypto/x509/pkix" + "encoding/asn1" +) + +type macData struct { + Mac digestInfo + MacSalt []byte + Iterations int `asn1:"optional,default:1"` +} + +// from PKCS#7: +type digestInfo struct { + Algorithm pkix.AlgorithmIdentifier + Digest []byte +} + +var ( + oidSHA1 = asn1.ObjectIdentifier([]int{1, 3, 14, 3, 2, 26}) +) + +func verifyMac(macData *macData, message, password []byte) error { + if !macData.Mac.Algorithm.Algorithm.Equal(oidSHA1) { + return NotImplementedError("unknown digest algorithm: " + macData.Mac.Algorithm.Algorithm.String()) + } + + key := pbkdf(sha1Sum, 20, 64, macData.MacSalt, password, macData.Iterations, 3, 20) + + mac := hmac.New(sha1.New, key) + mac.Write(message) + expectedMAC := mac.Sum(nil) + + if !hmac.Equal(macData.Mac.Digest, expectedMAC) { + return ErrIncorrectPassword + } + return nil +} diff --git a/vendor/golang.org/x/crypto/pkcs12/pbkdf.go b/vendor/golang.org/x/crypto/pkcs12/pbkdf.go new file mode 100644 index 0000000000..5c419d41e3 --- /dev/null +++ b/vendor/golang.org/x/crypto/pkcs12/pbkdf.go @@ -0,0 +1,170 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pkcs12 + +import ( + "bytes" + "crypto/sha1" + "math/big" +) + +var ( + one = big.NewInt(1) +) + +// sha1Sum returns the SHA-1 hash of in. +func sha1Sum(in []byte) []byte { + sum := sha1.Sum(in) + return sum[:] +} + +// fillWithRepeats returns v*ceiling(len(pattern) / v) bytes consisting of +// repeats of pattern. +func fillWithRepeats(pattern []byte, v int) []byte { + if len(pattern) == 0 { + return nil + } + outputLen := v * ((len(pattern) + v - 1) / v) + return bytes.Repeat(pattern, (outputLen+len(pattern)-1)/len(pattern))[:outputLen] +} + +func pbkdf(hash func([]byte) []byte, u, v int, salt, password []byte, r int, ID byte, size int) (key []byte) { + // implementation of https://tools.ietf.org/html/rfc7292#appendix-B.2 , RFC text verbatim in comments + + // Let H be a hash function built around a compression function f: + + // Z_2^u x Z_2^v -> Z_2^u + + // (that is, H has a chaining variable and output of length u bits, and + // the message input to the compression function of H is v bits). The + // values for u and v are as follows: + + // HASH FUNCTION VALUE u VALUE v + // MD2, MD5 128 512 + // SHA-1 160 512 + // SHA-224 224 512 + // SHA-256 256 512 + // SHA-384 384 1024 + // SHA-512 512 1024 + // SHA-512/224 224 1024 + // SHA-512/256 256 1024 + + // Furthermore, let r be the iteration count. + + // We assume here that u and v are both multiples of 8, as are the + // lengths of the password and salt strings (which we denote by p and s, + // respectively) and the number n of pseudorandom bits required. In + // addition, u and v are of course non-zero. + + // For information on security considerations for MD5 [19], see [25] and + // [1], and on those for MD2, see [18]. + + // The following procedure can be used to produce pseudorandom bits for + // a particular "purpose" that is identified by a byte called "ID". + // This standard specifies 3 different values for the ID byte: + + // 1. If ID=1, then the pseudorandom bits being produced are to be used + // as key material for performing encryption or decryption. + + // 2. If ID=2, then the pseudorandom bits being produced are to be used + // as an IV (Initial Value) for encryption or decryption. + + // 3. If ID=3, then the pseudorandom bits being produced are to be used + // as an integrity key for MACing. + + // 1. Construct a string, D (the "diversifier"), by concatenating v/8 + // copies of ID. + var D []byte + for i := 0; i < v; i++ { + D = append(D, ID) + } + + // 2. Concatenate copies of the salt together to create a string S of + // length v(ceiling(s/v)) bits (the final copy of the salt may be + // truncated to create S). Note that if the salt is the empty + // string, then so is S. + + S := fillWithRepeats(salt, v) + + // 3. Concatenate copies of the password together to create a string P + // of length v(ceiling(p/v)) bits (the final copy of the password + // may be truncated to create P). Note that if the password is the + // empty string, then so is P. + + P := fillWithRepeats(password, v) + + // 4. Set I=S||P to be the concatenation of S and P. + I := append(S, P...) + + // 5. Set c=ceiling(n/u). + c := (size + u - 1) / u + + // 6. For i=1, 2, ..., c, do the following: + A := make([]byte, c*20) + var IjBuf []byte + for i := 0; i < c; i++ { + // A. Set A2=H^r(D||I). (i.e., the r-th hash of D||1, + // H(H(H(... H(D||I)))) + Ai := hash(append(D, I...)) + for j := 1; j < r; j++ { + Ai = hash(Ai) + } + copy(A[i*20:], Ai[:]) + + if i < c-1 { // skip on last iteration + // B. Concatenate copies of Ai to create a string B of length v + // bits (the final copy of Ai may be truncated to create B). + var B []byte + for len(B) < v { + B = append(B, Ai[:]...) + } + B = B[:v] + + // C. Treating I as a concatenation I_0, I_1, ..., I_(k-1) of v-bit + // blocks, where k=ceiling(s/v)+ceiling(p/v), modify I by + // setting I_j=(I_j+B+1) mod 2^v for each j. + { + Bbi := new(big.Int).SetBytes(B) + Ij := new(big.Int) + + for j := 0; j < len(I)/v; j++ { + Ij.SetBytes(I[j*v : (j+1)*v]) + Ij.Add(Ij, Bbi) + Ij.Add(Ij, one) + Ijb := Ij.Bytes() + // We expect Ijb to be exactly v bytes, + // if it is longer or shorter we must + // adjust it accordingly. + if len(Ijb) > v { + Ijb = Ijb[len(Ijb)-v:] + } + if len(Ijb) < v { + if IjBuf == nil { + IjBuf = make([]byte, v) + } + bytesShort := v - len(Ijb) + for i := 0; i < bytesShort; i++ { + IjBuf[i] = 0 + } + copy(IjBuf[bytesShort:], Ijb) + Ijb = IjBuf + } + copy(I[j*v:(j+1)*v], Ijb) + } + } + } + } + // 7. Concatenate A_1, A_2, ..., A_c together to form a pseudorandom + // bit string, A. + + // 8. Use the first n bits of A as the output of this entire process. + return A[:size] + + // If the above process is being used to generate a DES key, the process + // should be used to create 64 random bits, and the key's parity bits + // should be set after the 64 bits have been produced. Similar concerns + // hold for 2-key and 3-key triple-DES keys, for CDMF keys, and for any + // similar keys with parity bits "built into them". +} diff --git a/vendor/golang.org/x/crypto/pkcs12/pkcs12.go b/vendor/golang.org/x/crypto/pkcs12/pkcs12.go new file mode 100644 index 0000000000..3a89bdb3e3 --- /dev/null +++ b/vendor/golang.org/x/crypto/pkcs12/pkcs12.go @@ -0,0 +1,360 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package pkcs12 implements some of PKCS#12. +// +// This implementation is distilled from https://tools.ietf.org/html/rfc7292 +// and referenced documents. It is intended for decoding P12/PFX-stored +// certificates and keys for use with the crypto/tls package. +// +// This package is frozen. If it's missing functionality you need, consider +// an alternative like software.sslmate.com/src/go-pkcs12. +package pkcs12 + +import ( + "crypto/ecdsa" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/hex" + "encoding/pem" + "errors" +) + +var ( + oidDataContentType = asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 7, 1}) + oidEncryptedDataContentType = asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 7, 6}) + + oidFriendlyName = asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 9, 20}) + oidLocalKeyID = asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 9, 21}) + oidMicrosoftCSPName = asn1.ObjectIdentifier([]int{1, 3, 6, 1, 4, 1, 311, 17, 1}) + + errUnknownAttributeOID = errors.New("pkcs12: unknown attribute OID") +) + +type pfxPdu struct { + Version int + AuthSafe contentInfo + MacData macData `asn1:"optional"` +} + +type contentInfo struct { + ContentType asn1.ObjectIdentifier + Content asn1.RawValue `asn1:"tag:0,explicit,optional"` +} + +type encryptedData struct { + Version int + EncryptedContentInfo encryptedContentInfo +} + +type encryptedContentInfo struct { + ContentType asn1.ObjectIdentifier + ContentEncryptionAlgorithm pkix.AlgorithmIdentifier + EncryptedContent []byte `asn1:"tag:0,optional"` +} + +func (i encryptedContentInfo) Algorithm() pkix.AlgorithmIdentifier { + return i.ContentEncryptionAlgorithm +} + +func (i encryptedContentInfo) Data() []byte { return i.EncryptedContent } + +type safeBag struct { + Id asn1.ObjectIdentifier + Value asn1.RawValue `asn1:"tag:0,explicit"` + Attributes []pkcs12Attribute `asn1:"set,optional"` +} + +type pkcs12Attribute struct { + Id asn1.ObjectIdentifier + Value asn1.RawValue `asn1:"set"` +} + +type encryptedPrivateKeyInfo struct { + AlgorithmIdentifier pkix.AlgorithmIdentifier + EncryptedData []byte +} + +func (i encryptedPrivateKeyInfo) Algorithm() pkix.AlgorithmIdentifier { + return i.AlgorithmIdentifier +} + +func (i encryptedPrivateKeyInfo) Data() []byte { + return i.EncryptedData +} + +// PEM block types +const ( + certificateType = "CERTIFICATE" + privateKeyType = "PRIVATE KEY" +) + +// unmarshal calls asn1.Unmarshal, but also returns an error if there is any +// trailing data after unmarshaling. +func unmarshal(in []byte, out interface{}) error { + trailing, err := asn1.Unmarshal(in, out) + if err != nil { + return err + } + if len(trailing) != 0 { + return errors.New("pkcs12: trailing data found") + } + return nil +} + +// ToPEM converts all "safe bags" contained in pfxData to PEM blocks. +// Unknown attributes are discarded. +// +// Note that although the returned PEM blocks for private keys have type +// "PRIVATE KEY", the bytes are not encoded according to PKCS #8, but according +// to PKCS #1 for RSA keys and SEC 1 for ECDSA keys. +func ToPEM(pfxData []byte, password string) ([]*pem.Block, error) { + encodedPassword, err := bmpString(password) + if err != nil { + return nil, ErrIncorrectPassword + } + + bags, encodedPassword, err := getSafeContents(pfxData, encodedPassword) + + if err != nil { + return nil, err + } + + blocks := make([]*pem.Block, 0, len(bags)) + for _, bag := range bags { + block, err := convertBag(&bag, encodedPassword) + if err != nil { + return nil, err + } + blocks = append(blocks, block) + } + + return blocks, nil +} + +func convertBag(bag *safeBag, password []byte) (*pem.Block, error) { + block := &pem.Block{ + Headers: make(map[string]string), + } + + for _, attribute := range bag.Attributes { + k, v, err := convertAttribute(&attribute) + if err == errUnknownAttributeOID { + continue + } + if err != nil { + return nil, err + } + block.Headers[k] = v + } + + switch { + case bag.Id.Equal(oidCertBag): + block.Type = certificateType + certsData, err := decodeCertBag(bag.Value.Bytes) + if err != nil { + return nil, err + } + block.Bytes = certsData + case bag.Id.Equal(oidPKCS8ShroundedKeyBag): + block.Type = privateKeyType + + key, err := decodePkcs8ShroudedKeyBag(bag.Value.Bytes, password) + if err != nil { + return nil, err + } + + switch key := key.(type) { + case *rsa.PrivateKey: + block.Bytes = x509.MarshalPKCS1PrivateKey(key) + case *ecdsa.PrivateKey: + block.Bytes, err = x509.MarshalECPrivateKey(key) + if err != nil { + return nil, err + } + default: + return nil, errors.New("found unknown private key type in PKCS#8 wrapping") + } + default: + return nil, errors.New("don't know how to convert a safe bag of type " + bag.Id.String()) + } + return block, nil +} + +func convertAttribute(attribute *pkcs12Attribute) (key, value string, err error) { + isString := false + + switch { + case attribute.Id.Equal(oidFriendlyName): + key = "friendlyName" + isString = true + case attribute.Id.Equal(oidLocalKeyID): + key = "localKeyId" + case attribute.Id.Equal(oidMicrosoftCSPName): + // This key is chosen to match OpenSSL. + key = "Microsoft CSP Name" + isString = true + default: + return "", "", errUnknownAttributeOID + } + + if isString { + if err := unmarshal(attribute.Value.Bytes, &attribute.Value); err != nil { + return "", "", err + } + if value, err = decodeBMPString(attribute.Value.Bytes); err != nil { + return "", "", err + } + } else { + var id []byte + if err := unmarshal(attribute.Value.Bytes, &id); err != nil { + return "", "", err + } + value = hex.EncodeToString(id) + } + + return key, value, nil +} + +// Decode extracts a certificate and private key from pfxData. This function +// assumes that there is only one certificate and only one private key in the +// pfxData; if there are more use ToPEM instead. +func Decode(pfxData []byte, password string) (privateKey interface{}, certificate *x509.Certificate, err error) { + encodedPassword, err := bmpString(password) + if err != nil { + return nil, nil, err + } + + bags, encodedPassword, err := getSafeContents(pfxData, encodedPassword) + if err != nil { + return nil, nil, err + } + + if len(bags) != 2 { + err = errors.New("pkcs12: expected exactly two safe bags in the PFX PDU") + return + } + + for _, bag := range bags { + switch { + case bag.Id.Equal(oidCertBag): + if certificate != nil { + err = errors.New("pkcs12: expected exactly one certificate bag") + } + + certsData, err := decodeCertBag(bag.Value.Bytes) + if err != nil { + return nil, nil, err + } + certs, err := x509.ParseCertificates(certsData) + if err != nil { + return nil, nil, err + } + if len(certs) != 1 { + err = errors.New("pkcs12: expected exactly one certificate in the certBag") + return nil, nil, err + } + certificate = certs[0] + + case bag.Id.Equal(oidPKCS8ShroundedKeyBag): + if privateKey != nil { + err = errors.New("pkcs12: expected exactly one key bag") + return nil, nil, err + } + + if privateKey, err = decodePkcs8ShroudedKeyBag(bag.Value.Bytes, encodedPassword); err != nil { + return nil, nil, err + } + } + } + + if certificate == nil { + return nil, nil, errors.New("pkcs12: certificate missing") + } + if privateKey == nil { + return nil, nil, errors.New("pkcs12: private key missing") + } + + return +} + +func getSafeContents(p12Data, password []byte) (bags []safeBag, updatedPassword []byte, err error) { + pfx := new(pfxPdu) + if err := unmarshal(p12Data, pfx); err != nil { + return nil, nil, errors.New("pkcs12: error reading P12 data: " + err.Error()) + } + + if pfx.Version != 3 { + return nil, nil, NotImplementedError("can only decode v3 PFX PDU's") + } + + if !pfx.AuthSafe.ContentType.Equal(oidDataContentType) { + return nil, nil, NotImplementedError("only password-protected PFX is implemented") + } + + // unmarshal the explicit bytes in the content for type 'data' + if err := unmarshal(pfx.AuthSafe.Content.Bytes, &pfx.AuthSafe.Content); err != nil { + return nil, nil, err + } + + if len(pfx.MacData.Mac.Algorithm.Algorithm) == 0 { + return nil, nil, errors.New("pkcs12: no MAC in data") + } + + if err := verifyMac(&pfx.MacData, pfx.AuthSafe.Content.Bytes, password); err != nil { + if err == ErrIncorrectPassword && len(password) == 2 && password[0] == 0 && password[1] == 0 { + // some implementations use an empty byte array + // for the empty string password try one more + // time with empty-empty password + password = nil + err = verifyMac(&pfx.MacData, pfx.AuthSafe.Content.Bytes, password) + } + if err != nil { + return nil, nil, err + } + } + + var authenticatedSafe []contentInfo + if err := unmarshal(pfx.AuthSafe.Content.Bytes, &authenticatedSafe); err != nil { + return nil, nil, err + } + + if len(authenticatedSafe) != 2 { + return nil, nil, NotImplementedError("expected exactly two items in the authenticated safe") + } + + for _, ci := range authenticatedSafe { + var data []byte + + switch { + case ci.ContentType.Equal(oidDataContentType): + if err := unmarshal(ci.Content.Bytes, &data); err != nil { + return nil, nil, err + } + case ci.ContentType.Equal(oidEncryptedDataContentType): + var encryptedData encryptedData + if err := unmarshal(ci.Content.Bytes, &encryptedData); err != nil { + return nil, nil, err + } + if encryptedData.Version != 0 { + return nil, nil, NotImplementedError("only version 0 of EncryptedData is supported") + } + if data, err = pbDecrypt(encryptedData.EncryptedContentInfo, password); err != nil { + return nil, nil, err + } + default: + return nil, nil, NotImplementedError("only data and encryptedData content types are supported in authenticated safe") + } + + var safeContents []safeBag + if err := unmarshal(data, &safeContents); err != nil { + return nil, nil, err + } + bags = append(bags, safeContents...) + } + + return bags, password, nil +} diff --git a/vendor/golang.org/x/crypto/pkcs12/safebags.go b/vendor/golang.org/x/crypto/pkcs12/safebags.go new file mode 100644 index 0000000000..def1f7b98d --- /dev/null +++ b/vendor/golang.org/x/crypto/pkcs12/safebags.go @@ -0,0 +1,57 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pkcs12 + +import ( + "crypto/x509" + "encoding/asn1" + "errors" +) + +var ( + // see https://tools.ietf.org/html/rfc7292#appendix-D + oidCertTypeX509Certificate = asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 9, 22, 1}) + oidPKCS8ShroundedKeyBag = asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 12, 10, 1, 2}) + oidCertBag = asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 12, 10, 1, 3}) +) + +type certBag struct { + Id asn1.ObjectIdentifier + Data []byte `asn1:"tag:0,explicit"` +} + +func decodePkcs8ShroudedKeyBag(asn1Data, password []byte) (privateKey interface{}, err error) { + pkinfo := new(encryptedPrivateKeyInfo) + if err = unmarshal(asn1Data, pkinfo); err != nil { + return nil, errors.New("pkcs12: error decoding PKCS#8 shrouded key bag: " + err.Error()) + } + + pkData, err := pbDecrypt(pkinfo, password) + if err != nil { + return nil, errors.New("pkcs12: error decrypting PKCS#8 shrouded key bag: " + err.Error()) + } + + ret := new(asn1.RawValue) + if err = unmarshal(pkData, ret); err != nil { + return nil, errors.New("pkcs12: error unmarshaling decrypted private key: " + err.Error()) + } + + if privateKey, err = x509.ParsePKCS8PrivateKey(pkData); err != nil { + return nil, errors.New("pkcs12: error parsing PKCS#8 private key: " + err.Error()) + } + + return privateKey, nil +} + +func decodeCertBag(asn1Data []byte) (x509Certificates []byte, err error) { + bag := new(certBag) + if err := unmarshal(asn1Data, bag); err != nil { + return nil, errors.New("pkcs12: error decoding cert bag: " + err.Error()) + } + if !bag.Id.Equal(oidCertTypeX509Certificate) { + return nil, NotImplementedError("only X509 certificates are supported") + } + return bag.Data, nil +} diff --git a/vendor/golang.org/x/crypto/ssh/test/doc.go b/vendor/golang.org/x/crypto/ssh/test/doc.go new file mode 100644 index 0000000000..198f0ca1e2 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/test/doc.go @@ -0,0 +1,7 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package test contains integration tests for the +// golang.org/x/crypto/ssh package. +package test // import "golang.org/x/crypto/ssh/test" diff --git a/vendor/golang.org/x/crypto/ssh/test/sshd_test_pw.c b/vendor/golang.org/x/crypto/ssh/test/sshd_test_pw.c new file mode 100644 index 0000000000..2794a563a4 --- /dev/null +++ b/vendor/golang.org/x/crypto/ssh/test/sshd_test_pw.c @@ -0,0 +1,173 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// sshd_test_pw.c +// Wrapper to inject test password data for sshd PAM authentication +// +// This wrapper implements custom versions of getpwnam, getpwnam_r, +// getspnam and getspnam_r. These functions first call their real +// libc versions, then check if the requested user matches test user +// specified in env variable TEST_USER and if so replace the password +// with crypted() value of TEST_PASSWD env variable. +// +// Compile: +// gcc -Wall -shared -o sshd_test_pw.so -fPIC sshd_test_pw.c +// +// Compile with debug: +// gcc -DVERBOSE -Wall -shared -o sshd_test_pw.so -fPIC sshd_test_pw.c +// +// Run sshd: +// LD_PRELOAD="sshd_test_pw.so" TEST_USER="..." TEST_PASSWD="..." sshd ... + +// +build ignore + +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include + +#ifdef VERBOSE +#define DEBUG(X...) fprintf(stderr, X) +#else +#define DEBUG(X...) while (0) { } +#endif + +/* crypt() password */ +static char * +pwhash(char *passwd) { + return strdup(crypt(passwd, "$6$")); +} + +/* Pointers to real functions in libc */ +static struct passwd * (*real_getpwnam)(const char *) = NULL; +static int (*real_getpwnam_r)(const char *, struct passwd *, char *, size_t, struct passwd **) = NULL; +static struct spwd * (*real_getspnam)(const char *) = NULL; +static int (*real_getspnam_r)(const char *, struct spwd *, char *, size_t, struct spwd **) = NULL; + +/* Cached test user and test password */ +static char *test_user = NULL; +static char *test_passwd_hash = NULL; + +static void +init(void) { + /* Fetch real libc function pointers */ + real_getpwnam = dlsym(RTLD_NEXT, "getpwnam"); + real_getpwnam_r = dlsym(RTLD_NEXT, "getpwnam_r"); + real_getspnam = dlsym(RTLD_NEXT, "getspnam"); + real_getspnam_r = dlsym(RTLD_NEXT, "getspnam_r"); + + /* abort if env variables are not defined */ + if (getenv("TEST_USER") == NULL || getenv("TEST_PASSWD") == NULL) { + fprintf(stderr, "env variables TEST_USER and TEST_PASSWD are missing\n"); + abort(); + } + + /* Fetch test user and test password from env */ + test_user = strdup(getenv("TEST_USER")); + test_passwd_hash = pwhash(getenv("TEST_PASSWD")); + + DEBUG("sshd_test_pw init():\n"); + DEBUG("\treal_getpwnam: %p\n", real_getpwnam); + DEBUG("\treal_getpwnam_r: %p\n", real_getpwnam_r); + DEBUG("\treal_getspnam: %p\n", real_getspnam); + DEBUG("\treal_getspnam_r: %p\n", real_getspnam_r); + DEBUG("\tTEST_USER: '%s'\n", test_user); + DEBUG("\tTEST_PASSWD: '%s'\n", getenv("TEST_PASSWD")); + DEBUG("\tTEST_PASSWD_HASH: '%s'\n", test_passwd_hash); +} + +static int +is_test_user(const char *name) { + if (test_user != NULL && strcmp(test_user, name) == 0) + return 1; + return 0; +} + +/* getpwnam */ + +struct passwd * +getpwnam(const char *name) { + struct passwd *pw; + + DEBUG("sshd_test_pw getpwnam(%s)\n", name); + + if (real_getpwnam == NULL) + init(); + if ((pw = real_getpwnam(name)) == NULL) + return NULL; + + if (is_test_user(name)) + pw->pw_passwd = strdup(test_passwd_hash); + + return pw; +} + +/* getpwnam_r */ + +int +getpwnam_r(const char *name, + struct passwd *pwd, + char *buf, + size_t buflen, + struct passwd **result) { + int r; + + DEBUG("sshd_test_pw getpwnam_r(%s)\n", name); + + if (real_getpwnam_r == NULL) + init(); + if ((r = real_getpwnam_r(name, pwd, buf, buflen, result)) != 0 || *result == NULL) + return r; + + if (is_test_user(name)) + pwd->pw_passwd = strdup(test_passwd_hash); + + return 0; +} + +/* getspnam */ + +struct spwd * +getspnam(const char *name) { + struct spwd *sp; + + DEBUG("sshd_test_pw getspnam(%s)\n", name); + + if (real_getspnam == NULL) + init(); + if ((sp = real_getspnam(name)) == NULL) + return NULL; + + if (is_test_user(name)) + sp->sp_pwdp = strdup(test_passwd_hash); + + return sp; +} + +/* getspnam_r */ + +int +getspnam_r(const char *name, + struct spwd *spbuf, + char *buf, + size_t buflen, + struct spwd **spbufp) { + int r; + + DEBUG("sshd_test_pw getspnam_r(%s)\n", name); + + if (real_getspnam_r == NULL) + init(); + if ((r = real_getspnam_r(name, spbuf, buf, buflen, spbufp)) != 0) + return r; + + if (is_test_user(name)) + spbuf->sp_pwdp = strdup(test_passwd_hash); + + return r; +}