Skip to content

Commit

Permalink
Atomically write and read xDS handshake info client side
Browse files Browse the repository at this point in the history
  • Loading branch information
zasweq committed Nov 15, 2023
1 parent 8645f95 commit 8d82f50
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 106 deletions.
5 changes: 4 additions & 1 deletion credentials/xds/xds.go
Expand Up @@ -27,6 +27,7 @@ import (
"errors"
"fmt"
"net"
"sync/atomic"
"time"

"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -114,7 +115,9 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
if chi.Attributes == nil {
return c.fallback.ClientHandshake(ctx, authority, rawConn)
}
hi := xdsinternal.GetHandshakeInfo(chi.Attributes)

uPtr := xdsinternal.GetHandshakeInfo(chi.Attributes)
hi := (*xdsinternal.HandshakeInfo)(atomic.LoadPointer(uPtr))
if hi.UseFallbackCreds() {
return c.fallback.ClientHandshake(ctx, authority, rawConn)
}
Expand Down
21 changes: 13 additions & 8 deletions credentials/xds/xds_client_test.go
Expand Up @@ -29,6 +29,7 @@ import (
"strings"
"testing"
"time"
"unsafe"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
Expand Down Expand Up @@ -219,11 +220,13 @@ func newTestContextWithHandshakeInfo(parent context.Context, root, identity cert
// Creating the HandshakeInfo and adding it to the attributes is very
// similar to what the CDS balancer would do when it intercepts calls to
// NewSubConn().
info := xdsinternal.NewHandshakeInfo(root, identity)
var sm []matcher.StringMatcher
if sanExactMatch != "" {
info.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)})
sm = []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)}
}
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info)
info := xdsinternal.NewHandshakeInfo(root, identity, sm, false)
uPtr := unsafe.Pointer(info)
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)

// Moving the attributes from the resolver.Address to the context passed to
// the handshaker is done in the transport layer. Since we directly call the
Expand Down Expand Up @@ -533,13 +536,12 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
// Create a root provider which will fail the handshake because it does not
// use the correct trust roots.
root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil)
handshakeInfo.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)})

handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
// We need to repeat most of what newTestContextWithHandshakeInfo() does
// here because we need access to the underlying HandshakeInfo so that we
// can update it before the next call to ClientHandshake().
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo)
uPtr := unsafe.Pointer(handshakeInfo)
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
t.Fatal("ClientHandshake() succeeded when expected to fail")
Expand All @@ -560,7 +562,10 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
// Create a new root provider which uses the correct trust roots. And update
// the HandshakeInfo with the new provider.
root2 := makeRootProvider(t, "x509/server_ca_cert.pem")
handshakeInfo.SetRootCertProvider(root2)
handshakeInfo = xdsinternal.NewHandshakeInfo(root2, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
uPtr = unsafe.Pointer(handshakeInfo)
addr = xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
_, ai, err := creds.ClientHandshake(ctx, authority, conn)
if err != nil {
t.Fatalf("ClientHandshake() returned failed: %q", err)
Expand Down
19 changes: 7 additions & 12 deletions credentials/xds/xds_server_test.go
Expand Up @@ -122,7 +122,7 @@ func (s) TestServerCredsInvalidHandshakeInfo(t *testing.T) {
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
}

info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil)
info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil, nil, false)
conn := newWrappedConn(nil, info, time.Time{})
if _, _, err := creds.ServerHandshake(conn); err == nil {
t.Fatal("ServerHandshake succeeded without identity certificate provider in HandshakeInfo")
Expand Down Expand Up @@ -158,7 +158,7 @@ func (s) TestServerCredsProviderFailure(t *testing.T) {
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider)
info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, false)
conn := newWrappedConn(nil, info, time.Time{})
if _, _, err := creds.ServerHandshake(conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
t.Fatalf("ServerHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
Expand Down Expand Up @@ -232,8 +232,7 @@ func (s) TestServerCredsHandshakeTimeout(t *testing.T) {
// Create a test server which uses the xDS server credentials created above
// to perform TLS handshake on incoming connections.
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"))
hi.SetRequireClientCert(true)
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), nil, true)

// Create a wrapped conn which can return the HandshakeInfo created
// above with a very small deadline.
Expand Down Expand Up @@ -285,8 +284,7 @@ func (s) TestServerCredsHandshakeFailure(t *testing.T) {
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
// Create a HandshakeInfo which has a root provider which does not match
// the certificate sent by the client.
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"))
hi.SetRequireClientCert(true)
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true)

// Create a wrapped conn which can return the HandshakeInfo and
// configured deadline to the xDS credentials' ServerHandshake()
Expand Down Expand Up @@ -367,8 +365,7 @@ func (s) TestServerCredsHandshakeSuccess(t *testing.T) {
// created above to perform TLS handshake on incoming connections.
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
// Create a HandshakeInfo with information from the test table.
hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider)
hi.SetRequireClientCert(test.requireClientCert)
hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, test.requireClientCert)

// Create a wrapped conn which can return the HandshakeInfo and
// configured deadline to the xDS credentials' ServerHandshake()
Expand Down Expand Up @@ -448,8 +445,7 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) {
if cnt == 1 {
// Create a HandshakeInfo which has a root provider which does not match
// the certificate sent by the client.
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"))
hi.SetRequireClientCert(true)
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true)

// Create a wrapped conn which can return the HandshakeInfo and
// configured deadline to the xDS credentials' ServerHandshake()
Expand All @@ -463,8 +459,7 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) {
return handshakeResult{}
}

hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"))
hi.SetRequireClientCert(true)
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), nil, true)

// Create a wrapped conn which can return the HandshakeInfo and
// configured deadline to the xDS credentials' ServerHandshake()
Expand Down
67 changes: 14 additions & 53 deletions internal/credentials/xds/handshake_info.go
Expand Up @@ -26,7 +26,7 @@ import (
"errors"
"fmt"
"strings"
"sync"
"unsafe"

"google.golang.org/grpc/attributes"
"google.golang.org/grpc/credentials/tls/certprovider"
Expand Down Expand Up @@ -65,17 +65,14 @@ func (hi *HandshakeInfo) Equal(other *HandshakeInfo) bool {
return true
}

// SetHandshakeInfo returns a copy of addr in which the Attributes field is
// updated with hInfo.
func SetHandshakeInfo(addr resolver.Address, hInfo *HandshakeInfo) resolver.Address {
addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hInfo)
func SetHandshakeInfo(addr resolver.Address, hiPtr *unsafe.Pointer) resolver.Address {
addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hiPtr)
return addr
}

// GetHandshakeInfo returns a pointer to the HandshakeInfo stored in attr.
func GetHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo {
func GetHandshakeInfo(attr *attributes.Attributes) *unsafe.Pointer {
v := attr.Value(handshakeAttrKey{})
hi, _ := v.(*HandshakeInfo)
hi, _ := v.(*unsafe.Pointer)
return hi
}

Expand All @@ -85,40 +82,21 @@ func GetHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo {
//
// Safe for concurrent access.
type HandshakeInfo struct {
mu sync.Mutex
// All fields written at init time and read only after that, so no
// synchronization needed.
rootProvider certprovider.Provider
identityProvider certprovider.Provider
sanMatchers []matcher.StringMatcher // Only on the client side.
requireClientCert bool // Only on server side.
}

// SetRootCertProvider updates the root certificate provider.
func (hi *HandshakeInfo) SetRootCertProvider(root certprovider.Provider) {
hi.mu.Lock()
hi.rootProvider = root
hi.mu.Unlock()
}

// SetIdentityCertProvider updates the identity certificate provider.
func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) {
hi.mu.Lock()
hi.identityProvider = identity
hi.mu.Unlock()
}

// SetSANMatchers updates the list of SAN matchers.
func (hi *HandshakeInfo) SetSANMatchers(sanMatchers []matcher.StringMatcher) {
hi.mu.Lock()
hi.sanMatchers = sanMatchers
hi.mu.Unlock()
}

// SetRequireClientCert updates whether a client cert is required during the
// ServerHandshake(). A value of true indicates that we are performing mTLS.
func (hi *HandshakeInfo) SetRequireClientCert(require bool) {
hi.mu.Lock()
hi.requireClientCert = require
hi.mu.Unlock()
func NewHandshakeInfo(rootProvider certprovider.Provider, identityProvider certprovider.Provider, sanMatchers []matcher.StringMatcher, requireClientCert bool) *HandshakeInfo {
return &HandshakeInfo{
rootProvider: rootProvider,
identityProvider: identityProvider,
sanMatchers: sanMatchers,
requireClientCert: requireClientCert,
}
}

// UseFallbackCreds returns true when fallback credentials are to be used based
Expand All @@ -127,24 +105,18 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool {
if hi == nil {
return true
}

hi.mu.Lock()
defer hi.mu.Unlock()
return hi.identityProvider == nil && hi.rootProvider == nil
}

// GetSANMatchersForTesting returns the SAN matchers stored in HandshakeInfo.
// To be used only for testing purposes.
func (hi *HandshakeInfo) GetSANMatchersForTesting() []matcher.StringMatcher {
hi.mu.Lock()
defer hi.mu.Unlock()
return append([]matcher.StringMatcher{}, hi.sanMatchers...)
}

// ClientSideTLSConfig constructs a tls.Config to be used in a client-side
// handshake based on the contents of the HandshakeInfo.
func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, error) {
hi.mu.Lock()
// On the client side, rootProvider is mandatory. IdentityProvider is
// optional based on whether the client is doing TLS or mTLS.
if hi.rootProvider == nil {
Expand All @@ -153,7 +125,6 @@ func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config,
// Since the call to KeyMaterial() can block, we read the providers under
// the lock but call the actual function after releasing the lock.
rootProv, idProv := hi.rootProvider, hi.identityProvider
hi.mu.Unlock()

// InsecureSkipVerify needs to be set to true because we need to perform
// custom verification to check the SAN on the received certificate.
Expand Down Expand Up @@ -188,7 +159,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
ClientAuth: tls.NoClientCert,
NextProtos: []string{"h2"},
}
hi.mu.Lock()
// On the server side, identityProvider is mandatory. RootProvider is
// optional based on whether the server is doing TLS or mTLS.
if hi.identityProvider == nil {
Expand All @@ -200,7 +170,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
if hi.requireClientCert {
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}
hi.mu.Unlock()

// identityProvider is mandatory on the server side.
km, err := idProv.KeyMaterial(ctx)
Expand All @@ -225,8 +194,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
// If the list of SAN matchers in the HandshakeInfo is empty, this function
// returns true for all input certificates.
func (hi *HandshakeInfo) MatchingSANExists(cert *x509.Certificate) bool {
hi.mu.Lock()
defer hi.mu.Unlock()
if len(hi.sanMatchers) == 0 {
return true
}
Expand Down Expand Up @@ -325,9 +292,3 @@ func dnsMatch(host, san string) bool {
hostPrefix := strings.TrimSuffix(host, san[1:])
return !strings.Contains(hostPrefix, ".")
}

// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root
// and identity certificate providers.
func NewHandshakeInfo(root, identity certprovider.Provider) *HandshakeInfo {
return &HandshakeInfo{rootProvider: root, identityProvider: identity}
}
6 changes: 2 additions & 4 deletions internal/credentials/xds/handshake_info_test.go
Expand Up @@ -188,8 +188,7 @@ func TestMatchingSANExists_FailureCases(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
hi := NewHandshakeInfo(nil, nil)
hi.SetSANMatchers(test.sanMatchers)
hi := NewHandshakeInfo(nil, nil, test.sanMatchers, false)

if hi.MatchingSANExists(inputCert) {
t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v succeeded when expected to fail", inputCert, test.sanMatchers)
Expand Down Expand Up @@ -289,8 +288,7 @@ func TestMatchingSANExists_Success(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
hi := NewHandshakeInfo(nil, nil)
hi.SetSANMatchers(test.sanMatchers)
hi := NewHandshakeInfo(nil, nil, test.sanMatchers, false)

if !hi.MatchingSANExists(inputCert) {
t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v failed when expected to succeed", inputCert, test.sanMatchers)
Expand Down
2 changes: 1 addition & 1 deletion internal/internal.go
Expand Up @@ -57,7 +57,7 @@ var (
// GetXDSHandshakeInfoForTesting returns a pointer to the xds.HandshakeInfo
// stored in the passed in attributes. This is set by
// credentials/xds/xds.go.
GetXDSHandshakeInfoForTesting any // func (*attributes.Attributes) *xds.HandshakeInfo
GetXDSHandshakeInfoForTesting any // func (*attributes.Attributes) *unsafe.Pointer
// GetServerCredentials returns the transport credentials configured on a
// gRPC server. An xDS-enabled server needs to know what type of credentials
// is configured on the underlying gRPC server. This is set by server.go.
Expand Down

0 comments on commit 8d82f50

Please sign in to comment.