Skip to content

Commit

Permalink
xds: add support for mTLS Credentials in xDS bootstrap (#6757)
Browse files Browse the repository at this point in the history
  • Loading branch information
atollena committed Jan 4, 2024
1 parent 71cc0f1 commit 6bc1906
Show file tree
Hide file tree
Showing 11 changed files with 571 additions and 31 deletions.
6 changes: 3 additions & 3 deletions credentials/tls/certprovider/pemfile/builder.go
Expand Up @@ -29,7 +29,7 @@ import (
)

const (
pluginName = "file_watcher"
PluginName = "file_watcher"
defaultRefreshInterval = 10 * time.Minute
)

Expand All @@ -48,13 +48,13 @@ func (p *pluginBuilder) ParseConfig(c any) (*certprovider.BuildableConfig, error
if err != nil {
return nil, err
}
return certprovider.NewBuildableConfig(pluginName, opts.canonical(), func(certprovider.BuildOptions) certprovider.Provider {
return certprovider.NewBuildableConfig(PluginName, opts.canonical(), func(certprovider.BuildOptions) certprovider.Provider {
return newProvider(opts)
}), nil
}

func (p *pluginBuilder) Name() string {
return pluginName
return PluginName
}

func pluginConfigFromJSON(jd json.RawMessage) (Options, error) {
Expand Down
4 changes: 2 additions & 2 deletions internal/testutils/xds/e2e/setup_certs.go
Expand Up @@ -98,7 +98,7 @@ func CreateClientTLSCredentials(t *testing.T) credentials.TransportCredentials {

// CreateServerTLSCredentials creates server-side TLS transport credentials
// using certificate and key files from testdata/x509 directory.
func CreateServerTLSCredentials(t *testing.T) credentials.TransportCredentials {
func CreateServerTLSCredentials(t *testing.T, clientAuth tls.ClientAuthType) credentials.TransportCredentials {
t.Helper()

cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
Expand All @@ -114,7 +114,7 @@ func CreateServerTLSCredentials(t *testing.T) credentials.TransportCredentials {
t.Fatal("Failed to append certificates")
}
return credentials.NewTLS(&tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientAuth: clientAuth,
Certificates: []tls.Certificate{cert},
ClientCAs: ca,
})
Expand Down
3 changes: 2 additions & 1 deletion test/xds/xds_client_certificate_providers_test.go
Expand Up @@ -20,6 +20,7 @@ package xds_test

import (
"context"
"crypto/tls"
"fmt"
"strings"
"testing"
Expand Down Expand Up @@ -226,7 +227,7 @@ func (s) TestClientSideXDS_WithValidAndInvalidSecurityConfiguration(t *testing.T
// backend1 configured with TLS creds, represents cluster1
// backend2 configured with insecure creds, represents cluster2
// backend3 configured with insecure creds, represents cluster3
creds := e2e.CreateServerTLSCredentials(t)
creds := e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert)
server1 := stubserver.StartTestService(t, nil, grpc.Creds(creds))
defer server1.Stop()
server2 := stubserver.StartTestService(t, nil)
Expand Down
6 changes: 4 additions & 2 deletions xds/bootstrap/bootstrap.go
Expand Up @@ -37,8 +37,10 @@ var registry = make(map[string]Credentials)
// Credentials interface encapsulates a credentials.Bundle builder
// that can be used for communicating with the xDS Management server.
type Credentials interface {
// Build returns a credential bundle associated with this credential.
Build(config json.RawMessage) (credentials.Bundle, error)
// Build returns a credential bundle associated with this credential, and
// a function to cleans up additional resources associated with this bundle
// when it is no longer needed.
Build(config json.RawMessage) (credentials.Bundle, func(), error)
// Name returns the credential name associated with this credential.
Name() string
}
Expand Down
6 changes: 3 additions & 3 deletions xds/bootstrap/bootstrap_test.go
Expand Up @@ -36,9 +36,9 @@ type testCredsBuilder struct {
config json.RawMessage
}

func (t *testCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, error) {
func (t *testCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, func(), error) {
t.config = config
return nil, nil
return nil, nil, nil
}

func (t *testCredsBuilder) Name() string {
Expand All @@ -53,7 +53,7 @@ func TestRegisterNew(t *testing.T) {

const sampleConfig = "sample_config"
rawMessage := json.RawMessage(sampleConfig)
if _, err := c.Build(rawMessage); err != nil {
if _, _, err := c.Build(rawMessage); err != nil {
t.Errorf("Build(%v) error = %v, want nil", rawMessage, err)
}

Expand Down
29 changes: 24 additions & 5 deletions xds/internal/xdsclient/bootstrap/bootstrap.go
Expand Up @@ -39,6 +39,7 @@ import (
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/xds/bootstrap"
"google.golang.org/grpc/xds/internal/xdsclient/tlscreds"
)

const (
Expand All @@ -60,6 +61,7 @@ const (
func init() {
bootstrap.RegisterCredentials(&insecureCredsBuilder{})
bootstrap.RegisterCredentials(&googleDefaultCredsBuilder{})
bootstrap.RegisterCredentials(&tlsCredsBuilder{})
}

// For overriding in unit tests.
Expand All @@ -69,20 +71,32 @@ var bootstrapFileReadFunc = os.ReadFile
// package `xds/bootstrap` and encapsulates an insecure credential.
type insecureCredsBuilder struct{}

func (i *insecureCredsBuilder) Build(json.RawMessage) (credentials.Bundle, error) {
return insecure.NewBundle(), nil
func (i *insecureCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) {
return insecure.NewBundle(), func() {}, nil
}

func (i *insecureCredsBuilder) Name() string {
return "insecure"
}

// tlsCredsBuilder implements the `Credentials` interface defined in
// package `xds/bootstrap` and encapsulates a TLS credential.
type tlsCredsBuilder struct{}

func (t *tlsCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, func(), error) {
return tlscreds.NewBundle(config)
}

func (t *tlsCredsBuilder) Name() string {
return "tls"
}

// googleDefaultCredsBuilder implements the `Credentials` interface defined in
// package `xds/boostrap` and encapsulates a Google Default credential.
type googleDefaultCredsBuilder struct{}

func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, error) {
return google.NewDefaultCredentials(), nil
func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) {
return google.NewDefaultCredentials(), func() {}, nil
}

func (d *googleDefaultCredsBuilder) Name() string {
Expand Down Expand Up @@ -151,6 +165,10 @@ type ServerConfig struct {
// when a resource is deleted, nor will it remove the existing resource value
// from its cache.
IgnoreResourceDeletion bool

// Cleanups are called when the xDS client for this server is closed. Allows
// cleaning up resources created specifically for this ServerConfig.
Cleanups []func()
}

// CredsDialOption returns the configured credentials as a grpc dial option.
Expand Down Expand Up @@ -206,12 +224,13 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error {
if c == nil {
continue
}
bundle, err := c.Build(cc.Config)
bundle, cancel, err := c.Build(cc.Config)
if err != nil {
return fmt.Errorf("failed to build credentials bundle from bootstrap for %q: %v", cc.Type, err)
}
sc.Creds = ChannelCreds(cc)
sc.credsDialOption = grpc.WithCredentialsBundle(bundle)
sc.Cleanups = append(sc.Cleanups, cancel)
break
}
return nil
Expand Down
53 changes: 38 additions & 15 deletions xds/internal/xdsclient/bootstrap/bootstrap_test.go
Expand Up @@ -1008,30 +1008,53 @@ func TestServerConfigMarshalAndUnmarshal(t *testing.T) {
}

func TestDefaultBundles(t *testing.T) {
if c := bootstrap.GetCredentials("google_default"); c == nil {
t.Errorf(`bootstrap.GetCredentials("google_default") credential is nil, want non-nil`)
}
tests := []string{"google_default", "insecure", "tls"}

if c := bootstrap.GetCredentials("insecure"); c == nil {
t.Errorf(`bootstrap.GetCredentials("insecure") credential is nil, want non-nil`)
for _, typename := range tests {
t.Run(typename, func(t *testing.T) {
if c := bootstrap.GetCredentials(typename); c == nil {
t.Errorf(`bootstrap.GetCredentials(%s) credential is nil, want non-nil`, typename)
}
})
}
}

func TestCredsBuilders(t *testing.T) {
b := &googleDefaultCredsBuilder{}
if _, err := b.Build(nil); err != nil {
t.Errorf("googleDefaultCredsBuilder.Build failed: %v", err)
tests := []struct {
typename string
builder bootstrap.Credentials
}{
{"google_default", &googleDefaultCredsBuilder{}},
{"insecure", &insecureCredsBuilder{}},
{"tls", &tlsCredsBuilder{}},
}
if got, want := b.Name(), "google_default"; got != want {
t.Errorf("googleDefaultCredsBuilder.Name = %v, want %v", got, want)

for _, test := range tests {
t.Run(test.typename, func(t *testing.T) {
if got, want := test.builder.Name(), test.typename; got != want {
t.Errorf("%T.Name = %v, want %v", test.builder, got, want)
}

_, stop, err := test.builder.Build(nil)
if err != nil {
t.Fatalf("%T.Build failed: %v", test.builder, err)
}
stop()
})
}
}

i := &insecureCredsBuilder{}
if _, err := i.Build(nil); err != nil {
t.Errorf("insecureCredsBuilder.Build failed: %v", err)
func TestTlsCredsBuilder(t *testing.T) {
tls := &tlsCredsBuilder{}
_, stop, err := tls.Build(json.RawMessage(`{}`))
if err != nil {
t.Fatalf("tls.Build() failed with error %s when expected to succeed", err)
}
stop()

if got, want := i.Name(), "insecure"; got != want {
t.Errorf("insecureCredsBuilder.Name = %v, want %v", got, want)
if _, stop, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil {
t.Errorf("tls.Build() succeeded with an invalid refresh interval, when expected to fail")
stop()
}
// package internal/xdsclient/tlscreds has tests for config validity.
}
12 changes: 12 additions & 0 deletions xds/internal/xdsclient/clientimpl.go
Expand Up @@ -85,5 +85,17 @@ func (c *clientImpl) close() {
c.authorityMu.Unlock()
c.serializerClose()

for _, f := range c.config.XDSServer.Cleanups {
f()
}
for _, a := range c.config.Authorities {
if a.XDSServer == nil {
// The server for this authority is the top-level one, cleaned up above.
continue
}
for _, f := range a.XDSServer.Cleanups {
f()
}
}
c.logger.Infof("Shutdown")
}
138 changes: 138 additions & 0 deletions xds/internal/xdsclient/tlscreds/bundle.go
@@ -0,0 +1,138 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

// Package tlscreds implements mTLS Credentials in xDS Bootstrap File.
// See gRFC A65: github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md.
package tlscreds

import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"net"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/grpc/credentials/tls/certprovider/pemfile"
"google.golang.org/grpc/internal/grpcsync"
)

// bundle is an implementation of credentials.Bundle which implements mTLS
// Credentials in xDS Bootstrap File.
type bundle struct {
transportCredentials credentials.TransportCredentials
}

// NewBundle returns a credentials.Bundle which implements mTLS Credentials in xDS
// Bootstrap File. It delegates certificate loading to a file_watcher provider
// if either client certificates or server root CA is specified. The second
// return value is a close func that should be called when the caller no longer
// needs this bundle.
// See gRFC A65: github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md
func NewBundle(jd json.RawMessage) (credentials.Bundle, func(), error) {
cfg := &struct {
CertificateFile string `json:"certificate_file"`
CACertificateFile string `json:"ca_certificate_file"`
PrivateKeyFile string `json:"private_key_file"`
}{}

if jd != nil {
if err := json.Unmarshal(jd, cfg); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal config: %v", err)
}
} // Else the config field is absent. Treat it as an empty config.

if cfg.CACertificateFile == "" && cfg.CertificateFile == "" && cfg.PrivateKeyFile == "" {
// We cannot use (and do not need) a file_watcher provider in this case,
// and can simply directly use the TLS transport credentials.
// Quoting A65:
//
// > The only difference between the file-watcher certificate provider
// > config and this one is that in the file-watcher certificate
// > provider, at least one of the "certificate_file" or
// > "ca_certificate_file" fields must be specified, whereas in this
// > configuration, it is acceptable to specify neither one.
return &bundle{transportCredentials: credentials.NewTLS(&tls.Config{})}, func() {}, nil
}
// Otherwise we need to use a file_watcher provider to watch the CA,
// private and public keys.

// The pemfile plugin (file_watcher) currently ignores BuildOptions.
provider, err := certprovider.GetProvider(pemfile.PluginName, jd, certprovider.BuildOptions{})
if err != nil {
return nil, nil, err
}
return &bundle{
transportCredentials: &reloadingCreds{provider: provider},
}, grpcsync.OnceFunc(func() { provider.Close() }), nil
}

func (t *bundle) TransportCredentials() credentials.TransportCredentials {
return t.transportCredentials
}

func (t *bundle) PerRPCCredentials() credentials.PerRPCCredentials {
// mTLS provides transport credentials only. There are no per-RPC
// credentials.
return nil
}

func (t *bundle) NewWithMode(string) (credentials.Bundle, error) {
// This bundle has a single mode which only uses TLS transport credentials,
// so there is no legitimate case where callers would call NewWithMode.
return nil, fmt.Errorf("xDS TLS credentials only support one mode")
}

// reloadingCreds is a credentials.TransportCredentials for client
// side mTLS that reloads the server root CA certificate and the client
// certificates from the provider on every client handshake. This is necessary
// because the standard TLS credentials do not support reloading CA
// certificates.
type reloadingCreds struct {
provider certprovider.Provider
}

func (c *reloadingCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
km, err := c.provider.KeyMaterial(ctx)
if err != nil {
return nil, nil, err
}
config := &tls.Config{
RootCAs: km.Roots,
Certificates: km.Certs,
}
return credentials.NewTLS(config).ClientHandshake(ctx, authority, rawConn)
}

func (c *reloadingCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{SecurityProtocol: "tls"}
}

func (c *reloadingCreds) Clone() credentials.TransportCredentials {
return &reloadingCreds{provider: c.provider}
}

func (c *reloadingCreds) OverrideServerName(string) error {
return errors.New("overriding server name is not supported by xDS client TLS credentials")
}

func (c *reloadingCreds) ServerHandshake(net.Conn) (net.Conn, credentials.AuthInfo, error) {
return nil, nil, errors.New("server handshake is not supported by xDS client TLS credentials")
}

0 comments on commit 6bc1906

Please sign in to comment.