Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xds: client-side security: xDS credentials implementation #3888

Merged
merged 8 commits into from Sep 29, 2020
188 changes: 100 additions & 88 deletions credentials/xds/xds.go
Expand Up @@ -19,7 +19,10 @@
// Package xds provides a transport credentials implementation where the
// security configuration is pushed by a management server using xDS APIs.
//
// All APIs in this package are EXPERIMENTAL.
// Experimental
//
// Notice: All APIs in this package are EXPERIMENTAL and may be removed in a
// later release.
package xds

import (
Expand All @@ -31,6 +34,7 @@ import (
"net"
"sync"

"google.golang.org/grpc/attributes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
credinternal "google.golang.org/grpc/internal/credentials"
Expand Down Expand Up @@ -65,8 +69,21 @@ type credsImpl struct {
fallback credentials.TransportCredentials
}

// handshakeCtxKey is the context key used to store HandshakeInfo values.
type handshakeCtxKey struct{}
// handshakeAttrKey is the type used as the key to store HandshakeInfo in
// the Attributes field of resolver.Address.
type handshakeAttrKey struct{}

// SetHandshakeInfo returns a copy of attr which is updated with hInfo.
func SetHandshakeInfo(attr *attributes.Attributes, hInfo *HandshakeInfo) *attributes.Attributes {
return attr.WithValues(handshakeAttrKey{}, hInfo)
}

// GetHandshakeInfo returns a pointer to the HandshakeInfo stored in attr.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be unexported? Does anyone else need access to it? If so that would eliminate the asymmetry of the API (setter takes address, getter takes attributes).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

func GetHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make these work on resolver.Addresses instead of attributes, e.g.:

func Get(addr resolver.Address) []string {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed SetHandshakeInfo to work on resolver.Address since this will be invoked from the CDS balancer and it has a resolver.Address at that point.

GetHandshakeInfo is called from the credentials code, and there is no resolver.Address at that point.

v := attr.Value(handshakeAttrKey{})
hi, _ := v.(*HandshakeInfo)
return hi
}

// HandshakeInfo wraps all the security configuration required by client and
// server handshake methods in credsImpl. The xDS implementation will be
Expand All @@ -81,49 +98,79 @@ type HandshakeInfo struct {
}

// SetRootCertProvider updates the root certificate provider.
func (chi *HandshakeInfo) SetRootCertProvider(root certprovider.Provider) {
chi.mu.Lock()
chi.rootProvider = root
chi.mu.Unlock()
func (hi *HandshakeInfo) SetRootCertProvider(root certprovider.Provider) {
hi.mu.Lock()
hi.rootProvider = root
hi.mu.Unlock()
}

// SetIdentityCertProvider updates the identity certificate provider.
func (chi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) {
chi.mu.Lock()
chi.identityProvider = identity
chi.mu.Unlock()
func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) {
hi.mu.Lock()
hi.identityProvider = identity
hi.mu.Unlock()
}

// SetAcceptedSANs updates the list of accepted SANs.
func (chi *HandshakeInfo) SetAcceptedSANs(sans []string) {
chi.mu.Lock()
chi.acceptedSANs = make(map[string]bool)
func (hi *HandshakeInfo) SetAcceptedSANs(sans []string) {
hi.mu.Lock()
hi.acceptedSANs = make(map[string]bool, len(sans))
for _, san := range sans {
chi.acceptedSANs[san] = true
hi.acceptedSANs[san] = true
}
chi.mu.Unlock()
hi.mu.Unlock()
}

func (chi *HandshakeInfo) validate(isClient bool) error {
chi.mu.Lock()
defer chi.mu.Unlock()
func (hi *HandshakeInfo) validate(isClient bool) error {
hi.mu.Lock()
defer hi.mu.Unlock()

// On the client side, rootProvider is mandatory. IdentityProvider is
// optional based on whether the client is doing TLS or mTLS.
ZhenLian marked this conversation as resolved.
Show resolved Hide resolved
if isClient && chi.rootProvider == nil {
return errors.New("root certificate provider is missing")
if isClient && hi.rootProvider == nil {
return errors.New("xds: CertificateProvider to fetch trusted roots is missing, cannot perform TLS handshake")
}

// On the server side, identityProvider is mandatory. RootProvider is
// optional based on whether the server is doing TLS or mTLS.
if !isClient && chi.identityProvider == nil {
return errors.New("identity certificate provider is missing")
if !isClient && hi.identityProvider == nil {
return errors.New("xds: CertificateProvider to fetch identity certificate is missing, cannot perform TLS handshake")
}

return nil
}

func (chi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool {
func (hi *HandshakeInfo) makeTLSConfig(ctx context.Context) (*tls.Config, error) {
hi.mu.Lock()
// Since the call to KeyMaterial() can block, we read the providers under
// the lock but call the actual function after releasing the lock.
rootProv, idProv := hi.rootProvider, hi.identityProvider
hi.mu.Unlock()

// InsecureSkipVerify needs to be set to true because we need to perform
// custom verification to check the SAN on the received certificate.
// Currently the Go stdlib does complete verification of the cert (which
// includes hostname verification) or none. We are forced to go with the
// latter and perform the normal cert validation ourselves.
cfg := &tls.Config{InsecureSkipVerify: true}
if rootProv != nil {
km, err := rootProv.KeyMaterial(ctx)
if err != nil {
return nil, fmt.Errorf("xds: fetching trusted roots from CertificateProvider failed: %v", err)
}
cfg.RootCAs = km.Roots
}
if idProv != nil {
km, err := idProv.KeyMaterial(ctx)
if err != nil {
return nil, fmt.Errorf("xds: fetching identity certificates from CertificateProvider failed: %v", err)
}
cfg.Certificates = km.Certs
}
return cfg, nil
}

func (hi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool {
var sans []string
// SANs can be specified in any of these four fields on the parsed cert.
sans = append(sans, cert.DNSNames...)
Expand All @@ -135,10 +182,10 @@ func (chi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool {
sans = append(sans, uri.String())
}

chi.mu.Lock()
defer chi.mu.Unlock()
hi.mu.Lock()
defer hi.mu.Unlock()
for _, san := range sans {
if chi.acceptedSANs[san] {
if hi.acceptedSANs[san] {
return true
}
}
Expand All @@ -148,7 +195,7 @@ func (chi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool {
// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root
// and identity certificate providers.
func NewHandshakeInfo(root, identity certprovider.Provider, sans ...string) *HandshakeInfo {
acceptedSANs := make(map[string]bool)
acceptedSANs := make(map[string]bool, len(sans))
for _, san := range sans {
acceptedSANs[san] = true
}
Expand All @@ -159,21 +206,6 @@ func NewHandshakeInfo(root, identity certprovider.Provider, sans ...string) *Han
}
}

// NewContextWithHandshakeInfo returns a copy of the parent context with the
// provided HandshakeInfo stored as a value.
func NewContextWithHandshakeInfo(parent context.Context, info *HandshakeInfo) context.Context {
return context.WithValue(parent, handshakeCtxKey{}, info)
}

// handshakeInfoFromCtx returns a pointer to the HandshakeInfo stored in ctx.
func handshakeInfoFromCtx(ctx context.Context) *HandshakeInfo {
val, ok := ctx.Value(handshakeCtxKey{}).(*HandshakeInfo)
if !ok {
return nil
}
return val
}

// ClientHandshake performs the TLS handshake on the client-side.
//
// It looks for the presence of a HandshakeInfo value in the passed in context
Expand All @@ -187,15 +219,29 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
return nil, nil, errors.New("ClientHandshake() is not supported for server credentials")
}

chi := handshakeInfoFromCtx(ctx)
if chi == nil {
// A missing handshake info in the provided context could mean either
// the user did not specify an `xds` scheme in their dial target or that
// the xDS server did not provide any security configuration. In both of
// these cases, we use the fallback credentials specified by the user.
// The CDS balancer constructs a new HandshakeInfo using a call to
// NewHandshakeInfo(), and then adds it to the attributes field of the
// resolver.Address when handling calls to NewSubConn(). The transport layer
// takes care of shipping these attributes in the context to this handshake
// function. We first read the credentials.ClientHandshakeInfo type from the
// context, which contains the attributes added by the CDS balancer. We then
// read the HandshakeInfo from the attributes to get to the actual data that
// we need here for the handshake.
chi := credentials.ClientHandshakeInfoFromContext(ctx)
// If there are no attributes in the received context or the attributes does
// not contain a HandshakeInfo, it could either mean that the user did not
// specify an `xds` scheme in their dial target or that the xDS server did
// not provide any security configuration. In both of these cases, we use
// the fallback credentials specified by the user.
if chi.Attributes == nil {
return c.fallback.ClientHandshake(ctx, authority, rawConn)
}
hi := GetHandshakeInfo(chi.Attributes)
if hi == nil {
return c.fallback.ClientHandshake(ctx, authority, rawConn)
}
if err := chi.validate(c.isClient); err != nil {

if err := hi.validate(c.isClient); err != nil {
return nil, nil, err
}

Expand All @@ -211,38 +257,10 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
// 4. Key usage to match whether client/server usage.
// 5. A `VerifyPeerCertificate` function which performs normal peer
// cert verification using configured roots, and the custom SAN checks.
var certs []tls.Certificate
var roots *x509.CertPool
err := func() error {
// We use this anonymous function trick to be able to defer the unlock.
chi.mu.Lock()
defer chi.mu.Unlock()

if chi.rootProvider != nil {
km, err := chi.rootProvider.KeyMaterial(ctx)
if err != nil {
return fmt.Errorf("fetching root certificates failed: %v", err)
}
roots = km.Roots
}
if chi.identityProvider != nil {
km, err := chi.identityProvider.KeyMaterial(ctx)
if err != nil {
return fmt.Errorf("fetching identity certificates failed: %v", err)
}
certs = km.Certs
}
return nil
}()
cfg, err := hi.makeTLSConfig(ctx)
if err != nil {
return nil, nil, err
}

cfg := &tls.Config{
Certificates: certs,
InsecureSkipVerify: true,
RootCAs: roots,
}
cfg.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
// Parse all raw certificates presented by the peer.
var certs []*x509.Certificate
Expand All @@ -261,7 +279,7 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
intermediates.AddCert(cert)
}
opts := x509.VerifyOptions{
Roots: roots,
Roots: cfg.RootCAs,
Intermediates: intermediates,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
Expand All @@ -270,7 +288,7 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
}
// The SANs sent by the MeshCA are encoded as SPIFFE IDs. We need to
// only look at the SANs on the leaf cert.
if !chi.matchingSANExists(certs[0]) {
if !hi.matchingSANExists(certs[0]) {
return fmt.Errorf("SANs received in leaf certificate %+v does not match any of the accepted SANs", certs[0])
}
return nil
Expand All @@ -280,13 +298,7 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
// actual Handshake() function in a goroutine because we need to respect the
// deadline specified on the passed in context, and we need a way to cancel
// the handshake if the context is cancelled.
var conn *tls.Conn
if c.isClient {
conn = tls.Client(rawConn, cfg)
} else {
conn = tls.Server(rawConn, cfg)
}

conn := tls.Client(rawConn, cfg)
errCh := make(chan error, 1)
go func() {
errCh <- conn.Handshake()
Expand Down
30 changes: 25 additions & 5 deletions credentials/xds/xds_test.go
Expand Up @@ -32,6 +32,7 @@ import (

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/testdata"
Expand Down Expand Up @@ -219,11 +220,20 @@ func makeRootProvider(t *testing.T, caPath string) *fakeProvider {
return &fakeProvider{km: &certprovider.KeyMaterial{Roots: roots}}
}

// newTestContextWithHandshakeInfo returns a copy of the passed in context with
// HandshakeInfo context value added to it.
func newTestContextWithHandshakeInfo(ctx context.Context, root, identity certprovider.Provider, sans ...string) context.Context {
// newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo
// context value added to it.
func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sans ...string) context.Context {
// Creating the HandshakeInfo and adding it to the attributes is very
// similar to what the CDS balancer would do when it intercepts calls to
// NewSubConn().
info := NewHandshakeInfo(root, identity, sans...)
return NewContextWithHandshakeInfo(ctx, info)
attr := SetHandshakeInfo(nil, info)

// Moving the attributes from the resolver.Address to the context passed to
// the handshaker is done in the transport layer. Since we directly call the
// handshaker in these tests, we need to do the same here.
contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
return contextWithHandshakeInfo(parent, credentials.ClientHandshakeInfo{Attributes: attr})
}

// compareAuthInfo compares the AuthInfo received on the client side after a
Expand Down Expand Up @@ -485,9 +495,17 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Create a root provider which will fail the handshake because it does not
// use the correct trust roots.
root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
handshakeInfo := NewHandshakeInfo(root1, nil, defaultTestCertSAN)
ctx = NewContextWithHandshakeInfo(ctx, handshakeInfo)

// We need to repeat most of what newTestContextWithHandshakeInfo() does
// here because we need access to the underlying HandshakeInfo so that we
// can update it before the next call to ClientHandshake().
attr := SetHandshakeInfo(nil, handshakeInfo)
contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
ctx = contextWithHandshakeInfo(ctx, credentials.ClientHandshakeInfo{Attributes: attr})
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
t.Fatal("ClientHandshake() succeeded when expected to fail")
}
Expand All @@ -504,6 +522,8 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
}
defer conn.Close()

// Create a new root provider which uses the correct trust roots. And update
// the HandshakeInfo with the new provider.
root2 := makeRootProvider(t, "x509/server_ca_cert.pem")
handshakeInfo.SetRootCertProvider(root2)
_, ai, err := creds.ClientHandshake(ctx, authority, conn)
Expand Down