diff --git a/credentials/credentials.go b/credentials/credentials.go index 53f8fcdc9a0d..53addd8c71e9 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -27,7 +27,6 @@ import ( "errors" "fmt" "net" - "net/url" "github.com/golang/protobuf/proto" "google.golang.org/grpc/attributes" @@ -88,7 +87,6 @@ func (s SecurityLevel) String() string { // This API is experimental. type CommonAuthInfo struct { SecurityLevel SecurityLevel - SPIFFEID *url.URL } // GetCommonAuthInfo returns the pointer to CommonAuthInfo struct. diff --git a/credentials/tls.go b/credentials/tls.go index 1db95e537387..1ba6f3a6b8f8 100644 --- a/credentials/tls.go +++ b/credentials/tls.go @@ -25,6 +25,7 @@ import ( "fmt" "io/ioutil" "net" + "net/url" "google.golang.org/grpc/credentials/internal" credinternal "google.golang.org/grpc/internal/credentials" @@ -35,6 +36,8 @@ import ( type TLSInfo struct { State tls.ConnectionState CommonAuthInfo + // This API is experimental. + SPIFFEID *url.URL } // AuthType returns the type of TLSInfo as a string. @@ -101,10 +104,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon SecurityLevel: PrivacyAndIntegrity, }, } - id, err := credinternal.SPIFFEIDFromState(conn.ConnectionState()) - if err != nil { - return nil, nil, err - } + id := credinternal.SPIFFEIDFromState(conn.ConnectionState()) if id != nil { tlsInfo.SPIFFEID = id } @@ -123,10 +123,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) SecurityLevel: PrivacyAndIntegrity, }, } - id, err := credinternal.SPIFFEIDFromState(conn.ConnectionState()) - if err != nil { - return nil, nil, err - } + id := credinternal.SPIFFEIDFromState(conn.ConnectionState()) if id != nil { tlsInfo.SPIFFEID = id } diff --git a/internal/credentials/go110.go b/internal/credentials/go110.go index 83a6559f2c95..f8d67399e775 100644 --- a/internal/credentials/go110.go +++ b/internal/credentials/go110.go @@ -25,20 +25,19 @@ package credentials import ( "crypto/tls" - "fmt" "net/url" "google.golang.org/grpc/grpclog" ) -// SPIFFEIDFromState parses the SPIFFE ID from State. An error is returned only -// when we are sure SPIFFE ID is used but the format is wrong. -func SPIFFEIDFromState(state tls.ConnectionState) (*url.URL, error) { +// SPIFFEIDFromState parses the SPIFFE ID from State. If the SPIFFE ID format +// is invalid, return nil with warning. +func SPIFFEIDFromState(state tls.ConnectionState) *url.URL { if len(state.PeerCertificates) == 0 || len(state.PeerCertificates[0].URIs) == 0 { - return nil, nil + return nil } - spiffeIDCnt := 0 - var spiffeID url.URL + spiffeIDFound := false + var spiffeID *url.URL for _, uri := range state.PeerCertificates[0].URIs { if uri == nil || uri.Scheme != "spiffe" || uri.Opaque != "" || (uri.User != nil && uri.User.Username() != "") { continue @@ -46,27 +45,24 @@ func SPIFFEIDFromState(state tls.ConnectionState) (*url.URL, error) { // From this point, we assume the uri is intended for a SPIFFE ID. if len(uri.Host)+len(uri.Scheme)+len(uri.RawPath)+4 > 2048 || len(uri.Host)+len(uri.Scheme)+len(uri.Path)+4 > 2048 { - return nil, fmt.Errorf("invalid SPIFFE ID: total ID length larger than 2048 bytes") + grpclog.Warning("invalid SPIFFE ID: total ID length larger than 2048 bytes") + return nil } if len(uri.Host) == 0 || len(uri.RawPath) == 0 || len(uri.Path) == 0 { - return nil, fmt.Errorf("invalid SPIFFE ID: domain or workload ID is empty") + grpclog.Warning("invalid SPIFFE ID: domain or workload ID is empty") + return nil } if len(uri.Host) > 255 { - return nil, fmt.Errorf("invalid SPIFFE ID: domain length larger than 255 characters") + grpclog.Warning("invalid SPIFFE ID: domain length larger than 255 characters") + return nil } - // We use a default deep copy since we know the User field of a SPIFFE ID - // is empty. - spiffeID = *uri - spiffeIDCnt++ - } - if spiffeIDCnt == 1 { - return &spiffeID, nil - } else if spiffeIDCnt > 1 { - // A standard SPIFFE ID should be unique. If there are more than one ID, we - // should log this error but shouldn't halt the application. - grpclog.Warning("invalid SPIFFE ID: multiple SPIFFE IDs") - return nil, nil + // A valid SPIFFE certificate can only have exactly one URI SAN field. + if spiffeIDFound && len(state.PeerCertificates[0].URIs) > 1 { + grpclog.Warning("invalid SPIFFE ID: multiple URI SANs") + return nil + } + spiffeID = uri + spiffeIDFound = true } - // SPIFFE ID is not used. - return nil, nil + return spiffeID } diff --git a/internal/credentials/go110_test.go b/internal/credentials/go110_test.go index 5fb4b3296103..1e03dcdb5fb8 100644 --- a/internal/credentials/go110_test.go +++ b/internal/credentials/go110_test.go @@ -41,16 +41,13 @@ func (s) TestSPIFFEIDFromState(t *testing.T) { tests := []struct { name string urls []*url.URL - // If we expect SPIFFEIDFromState to return an error. - expectError bool // If we expect a SPIFFE ID to be returned. expectID bool }{ { - name: "empty URIs", - urls: []*url.URL{}, - expectError: false, - expectID: false, + name: "empty URIs", + urls: []*url.URL{}, + expectID: false, }, { name: "good SPIFFE ID", @@ -61,15 +58,8 @@ func (s) TestSPIFFEIDFromState(t *testing.T) { Path: "workload/wl1", RawPath: "workload/wl1", }, - { - Scheme: "https", - Host: "foo.bar.com", - Path: "workload/wl1", - RawPath: "workload/wl1", - }, }, - expectError: false, - expectID: true, + expectID: true, }, { name: "invalid host", @@ -81,8 +71,7 @@ func (s) TestSPIFFEIDFromState(t *testing.T) { RawPath: "workload/wl1", }, }, - expectError: true, - expectID: false, + expectID: false, }, { name: "invalid path", @@ -94,8 +83,7 @@ func (s) TestSPIFFEIDFromState(t *testing.T) { RawPath: "", }, }, - expectError: true, - expectID: false, + expectID: false, }, { name: "large path", @@ -107,8 +95,7 @@ func (s) TestSPIFFEIDFromState(t *testing.T) { RawPath: string(make([]byte, 2050)), }, }, - expectError: true, - expectID: false, + expectID: false, }, { name: "large host", @@ -120,11 +107,10 @@ func (s) TestSPIFFEIDFromState(t *testing.T) { RawPath: "workload/wl1", }, }, - expectError: true, - expectID: false, + expectID: false, }, { - name: "multiple SPIFFE IDs", + name: "multiple URI SANs with SPIFFE ID", urls: []*url.URL{ { Scheme: "spiffe", @@ -138,18 +124,38 @@ func (s) TestSPIFFEIDFromState(t *testing.T) { Path: "workload/wl2", RawPath: "workload/wl2", }, + { + Scheme: "https", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + expectID: false, + }, + { + name: "multiple URI SANs without SPIFFE ID", + urls: []*url.URL{ + { + Scheme: "https", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + { + Scheme: "ssh", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, }, - expectError: false, - expectID: false, + expectID: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { state := tls.ConnectionState{PeerCertificates: []*x509.Certificate{{URIs: tt.urls}}} - id, err := SPIFFEIDFromState(state) - if got, want := err != nil, tt.expectError; got != want { - t.Errorf("want expectError = %v, but got expectError = %v, with error %v", want, got, err) - } + id := SPIFFEIDFromState(state) if got, want := id != nil, tt.expectID; got != want { t.Errorf("want expectID = %v, but SPIFFE ID is %v", want, id) } diff --git a/internal/credentials/gobefore110.go b/internal/credentials/gobefore110.go index 0356fb4f3707..4ca8dcf26b79 100644 --- a/internal/credentials/gobefore110.go +++ b/internal/credentials/gobefore110.go @@ -26,6 +26,6 @@ import ( ) //TODO(ZhenLian): delete this file when we remove Go 1.9 tests. -func SPIFFEIDFromState(state tls.ConnectionState) (*url.URL, error) { +func SPIFFEIDFromState(state tls.ConnectionState) *url.URL { return nil, nil }