diff --git a/credentials/credentials.go b/credentials/credentials.go index 53addd8c71e9..53f8fcdc9a0d 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -27,6 +27,7 @@ import ( "errors" "fmt" "net" + "net/url" "github.com/golang/protobuf/proto" "google.golang.org/grpc/attributes" @@ -87,6 +88,7 @@ func (s SecurityLevel) String() string { // This API is experimental. type CommonAuthInfo struct { SecurityLevel SecurityLevel + SPIFFEID *url.URL } // GetCommonAuthInfo returns the pointer to CommonAuthInfo struct. diff --git a/credentials/tls.go b/credentials/tls.go index 4691ec8eadb3..1db95e537387 100644 --- a/credentials/tls.go +++ b/credentials/tls.go @@ -25,16 +25,15 @@ import ( "fmt" "io/ioutil" "net" - "net/url" "google.golang.org/grpc/credentials/internal" + credinternal "google.golang.org/grpc/internal/credentials" ) // TLSInfo contains the auth information for a TLS authenticated connection. // It implements the AuthInfo interface. type TLSInfo struct { - State tls.ConnectionState - SpiffeID *url.URL + State tls.ConnectionState CommonAuthInfo } @@ -102,9 +101,13 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon SecurityLevel: PrivacyAndIntegrity, }, } - if err := tlsInfo.ParseSpiffeID(); err != nil { + id, err := credinternal.SPIFFEIDFromState(conn.ConnectionState()) + if err != nil { return nil, nil, err } + if id != nil { + tlsInfo.SPIFFEID = id + } return internal.WrapSyscallConn(rawConn, conn), tlsInfo, nil } @@ -120,9 +123,13 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) SecurityLevel: PrivacyAndIntegrity, }, } - if err := tlsInfo.ParseSpiffeID(); err != nil { + id, err := credinternal.SPIFFEIDFromState(conn.ConnectionState()) + if err != nil { return nil, nil, err } + if id != nil { + tlsInfo.SPIFFEID = id + } return internal.WrapSyscallConn(rawConn, conn), tlsInfo, nil } diff --git a/credentials/go10.go b/internal/credentials/go110.go similarity index 56% rename from credentials/go10.go rename to internal/credentials/go110.go index 23886c4372ce..458701b5fe05 100644 --- a/credentials/go10.go +++ b/internal/credentials/go110.go @@ -21,49 +21,49 @@ package credentials import ( + "crypto/tls" "fmt" "net/url" "google.golang.org/grpc/grpclog" ) -// ParseSpiffeID parses the Spiffe ID from State and fill it into SpiffeID. -// An error is returned only when we are sure Spiffe ID is used but the format -// is wrong. -// This function can only be used with go version 1.10 and onwards. When used -// with a prior version, no error will be returned, but the field -// TLSInfo.SpiffeID wouldn't be plumbed. -func (t *TLSInfo) ParseSpiffeID() error { - if len(t.State.PeerCertificates) == 0 || len(t.State.PeerCertificates[0].URIs) == 0 { - return nil +// SPIFFEIDFromState parses the SPIFFE ID from State. An error is returned only +// when we are sure SPIFFE ID is used but the format is wrong. +func SPIFFEIDFromState(state tls.ConnectionState) (*url.URL, error) { + if len(state.PeerCertificates) == 0 || len(state.PeerCertificates[0].URIs) == 0 { + return nil, nil } spiffeIDCnt := 0 var spiffeID url.URL - for _, uri := range t.State.PeerCertificates[0].URIs { + for _, uri := range state.PeerCertificates[0].URIs { if uri == nil || uri.Scheme != "spiffe" || uri.Opaque != "" || (uri.User != nil && uri.User.Username() != "") { continue } - // From this point, we assume the uri is intended for a Spiffe ID. + // From this point, we assume the uri is intended for a SPIFFE ID. if len(uri.Host)+len(uri.Scheme)+len(uri.RawPath)+4 > 2048 || len(uri.Host)+len(uri.Scheme)+len(uri.Path)+4 > 2048 { - return fmt.Errorf("invalid SPIFFE ID: total ID length larger than 2048 bytes") + return nil, fmt.Errorf("invalid SPIFFE ID: total ID length larger than 2048 bytes") } if len(uri.Host) == 0 || len(uri.RawPath) == 0 || len(uri.Path) == 0 { - return fmt.Errorf("invalid SPIFFE ID: domain or workload ID is empty") + return nil, fmt.Errorf("invalid SPIFFE ID: domain or workload ID is empty") } if len(uri.Host) > 255 { - return fmt.Errorf("invalid SPIFFE ID: domain length larger than 255 characters") + return nil, fmt.Errorf("invalid SPIFFE ID: domain length larger than 255 characters") } - // We use a default deep copy since we know the User field of a SPIFFE ID is empty. + // We use a default deep copy since we know the User field of a SPIFFE ID + // is empty. spiffeID = *uri spiffeIDCnt++ } if spiffeIDCnt == 1 { - t.SpiffeID = &spiffeID + return &spiffeID, nil } else if spiffeIDCnt > 1 { - // A standard SPIFFE ID should be unique. If there are more, we log this - // mis-behavior and not plumb any of them. + // A standard SPIFFE ID should be unique. If there are more than one ID, we + // should log this error but shouldn't halt the application. grpclog.Warning("invalid SPIFFE ID: multiple SPIFFE IDs") + return nil, nil } - return nil + // SPIFFE ID is not used. + return nil, nil } diff --git a/credentials/credentials_go10_test.go b/internal/credentials/go110_test.go similarity index 88% rename from credentials/credentials_go10_test.go rename to internal/credentials/go110_test.go index 288692fbb8f9..6bf0cbe07e18 100644 --- a/credentials/credentials_go10_test.go +++ b/internal/credentials/go110_test.go @@ -25,8 +25,18 @@ import ( "crypto/x509" "net/url" "testing" + + "google.golang.org/grpc/internal/grpctest" ) +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + func (s) TestParseSpiffeID(t *testing.T) { tests := []struct { name string @@ -135,14 +145,13 @@ func (s) TestParseSpiffeID(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - info := TLSInfo{ - State: tls.ConnectionState{PeerCertificates: []*x509.Certificate{{URIs: tt.urls}}}} - err := info.ParseSpiffeID() + state := tls.ConnectionState{PeerCertificates: []*x509.Certificate{{URIs: tt.urls}}} + id, err := SPIFFEIDFromState(state) if got, want := err != nil, tt.expectError; got != want { t.Errorf("want expectError = %v, but got expectError = %v, with error %v", want, got, err) } - if got, want := info.SpiffeID != nil, tt.expectID; got != want { - t.Errorf("want expectID = %v, but spiffe ID is %v", want, info.SpiffeID) + if got, want := id != nil, tt.expectID; got != want { + t.Errorf("want expectID = %v, but spiffe ID is %v", want, id) } }) } diff --git a/credentials/gobefore10.go b/internal/credentials/gobefore110.go similarity index 75% rename from credentials/gobefore10.go rename to internal/credentials/gobefore110.go index 69f9b3bf042a..0356fb4f3707 100644 --- a/credentials/gobefore10.go +++ b/internal/credentials/gobefore110.go @@ -21,10 +21,11 @@ package credentials import ( - "google.golang.org/grpc/grpclog" + "crypto/tls" + "net/url" ) -func (t *TLSInfo) ParseSpiffeID() error { - grpclog.Info("go version prior to 1.10 doesn't support parsing URIs in certificates. Please consider a newer version") - return nil +//TODO(ZhenLian): delete this file when we remove Go 1.9 tests. +func SPIFFEIDFromState(state tls.ConnectionState) (*url.URL, error) { + return nil, nil }