Skip to content

Commit

Permalink
latest review from easwar
Browse files Browse the repository at this point in the history
- cleanup in clientimpl close rather than individual authority close
- convert tests in bootstrap_test to table driven tests
  • Loading branch information
atollena committed Dec 22, 2023
1 parent e694f4c commit 01f8c82
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 60 deletions.
4 changes: 0 additions & 4 deletions xds/internal/xdsclient/authority.go
Expand Up @@ -448,10 +448,6 @@ func (a *authority) close() {
a.resourcesMu.Lock()
a.closed = true
a.resourcesMu.Unlock()

for _, cleanup := range a.serverCfg.Cleanups {
cleanup()
}
}

func (a *authority) watchResource(rType xdsresource.Type, resourceName string, watcher xdsresource.ResourceWatcher) func() {
Expand Down
2 changes: 1 addition & 1 deletion xds/internal/xdsclient/bootstrap/bootstrap.go
Expand Up @@ -167,7 +167,7 @@ type ServerConfig struct {
IgnoreResourceDeletion bool

// Cleanups are called when the xDS client for this server is closed. Allows
// cleaning up resources created specifically for the xDS client.
// cleaning up resources created specifically for this ServerConfig.
Cleanups []func()
}

Expand Down
63 changes: 27 additions & 36 deletions xds/internal/xdsclient/bootstrap/bootstrap_test.go
Expand Up @@ -1008,49 +1008,39 @@ func TestServerConfigMarshalAndUnmarshal(t *testing.T) {
}

func TestDefaultBundles(t *testing.T) {
if c := bootstrap.GetCredentials("google_default"); c == nil {
t.Errorf(`bootstrap.GetCredentials("google_default") credential is nil, want non-nil`)
}

if c := bootstrap.GetCredentials("insecure"); c == nil {
t.Errorf(`bootstrap.GetCredentials("insecure") credential is nil, want non-nil`)
}
tests := []string{"google_default", "insecure", "tls"}

if c := bootstrap.GetCredentials("tls"); c == nil {
t.Errorf(`bootstrap.GetCredentials("tls") credential is nil, want non-nil`)
for _, typename := range tests {
t.Run(typename, func(t *testing.T) {
if c := bootstrap.GetCredentials(typename); c == nil {
t.Errorf(`bootstrap.GetCredentials(%s) credential is nil, want non-nil`, typename)
}
})
}
}

func TestCredsBuilders(t *testing.T) {
b := &googleDefaultCredsBuilder{}
if _, stop, err := b.Build(nil); err != nil {
t.Errorf("googleDefaultCredsBuilder.Build failed: %v", err)
} else {
stop()
}
if got, want := b.Name(), "google_default"; got != want {
t.Errorf("googleDefaultCredsBuilder.Name = %v, want %v", got, want)
}

i := &insecureCredsBuilder{}
if _, stop, err := i.Build(nil); err != nil {
t.Errorf("insecureCredsBuilder.Build failed: %v", err)
} else {
stop()
tests := []struct {
typename string
builder bootstrap.Credentials
}{
{"google_default", &googleDefaultCredsBuilder{}},
{"insecure", &insecureCredsBuilder{}},
{"tls", &tlsCredsBuilder{}},
}

if got, want := i.Name(), "insecure"; got != want {
t.Errorf("insecureCredsBuilder.Name = %v, want %v", got, want)
}
for _, test := range tests {
t.Run(test.typename, func(t *testing.T) {
if got, want := test.builder.Name(), test.typename; got != want {
t.Errorf("%T.Name = %v, want %v", test.builder, got, want)
}

tcb := &tlsCredsBuilder{}
if _, stop, err := tcb.Build(nil); err != nil {
t.Errorf("tlsCredsBuilder.Build failed: %v", err)
} else {
stop()
}
if got, want := tcb.Name(), "tls"; got != want {
t.Errorf("tlsCredsBuilder.Name = %v, want %v", got, want)
_, stop, err := test.builder.Build(nil)
if err != nil {
t.Fatalf("%T.Build failed: %v", test.builder, err)
}
stop()
})
}
}

Expand All @@ -1061,9 +1051,10 @@ func TestTlsCredsBuilder(t *testing.T) {
t.Fatalf("tls.Build() failed with error %s when expected to succeed", err)
}
stop()

if _, stop, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil {
t.Errorf("tls.Build() succeeded with an invalid refresh interval, when expected to fail")
stop()
}
// more tests for config validity are defined in tlscreds subpackage.
// package internal/xdsclient/tlscreds has tests for config validity.
}
12 changes: 12 additions & 0 deletions xds/internal/xdsclient/clientimpl.go
Expand Up @@ -85,5 +85,17 @@ func (c *clientImpl) close() {
c.authorityMu.Unlock()
c.serializerClose()

for _, f := range c.config.XDSServer.Cleanups {
f()
}
for _, a := range c.config.Authorities {
if a.XDSServer == nil {
// The server for this authority is the top-level one, cleaned up above.
continue
}
for _, f := range a.XDSServer.Cleanups {
f()
}
}
c.logger.Infof("Shutdown")
}
16 changes: 5 additions & 11 deletions xds/internal/xdsclient/tlscreds/bundle.go
Expand Up @@ -31,6 +31,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/grpc/credentials/tls/certprovider/pemfile"
"google.golang.org/grpc/internal/grpcsync"
)

// bundle is an implementation of credentials.Bundle which implements mTLS
Expand All @@ -41,7 +42,9 @@ type bundle struct {

// NewBundle returns a credentials.Bundle which implements mTLS Credentials in xDS
// Bootstrap File. It delegates certificate loading to a file_watcher provider
// if either client certificates or server root CA is specified.
// if either client certificates or server root CA is specified. The second
// return value is a close func that should be called when the caller no longer
// needs this bundle.
// See gRFC A65: github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md
func NewBundle(jd json.RawMessage) (credentials.Bundle, func(), error) {
cfg := &struct {
Expand Down Expand Up @@ -78,7 +81,7 @@ func NewBundle(jd json.RawMessage) (credentials.Bundle, func(), error) {
}
return &bundle{
transportCredentials: &reloadingCreds{provider: provider},
}, func() { provider.Close() }, nil
}, grpcsync.OnceFunc(func() { provider.Close() }), nil
}

func (t *bundle) TransportCredentials() credentials.TransportCredentials {
Expand All @@ -97,15 +100,6 @@ func (t *bundle) NewWithMode(string) (credentials.Bundle, error) {
return nil, fmt.Errorf("xDS TLS credentials only support one mode")

Check warning on line 100 in xds/internal/xdsclient/tlscreds/bundle.go

View check run for this annotation

Codecov / codecov/patch

xds/internal/xdsclient/tlscreds/bundle.go#L97-L100

Added lines #L97 - L100 were not covered by tests
}

// Close releases the underlying provider. Note that credentials.Bundle are
// not closeable, so users of this type must use a type assertion to call Close.
func (t *bundle) Close() {
cred, ok := t.transportCredentials.(*reloadingCreds)
if ok {
cred.provider.Close()
}
}

// reloadingCreds is a credentials.TransportCredentials for client
// side mTLS that reloads the server root CA certificate and the client
// certificates from the provider on every client handshake. This is necessary
Expand Down
15 changes: 8 additions & 7 deletions xds/internal/xdsclient/tlscreds/bundle_ext_test.go
Expand Up @@ -106,11 +106,11 @@ func (s) TestValidTlsBuilder(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
msg := json.RawMessage(test.jd)
if _, stop, err := tlscreds.NewBundle(msg); err != nil {
t.Errorf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err)
} else {
stop()
_, stop, err := tlscreds.NewBundle(msg)
if err != nil {
t.Fatalf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err)
}
stop()
})
}
}
Expand All @@ -133,11 +133,12 @@ func (s) TestInvalidTlsBuilder(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
msg := json.RawMessage(test.jd)
if _, stop, err := tlscreds.NewBundle(msg); err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) {
t.Errorf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix)
if err == nil {
_, stop, err := tlscreds.NewBundle(msg)
if err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) {
if stop != nil {
stop()
}
t.Fatalf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix)
}
})
}
Expand Down
2 changes: 1 addition & 1 deletion xds/internal/xdsclient/tlscreds/bundle_test.go
Expand Up @@ -46,7 +46,7 @@ func Test(t *testing.T) {

type failingProvider struct{}

func (f failingProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
func (f failingProvider) KeyMaterial(context.Context) (*certprovider.KeyMaterial, error) {
return nil, errors.New("test error")
}

Expand Down

0 comments on commit 01f8c82

Please sign in to comment.