Skip to content

Commit

Permalink
fix review comments; move funcs to internal
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhenLian committed Jun 15, 2020
1 parent 3ed9ae7 commit 6625171
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 36 deletions.
2 changes: 2 additions & 0 deletions credentials/credentials.go
Expand Up @@ -27,6 +27,7 @@ import (
"errors"
"fmt"
"net"
"net/url"

"github.com/golang/protobuf/proto"
"google.golang.org/grpc/attributes"
Expand Down Expand Up @@ -87,6 +88,7 @@ func (s SecurityLevel) String() string {
// This API is experimental.
type CommonAuthInfo struct {
SecurityLevel SecurityLevel
SPIFFEID *url.URL
}

// GetCommonAuthInfo returns the pointer to CommonAuthInfo struct.
Expand Down
17 changes: 12 additions & 5 deletions credentials/tls.go
Expand Up @@ -25,16 +25,15 @@ import (
"fmt"
"io/ioutil"
"net"
"net/url"

"google.golang.org/grpc/credentials/internal"
credinternal "google.golang.org/grpc/internal/credentials"
)

// TLSInfo contains the auth information for a TLS authenticated connection.
// It implements the AuthInfo interface.
type TLSInfo struct {
State tls.ConnectionState
SpiffeID *url.URL
State tls.ConnectionState
CommonAuthInfo
}

Expand Down Expand Up @@ -102,9 +101,13 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
SecurityLevel: PrivacyAndIntegrity,
},
}
if err := tlsInfo.ParseSpiffeID(); err != nil {
id, err := credinternal.SPIFFEIDFromState(conn.ConnectionState())
if err != nil {
return nil, nil, err
}
if id != nil {
tlsInfo.SPIFFEID = id
}
return internal.WrapSyscallConn(rawConn, conn), tlsInfo, nil
}

Expand All @@ -120,9 +123,13 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
SecurityLevel: PrivacyAndIntegrity,
},
}
if err := tlsInfo.ParseSpiffeID(); err != nil {
id, err := credinternal.SPIFFEIDFromState(conn.ConnectionState())
if err != nil {
return nil, nil, err
}
if id != nil {
tlsInfo.SPIFFEID = id
}
return internal.WrapSyscallConn(rawConn, conn), tlsInfo, nil
}

Expand Down
41 changes: 22 additions & 19 deletions credentials/go10.go → internal/credentials/go110.go
Expand Up @@ -18,52 +18,55 @@
*
*/

// Package credentials defines APIs for parsing SPIFFE ID.
//
// All APIs in this package are experimental.
package credentials

import (
"crypto/tls"
"fmt"
"net/url"

"google.golang.org/grpc/grpclog"
)

// ParseSpiffeID parses the Spiffe ID from State and fill it into SpiffeID.
// An error is returned only when we are sure Spiffe ID is used but the format
// is wrong.
// This function can only be used with go version 1.10 and onwards. When used
// with a prior version, no error will be returned, but the field
// TLSInfo.SpiffeID wouldn't be plumbed.
func (t *TLSInfo) ParseSpiffeID() error {
if len(t.State.PeerCertificates) == 0 || len(t.State.PeerCertificates[0].URIs) == 0 {
return nil
// SPIFFEIDFromState parses the SPIFFE ID from State. An error is returned only
// when we are sure SPIFFE ID is used but the format is wrong.
func SPIFFEIDFromState(state tls.ConnectionState) (*url.URL, error) {
if len(state.PeerCertificates) == 0 || len(state.PeerCertificates[0].URIs) == 0 {
return nil, nil
}
spiffeIDCnt := 0
var spiffeID url.URL
for _, uri := range t.State.PeerCertificates[0].URIs {
for _, uri := range state.PeerCertificates[0].URIs {
if uri == nil || uri.Scheme != "spiffe" || uri.Opaque != "" || (uri.User != nil && uri.User.Username() != "") {
continue
}
// From this point, we assume the uri is intended for a Spiffe ID.
// From this point, we assume the uri is intended for a SPIFFE ID.
if len(uri.Host)+len(uri.Scheme)+len(uri.RawPath)+4 > 2048 ||
len(uri.Host)+len(uri.Scheme)+len(uri.Path)+4 > 2048 {
return fmt.Errorf("invalid SPIFFE ID: total ID length larger than 2048 bytes")
return nil, fmt.Errorf("invalid SPIFFE ID: total ID length larger than 2048 bytes")
}
if len(uri.Host) == 0 || len(uri.RawPath) == 0 || len(uri.Path) == 0 {
return fmt.Errorf("invalid SPIFFE ID: domain or workload ID is empty")
return nil, fmt.Errorf("invalid SPIFFE ID: domain or workload ID is empty")
}
if len(uri.Host) > 255 {
return fmt.Errorf("invalid SPIFFE ID: domain length larger than 255 characters")
return nil, fmt.Errorf("invalid SPIFFE ID: domain length larger than 255 characters")
}
// We use a default deep copy since we know the User field of a SPIFFE ID is empty.
// We use a default deep copy since we know the User field of a SPIFFE ID
// is empty.
spiffeID = *uri
spiffeIDCnt++
}
if spiffeIDCnt == 1 {
t.SpiffeID = &spiffeID
return &spiffeID, nil
} else if spiffeIDCnt > 1 {
// A standard SPIFFE ID should be unique. If there are more, we log this
// mis-behavior and not plumb any of them.
// A standard SPIFFE ID should be unique. If there are more than one ID, we
// should log this error but shouldn't halt the application.
grpclog.Warning("invalid SPIFFE ID: multiple SPIFFE IDs")
return nil, nil
}
return nil
// SPIFFE ID is not used.
return nil, nil
}
Expand Up @@ -25,15 +25,25 @@ import (
"crypto/x509"
"net/url"
"testing"

"google.golang.org/grpc/internal/grpctest"
)

func (s) TestParseSpiffeID(t *testing.T) {
type s struct {
grpctest.Tester
}

func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}

func (s) TestSPIFFEIDFromState(t *testing.T) {
tests := []struct {
name string
urls []*url.URL
// If we expect ParseSpiffeID to return an error.
// If we expect SPIFFEIDFromState to return an error.
expectError bool
// If we expect TLSInfo.SpiffeID to be plumbed.
// If we expect a SPIFFE ID to be returned.
expectID bool
}{
{
Expand Down Expand Up @@ -135,14 +145,13 @@ func (s) TestParseSpiffeID(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
info := TLSInfo{
State: tls.ConnectionState{PeerCertificates: []*x509.Certificate{{URIs: tt.urls}}}}
err := info.ParseSpiffeID()
state := tls.ConnectionState{PeerCertificates: []*x509.Certificate{{URIs: tt.urls}}}
id, err := SPIFFEIDFromState(state)
if got, want := err != nil, tt.expectError; got != want {
t.Errorf("want expectError = %v, but got expectError = %v, with error %v", want, got, err)
}
if got, want := info.SpiffeID != nil, tt.expectID; got != want {
t.Errorf("want expectID = %v, but spiffe ID is %v", want, info.SpiffeID)
if got, want := id != nil, tt.expectID; got != want {
t.Errorf("want expectID = %v, but SPIFFE ID is %v", want, id)
}
})
}
Expand Down
Expand Up @@ -21,10 +21,11 @@
package credentials

import (
"google.golang.org/grpc/grpclog"
"crypto/tls"
"net/url"
)

func (t *TLSInfo) ParseSpiffeID() error {
grpclog.Info("go version prior to 1.10 doesn't support parsing URIs in certificates. Please consider a newer version")
return nil
//TODO(ZhenLian): delete this file when we remove Go 1.9 tests.
func SPIFFEIDFromState(state tls.ConnectionState) (*url.URL, error) {
return nil, nil
}

0 comments on commit 6625171

Please sign in to comment.