Skip to content

Commit

Permalink
add fallback code for STS ErrorResponse v/s S3 Error (#1607)
Browse files Browse the repository at this point in the history
  • Loading branch information
harshavardhana committed Jan 11, 2022
1 parent 35844af commit 19a708d
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 18 deletions.
13 changes: 12 additions & 1 deletion pkg/credentials/assume_role.go
Expand Up @@ -18,6 +18,7 @@
package credentials

import (
"bytes"
"encoding/hex"
"encoding/xml"
"errors"
Expand Down Expand Up @@ -185,10 +186,20 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume
defer closeResponse(resp)
if resp.StatusCode != http.StatusOK {
var errResp ErrorResponse
_, err = xmlDecodeAndBody(resp.Body, &errResp)
buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
return AssumeRoleResponse{}, err
}
_, err = xmlDecodeAndBody(bytes.NewReader(buf), &errResp)
if err != nil {
var s3Err Error
if _, err = xmlDecodeAndBody(bytes.NewReader(buf), &s3Err); err != nil {
return AssumeRoleResponse{}, err
}
errResp.RequestID = s3Err.RequestID
errResp.STSError.Code = s3Err.Code
errResp.STSError.Message = s3Err.Message
}
return AssumeRoleResponse{}, errResp
}

Expand Down
30 changes: 30 additions & 0 deletions pkg/credentials/error_response.go
Expand Up @@ -38,6 +38,36 @@ type ErrorResponse struct {
RequestID string `xml:"RequestId"`
}

// Error - Is the typed error returned by all API operations.
type Error struct {
XMLName xml.Name `xml:"Error" json:"-"`
Code string
Message string
BucketName string
Key string
Resource string
RequestID string `xml:"RequestId"`
HostID string `xml:"HostId"`

// Region where the bucket is located. This header is returned
// only in HEAD bucket and ListObjects response.
Region string

// Captures the server string returned in response header.
Server string

// Underlying HTTP status code for the returned error
StatusCode int `xml:"-" json:"-"`
}

// Error - Returns S3 error string.
func (e Error) Error() string {
if e.Message == "" {
return fmt.Sprintf("Error response code %s.", e.Code)
}
return e.Message
}

// Error - Returns STS error string.
func (e ErrorResponse) Error() string {
if e.STSError.Message == "" {
Expand Down
20 changes: 15 additions & 5 deletions pkg/credentials/sts_client_grants.go
Expand Up @@ -18,9 +18,11 @@
package credentials

import (
"bytes"
"encoding/xml"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
Expand Down Expand Up @@ -133,12 +135,20 @@ func getClientGrantsCredentials(clnt *http.Client, endpoint string,
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
var errResp ErrorResponse
_, err = xmlDecodeAndBody(resp.Body, &errResp)
buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
errResp := ErrorResponse{}
errResp.STSError.Code = "InvalidArgument"
errResp.STSError.Message = err.Error()
return AssumeRoleWithClientGrantsResponse{}, errResp
return AssumeRoleWithClientGrantsResponse{}, err

}
_, err = xmlDecodeAndBody(bytes.NewReader(buf), &errResp)
if err != nil {
var s3Err Error
if _, err = xmlDecodeAndBody(bytes.NewReader(buf), &s3Err); err != nil {
return AssumeRoleWithClientGrantsResponse{}, err
}
errResp.RequestID = s3Err.RequestID
errResp.STSError.Code = s3Err.Code
errResp.STSError.Message = s3Err.Message
}
return AssumeRoleWithClientGrantsResponse{}, errResp
}
Expand Down
19 changes: 15 additions & 4 deletions pkg/credentials/sts_ldap_identity.go
Expand Up @@ -18,8 +18,10 @@
package credentials

import (
"bytes"
"encoding/xml"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
Expand Down Expand Up @@ -169,11 +171,20 @@ func (k *LDAPIdentity) Retrieve() (value Value, err error) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
var errResp ErrorResponse
_, err = xmlDecodeAndBody(resp.Body, &errResp)
buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
errResp.STSError.Code = "InvalidArgument"
errResp.STSError.Message = err.Error()
return value, errResp
return value, err

}
_, err = xmlDecodeAndBody(bytes.NewReader(buf), &errResp)
if err != nil {
var s3Err Error
if _, err = xmlDecodeAndBody(bytes.NewReader(buf), &s3Err); err != nil {
return value, err
}
errResp.RequestID = s3Err.RequestID
errResp.STSError.Code = s3Err.Code
errResp.STSError.Message = s3Err.Message
}
return value, errResp
}
Expand Down
19 changes: 15 additions & 4 deletions pkg/credentials/sts_tls_identity.go
Expand Up @@ -16,10 +16,12 @@
package credentials

import (
"bytes"
"crypto/tls"
"encoding/xml"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -150,11 +152,20 @@ func (i *STSCertificateIdentity) Retrieve() (Value, error) {
}
if resp.StatusCode != http.StatusOK {
var errResp ErrorResponse
_, err = xmlDecodeAndBody(resp.Body, &errResp)
buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
errResp.STSError.Code = "InvalidArgument"
errResp.STSError.Message = err.Error()
return Value{}, errResp
return Value{}, err

}
_, err = xmlDecodeAndBody(bytes.NewReader(buf), &errResp)
if err != nil {
var s3Err Error
if _, err = xmlDecodeAndBody(bytes.NewReader(buf), &s3Err); err != nil {
return Value{}, err
}
errResp.RequestID = s3Err.RequestID
errResp.STSError.Code = s3Err.Code
errResp.STSError.Message = s3Err.Message
}
return Value{}, errResp
}
Expand Down
19 changes: 15 additions & 4 deletions pkg/credentials/sts_web_identity.go
Expand Up @@ -18,9 +18,11 @@
package credentials

import (
"bytes"
"encoding/xml"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -151,11 +153,20 @@ func getWebIdentityCredentials(clnt *http.Client, endpoint, roleARN, roleSession
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
var errResp ErrorResponse
_, err = xmlDecodeAndBody(resp.Body, &errResp)
buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
errResp.STSError.Code = "InvalidArgument"
errResp.STSError.Message = err.Error()
return AssumeRoleWithWebIdentityResponse{}, errResp
return AssumeRoleWithWebIdentityResponse{}, err

}
_, err = xmlDecodeAndBody(bytes.NewReader(buf), &errResp)
if err != nil {
var s3Err Error
if _, err = xmlDecodeAndBody(bytes.NewReader(buf), &s3Err); err != nil {
return AssumeRoleWithWebIdentityResponse{}, err
}
errResp.RequestID = s3Err.RequestID
errResp.STSError.Code = s3Err.Code
errResp.STSError.Message = s3Err.Message
}
return AssumeRoleWithWebIdentityResponse{}, errResp
}
Expand Down

0 comments on commit 19a708d

Please sign in to comment.