From 4e179b8d3ec42e48c4802f1938f9d593a5e91408 Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Fri, 30 Oct 2020 15:52:55 -0700 Subject: [PATCH] pemfile: Move file watcher plugin from advancedtls to gRPC (#3981) --- .../tls/certprovider/pemfile/watcher.go | 252 +++++++++++ .../tls/certprovider/pemfile/watcher_test.go | 426 ++++++++++++++++++ .../advancedtls_integration_test.go | 40 +- security/advancedtls/pemfile_provider.go | 197 -------- security/advancedtls/pemfile_provider_test.go | 220 --------- 5 files changed, 699 insertions(+), 436 deletions(-) create mode 100644 credentials/tls/certprovider/pemfile/watcher.go create mode 100644 credentials/tls/certprovider/pemfile/watcher_test.go delete mode 100644 security/advancedtls/pemfile_provider.go delete mode 100644 security/advancedtls/pemfile_provider_test.go diff --git a/credentials/tls/certprovider/pemfile/watcher.go b/credentials/tls/certprovider/pemfile/watcher.go new file mode 100644 index 00000000000..29ea8b2b065 --- /dev/null +++ b/credentials/tls/certprovider/pemfile/watcher.go @@ -0,0 +1,252 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package pemfile provides a file watching certificate provider plugin +// implementation which works for files with PEM contents. +// +// Experimental +// +// Notice: All APIs in this package are experimental and may be removed in a +// later release. +package pemfile + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "time" + + "google.golang.org/grpc/credentials/tls/certprovider" + "google.golang.org/grpc/grpclog" +) + +const ( + defaultCertRefreshDuration = 1 * time.Hour + defaultRootRefreshDuration = 2 * time.Hour +) + +var ( + // For overriding from unit tests. + newDistributor = func() distributor { return certprovider.NewDistributor() } + + logger = grpclog.Component("pemfile") +) + +// Options configures a certificate provider plugin that watches a specified set +// of files that contain certificates and keys in PEM format. +type Options struct { + // CertFile is the file that holds the identity certificate. + // Optional. If this is set, KeyFile must also be set. + CertFile string + // KeyFile is the file that holds identity private key. + // Optional. If this is set, CertFile must also be set. + KeyFile string + // RootFile is the file that holds trusted root certificate(s). + // Optional. + RootFile string + // CertRefreshDuration is the amount of time the plugin waits before + // checking for updates in the specified identity certificate and key file. + // Optional. If not set, a default value (1 hour) will be used. + CertRefreshDuration time.Duration + // RootRefreshDuration is the amount of time the plugin waits before + // checking for updates in the specified root file. + // Optional. If not set, a default value (2 hour) will be used. + RootRefreshDuration time.Duration +} + +// NewProvider returns a new certificate provider plugin that is configured to +// watch the PEM files specified in the passed in options. +func NewProvider(o Options) (certprovider.Provider, error) { + if o.CertFile == "" && o.KeyFile == "" && o.RootFile == "" { + return nil, fmt.Errorf("pemfile: at least one credential file needs to be specified") + } + if keySpecified, certSpecified := o.KeyFile != "", o.CertFile != ""; keySpecified != certSpecified { + return nil, fmt.Errorf("pemfile: private key file and identity cert file should be both specified or not specified") + } + if o.CertRefreshDuration == 0 { + o.CertRefreshDuration = defaultCertRefreshDuration + } + if o.RootRefreshDuration == 0 { + o.RootRefreshDuration = defaultRootRefreshDuration + } + + provider := &watcher{opts: o} + if o.CertFile != "" && o.KeyFile != "" { + provider.identityDistributor = newDistributor() + } + if o.RootFile != "" { + provider.rootDistributor = newDistributor() + } + + ctx, cancel := context.WithCancel(context.Background()) + provider.cancel = cancel + go provider.run(ctx) + + return provider, nil +} + +// watcher is a certificate provider plugin that implements the +// certprovider.Provider interface. It watches a set of certificate and key +// files and provides the most up-to-date key material for consumption by +// credentials implementation. +type watcher struct { + identityDistributor distributor + rootDistributor distributor + opts Options + certFileContents []byte + keyFileContents []byte + rootFileContents []byte + cancel context.CancelFunc +} + +// distributor wraps the methods on certprovider.Distributor which are used by +// the plugin. This is very useful in tests which need to know exactly when the +// plugin updates its key material. +type distributor interface { + KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) + Set(km *certprovider.KeyMaterial, err error) + Stop() +} + +// updateIdentityDistributor checks if the cert/key files that the plugin is +// watching have changed, and if so, reads the new contents and updates the +// identityDistributor with the new key material. +// +// Skips updates when file reading or parsing fails. +// TODO(easwars): Retry with limit (on the number of retries or the amount of +// time) upon failures. +func (w *watcher) updateIdentityDistributor() { + if w.identityDistributor == nil { + return + } + + certFileContents, err := ioutil.ReadFile(w.opts.CertFile) + if err != nil { + logger.Warningf("certFile (%s) read failed: %v", w.opts.CertFile, err) + return + } + keyFileContents, err := ioutil.ReadFile(w.opts.KeyFile) + if err != nil { + logger.Warningf("keyFile (%s) read failed: %v", w.opts.KeyFile, err) + return + } + // If the file contents have not changed, skip updating the distributor. + if bytes.Equal(w.certFileContents, certFileContents) && bytes.Equal(w.keyFileContents, keyFileContents) { + return + } + + cert, err := tls.X509KeyPair(certFileContents, keyFileContents) + if err != nil { + logger.Warningf("tls.X509KeyPair(%q, %q) failed: %v", certFileContents, keyFileContents, err) + return + } + w.certFileContents = certFileContents + w.keyFileContents = keyFileContents + w.identityDistributor.Set(&certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}, nil) +} + +// updateRootDistributor checks if the root cert file that the plugin is +// watching hs changed, and if so, updates the rootDistributor with the new key +// material. +// +// Skips updates when root cert reading or parsing fails. +// TODO(easwars): Retry with limit (on the number of retries or the amount of +// time) upon failures. +func (w *watcher) updateRootDistributor() { + if w.rootDistributor == nil { + return + } + + rootFileContents, err := ioutil.ReadFile(w.opts.RootFile) + if err != nil { + logger.Warningf("rootFile (%s) read failed: %v", w.opts.RootFile, err) + return + } + trustPool := x509.NewCertPool() + if !trustPool.AppendCertsFromPEM(rootFileContents) { + logger.Warning("failed to parse root certificate") + return + } + // If the file contents have not changed, skip updating the distributor. + if bytes.Equal(w.rootFileContents, rootFileContents) { + return + } + + w.rootFileContents = rootFileContents + w.rootDistributor.Set(&certprovider.KeyMaterial{Roots: trustPool}, nil) +} + +// run is a long running goroutine which watches the configured files for +// changes, and pushes new key material into the appropriate distributors which +// is returned from calls to KeyMaterial(). +func (w *watcher) run(ctx context.Context) { + // Update both root and identity certs at the beginning. Subsequently, + // update only the appropriate file whose ticker has fired. + w.updateIdentityDistributor() + w.updateRootDistributor() + + identityTicker := time.NewTicker(w.opts.CertRefreshDuration) + rootTicker := time.NewTicker(w.opts.RootRefreshDuration) + for { + select { + case <-ctx.Done(): + identityTicker.Stop() + rootTicker.Stop() + if w.identityDistributor != nil { + w.identityDistributor.Stop() + } + if w.rootDistributor != nil { + w.rootDistributor.Stop() + } + return + case <-identityTicker.C: + w.updateIdentityDistributor() + case <-rootTicker.C: + w.updateRootDistributor() + } + } +} + +// KeyMaterial returns the key material sourced by the watcher. +// Callers are expected to use the returned value as read-only. +func (w *watcher) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) { + km := &certprovider.KeyMaterial{} + if w.identityDistributor != nil { + identityKM, err := w.identityDistributor.KeyMaterial(ctx) + if err != nil { + return nil, err + } + km.Certs = identityKM.Certs + } + if w.rootDistributor != nil { + rootKM, err := w.rootDistributor.KeyMaterial(ctx) + if err != nil { + return nil, err + } + km.Roots = rootKM.Roots + } + return km, nil +} + +// Close cleans up resources allocated by the watcher. +func (w *watcher) Close() { + w.cancel() +} diff --git a/credentials/tls/certprovider/pemfile/watcher_test.go b/credentials/tls/certprovider/pemfile/watcher_test.go new file mode 100644 index 00000000000..092bd30ece6 --- /dev/null +++ b/credentials/tls/certprovider/pemfile/watcher_test.go @@ -0,0 +1,426 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package pemfile + +import ( + "context" + "crypto/x509" + "io/ioutil" + "math/big" + "os" + "path" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + + "google.golang.org/grpc/credentials/tls/certprovider" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/testdata" +) + +const ( + // These are the names of files inside temporary directories, which the + // plugin is asked to watch. + certFile = "cert.pem" + keyFile = "key.pem" + rootFile = "ca.pem" + + defaultTestRefreshDuration = 100 * time.Millisecond + defaultTestTimeout = 5 * time.Second +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// TestNewProvider tests the NewProvider() function with different inputs. +func (s) TestNewProvider(t *testing.T) { + tests := []struct { + desc string + options Options + wantError bool + }{ + { + desc: "No credential files specified", + options: Options{}, + wantError: true, + }, + { + desc: "Only identity cert is specified", + options: Options{ + CertFile: testdata.Path("x509/client1_cert.pem"), + }, + wantError: true, + }, + { + desc: "Only identity key is specified", + options: Options{ + KeyFile: testdata.Path("x509/client1_key.pem"), + }, + wantError: true, + }, + { + desc: "Identity cert/key pair is specified", + options: Options{ + KeyFile: testdata.Path("x509/client1_key.pem"), + CertFile: testdata.Path("x509/client1_cert.pem"), + }, + }, + { + desc: "Only root certs are specified", + options: Options{ + RootFile: testdata.Path("x509/client_ca_cert.pem"), + }, + }, + { + desc: "Everything is specified", + options: Options{ + KeyFile: testdata.Path("x509/client1_key.pem"), + CertFile: testdata.Path("x509/client1_cert.pem"), + RootFile: testdata.Path("x509/client_ca_cert.pem"), + }, + wantError: false, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + provider, err := NewProvider(test.options) + if (err != nil) != test.wantError { + t.Fatalf("NewProvider(%v) = %v, want %v", test.options, err, test.wantError) + } + if err != nil { + return + } + provider.Close() + }) + } +} + +// wrappedDistributor wraps a distributor and pushes on a channel whenever new +// key material is pushed to the distributor. +type wrappedDistributor struct { + *certprovider.Distributor + distCh *testutils.Channel +} + +func newWrappedDistributor(distCh *testutils.Channel) *wrappedDistributor { + return &wrappedDistributor{ + distCh: distCh, + Distributor: certprovider.NewDistributor(), + } +} + +func (wd *wrappedDistributor) Set(km *certprovider.KeyMaterial, err error) { + wd.Distributor.Set(km, err) + wd.distCh.Send(nil) +} + +func createTmpFile(t *testing.T, src, dst string) { + t.Helper() + + data, err := ioutil.ReadFile(src) + if err != nil { + t.Fatalf("ioutil.ReadFile(%q) failed: %v", src, err) + } + if err := ioutil.WriteFile(dst, data, os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", dst, err) + } + t.Logf("Wrote file at: %s", dst) + t.Logf("%s", string(data)) +} + +// createTempDirWithFiles creates a temporary directory under the system default +// tempDir with the given dirSuffix. It also reads from certSrc, keySrc and +// rootSrc files are creates appropriate files under the newly create tempDir. +// Returns the name of the created tempDir. +func createTmpDirWithFiles(t *testing.T, dirSuffix, certSrc, keySrc, rootSrc string) string { + t.Helper() + + // Create a temp directory. Passing an empty string for the first argument + // uses the system temp directory. + dir, err := ioutil.TempDir("", dirSuffix) + if err != nil { + t.Fatalf("ioutil.TempDir() failed: %v", err) + } + t.Logf("Using tmpdir: %s", dir) + + createTmpFile(t, testdata.Path(certSrc), path.Join(dir, certFile)) + createTmpFile(t, testdata.Path(keySrc), path.Join(dir, keyFile)) + createTmpFile(t, testdata.Path(rootSrc), path.Join(dir, rootFile)) + return dir +} + +// initializeProvider performs setup steps common to all tests (except the one +// which uses symlinks). +func initializeProvider(t *testing.T, testName string) (string, certprovider.Provider, *testutils.Channel, func()) { + t.Helper() + + // Override the newDistributor to one which pushes on a channel that we + // can block on. + origDistributorFunc := newDistributor + distCh := testutils.NewChannel() + d := newWrappedDistributor(distCh) + newDistributor = func() distributor { return d } + + // Create a new provider to watch the files in tmpdir. + dir := createTmpDirWithFiles(t, testName+"*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem") + opts := Options{ + CertFile: path.Join(dir, certFile), + KeyFile: path.Join(dir, keyFile), + RootFile: path.Join(dir, rootFile), + CertRefreshDuration: defaultTestRefreshDuration, + RootRefreshDuration: defaultTestRefreshDuration, + } + prov, err := NewProvider(opts) + if err != nil { + t.Fatalf("NewProvider(%+v) failed: %v", opts, err) + } + + // Make sure the provider picks up the files and pushes the key material on + // to the distributors. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + for i := 0; i < 2; i++ { + // Since we have root and identity certs, we need to make sure the + // update is pushed on both of them. + if _, err := distCh.Receive(ctx); err != nil { + t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err) + } + } + + return dir, prov, distCh, func() { + newDistributor = origDistributorFunc + prov.Close() + } +} + +// TestProvider_NoUpdate tests the case where a file watcher plugin is created +// successfully, and the underlying files do not change. Verifies that the +// plugin does not push new updates to the distributor in this case. +func (s) TestProvider_NoUpdate(t *testing.T) { + _, prov, distCh, cancel := initializeProvider(t, "no_update") + defer cancel() + + // Make sure the provider is healthy and returns key material. + ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cc() + if _, err := prov.KeyMaterial(ctx); err != nil { + t.Fatalf("provider.KeyMaterial() failed: %v", err) + } + + // Files haven't change. Make sure no updates are pushed by the provider. + sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration) + defer sc() + if _, err := distCh.Receive(sCtx); err == nil { + t.Fatal("new key material pushed to distributor when underlying files did not change") + } +} + +// TestProvider_UpdateSuccess tests the case where a file watcher plugin is +// created successfully and the underlying files change. Verifies that the +// changes are picked up by the provider. +func (s) TestProvider_UpdateSuccess(t *testing.T) { + dir, prov, distCh, cancel := initializeProvider(t, "update_success") + defer cancel() + + // Make sure the provider is healthy and returns key material. + ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cc() + km1, err := prov.KeyMaterial(ctx) + if err != nil { + t.Fatalf("provider.KeyMaterial() failed: %v", err) + } + + // Change only the root file. + createTmpFile(t, testdata.Path("x509/server_ca_cert.pem"), path.Join(dir, rootFile)) + if _, err := distCh.Receive(ctx); err != nil { + t.Fatal("timeout waiting for new key material to be pushed to the distributor") + } + + // Make sure update is picked up. + km2, err := prov.KeyMaterial(ctx) + if err != nil { + t.Fatalf("provider.KeyMaterial() failed: %v", err) + } + if cmp.Equal(km1, km2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { + t.Fatal("expected provider to return new key material after update to underlying file") + } + + // Change only cert/key files. + createTmpFile(t, testdata.Path("x509/client2_cert.pem"), path.Join(dir, certFile)) + createTmpFile(t, testdata.Path("x509/client2_key.pem"), path.Join(dir, keyFile)) + if _, err := distCh.Receive(ctx); err != nil { + t.Fatal("timeout waiting for new key material to be pushed to the distributor") + } + + // Make sure update is picked up. + km3, err := prov.KeyMaterial(ctx) + if err != nil { + t.Fatalf("provider.KeyMaterial() failed: %v", err) + } + if cmp.Equal(km2, km3, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { + t.Fatal("expected provider to return new key material after update to underlying file") + } +} + +// TestProvider_UpdateSuccessWithSymlink tests the case where a file watcher +// plugin is created successfully to watch files through a symlink and the +// symlink is updates to point to new files. Verifies that the changes are +// picked up by the provider. +func (s) TestProvider_UpdateSuccessWithSymlink(t *testing.T) { + // Override the newDistributor to one which pushes on a channel that we + // can block on. + origDistributorFunc := newDistributor + distCh := testutils.NewChannel() + d := newWrappedDistributor(distCh) + newDistributor = func() distributor { return d } + defer func() { newDistributor = origDistributorFunc }() + + // Create two tempDirs with different files. + dir1 := createTmpDirWithFiles(t, "update_with_symlink1_*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem") + dir2 := createTmpDirWithFiles(t, "update_with_symlink2_*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/server_ca_cert.pem") + + // Create a symlink under a new tempdir, and make it point to dir1. + tmpdir, err := ioutil.TempDir("", "test_symlink_*") + if err != nil { + t.Fatalf("ioutil.TempDir() failed: %v", err) + } + symLinkName := path.Join(tmpdir, "test_symlink") + if err := os.Symlink(dir1, symLinkName); err != nil { + t.Fatalf("failed to create symlink to %q: %v", dir1, err) + } + + // Create a provider which watches the files pointed to by the symlink. + opts := Options{ + CertFile: path.Join(symLinkName, certFile), + KeyFile: path.Join(symLinkName, keyFile), + RootFile: path.Join(symLinkName, rootFile), + CertRefreshDuration: defaultTestRefreshDuration, + RootRefreshDuration: defaultTestRefreshDuration, + } + prov, err := NewProvider(opts) + if err != nil { + t.Fatalf("NewProvider(%+v) failed: %v", opts, err) + } + defer prov.Close() + + // Make sure the provider picks up the files and pushes the key material on + // to the distributors. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + for i := 0; i < 2; i++ { + // Since we have root and identity certs, we need to make sure the + // update is pushed on both of them. + if _, err := distCh.Receive(ctx); err != nil { + t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err) + } + } + km1, err := prov.KeyMaterial(ctx) + if err != nil { + t.Fatalf("provider.KeyMaterial() failed: %v", err) + } + + // Update the symlink to point to dir2. + symLinkTmpName := path.Join(tmpdir, "test_symlink.tmp") + if err := os.Symlink(dir2, symLinkTmpName); err != nil { + t.Fatalf("failed to create symlink to %q: %v", dir2, err) + } + if err := os.Rename(symLinkTmpName, symLinkName); err != nil { + t.Fatalf("failed to update symlink: %v", err) + } + + // Make sure the provider picks up the new files and pushes the key material + // on to the distributors. + for i := 0; i < 2; i++ { + // Since we have root and identity certs, we need to make sure the + // update is pushed on both of them. + if _, err := distCh.Receive(ctx); err != nil { + t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err) + } + } + km2, err := prov.KeyMaterial(ctx) + if err != nil { + t.Fatalf("provider.KeyMaterial() failed: %v", err) + } + + if cmp.Equal(km1, km2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { + t.Fatal("expected provider to return new key material after symlink update") + } +} + +// TestProvider_UpdateFailure_ThenSuccess tests the case where updating cert/key +// files fail. Verifies that the failed update does not push anything on the +// distributor. Then the update succeeds, and the test verifies that the key +// material is updated. +func (s) TestProvider_UpdateFailure_ThenSuccess(t *testing.T) { + dir, prov, distCh, cancel := initializeProvider(t, "update_failure") + defer cancel() + + // Make sure the provider is healthy and returns key material. + ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cc() + km1, err := prov.KeyMaterial(ctx) + if err != nil { + t.Fatalf("provider.KeyMaterial() failed: %v", err) + } + + // Update only the cert file. The key file is left unchanged. This should + // lead to these two files being not compatible with each other. This + // simulates the case where the watching goroutine might catch the files in + // the midst of an update. + createTmpFile(t, testdata.Path("x509/server1_cert.pem"), path.Join(dir, certFile)) + + // Since the last update left the files in an incompatible state, the update + // should not be picked up by our provider. + sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration) + defer sc() + if _, err := distCh.Receive(sCtx); err == nil { + t.Fatal("new key material pushed to distributor when underlying files did not change") + } + + // The provider should return key material corresponding to the old state. + km2, err := prov.KeyMaterial(ctx) + if err != nil { + t.Fatalf("provider.KeyMaterial() failed: %v", err) + } + if !cmp.Equal(km1, km2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { + t.Fatal("expected provider to not update key material") + } + + // Update the key file to match the cert file. + createTmpFile(t, testdata.Path("x509/server1_key.pem"), path.Join(dir, keyFile)) + + // Make sure update is picked up. + if _, err := distCh.Receive(ctx); err != nil { + t.Fatal("timeout waiting for new key material to be pushed to the distributor") + } + km3, err := prov.KeyMaterial(ctx) + if err != nil { + t.Fatalf("provider.KeyMaterial() failed: %v", err) + } + if cmp.Equal(km2, km3, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { + t.Fatal("expected provider to return new key material after update to underlying file") + } +} diff --git a/security/advancedtls/advancedtls_integration_test.go b/security/advancedtls/advancedtls_integration_test.go index 20a9a585796..d554468557a 100644 --- a/security/advancedtls/advancedtls_integration_test.go +++ b/security/advancedtls/advancedtls_integration_test.go @@ -32,6 +32,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/tls/certprovider" + "google.golang.org/grpc/credentials/tls/certprovider/pemfile" pb "google.golang.org/grpc/examples/helloworld/helloworld" "google.golang.org/grpc/security/advancedtls/internal/testutils" "google.golang.org/grpc/security/advancedtls/testdata" @@ -511,38 +513,38 @@ func copyFileContents(sourceFile, destinationFile string) error { // Create PEMFileProvider(s) watching the content changes of temporary // files. -func createProviders(tmpFiles *tmpCredsFiles) (*PEMFileProvider, *PEMFileProvider, *PEMFileProvider, *PEMFileProvider, error) { - clientIdentityOptions := PEMFileProviderOptions{ - CertFile: tmpFiles.clientCertTmp.Name(), - KeyFile: tmpFiles.clientKeyTmp.Name(), - IdentityInterval: credRefreshingInterval, +func createProviders(tmpFiles *tmpCredsFiles) (certprovider.Provider, certprovider.Provider, certprovider.Provider, certprovider.Provider, error) { + clientIdentityOptions := pemfile.Options{ + CertFile: tmpFiles.clientCertTmp.Name(), + KeyFile: tmpFiles.clientKeyTmp.Name(), + CertRefreshDuration: credRefreshingInterval, } - clientIdentityProvider, err := NewPEMFileProvider(clientIdentityOptions) + clientIdentityProvider, err := pemfile.NewProvider(clientIdentityOptions) if err != nil { return nil, nil, nil, nil, err } - clientRootOptions := PEMFileProviderOptions{ - TrustFile: tmpFiles.clientTrustTmp.Name(), - RootInterval: credRefreshingInterval, + clientRootOptions := pemfile.Options{ + RootFile: tmpFiles.clientTrustTmp.Name(), + RootRefreshDuration: credRefreshingInterval, } - clientRootProvider, err := NewPEMFileProvider(clientRootOptions) + clientRootProvider, err := pemfile.NewProvider(clientRootOptions) if err != nil { return nil, nil, nil, nil, err } - serverIdentityOptions := PEMFileProviderOptions{ - CertFile: tmpFiles.serverCertTmp.Name(), - KeyFile: tmpFiles.serverKeyTmp.Name(), - IdentityInterval: credRefreshingInterval, + serverIdentityOptions := pemfile.Options{ + CertFile: tmpFiles.serverCertTmp.Name(), + KeyFile: tmpFiles.serverKeyTmp.Name(), + CertRefreshDuration: credRefreshingInterval, } - serverIdentityProvider, err := NewPEMFileProvider(serverIdentityOptions) + serverIdentityProvider, err := pemfile.NewProvider(serverIdentityOptions) if err != nil { return nil, nil, nil, nil, err } - serverRootOptions := PEMFileProviderOptions{ - TrustFile: tmpFiles.serverTrustTmp.Name(), - RootInterval: credRefreshingInterval, + serverRootOptions := pemfile.Options{ + RootFile: tmpFiles.serverTrustTmp.Name(), + RootRefreshDuration: credRefreshingInterval, } - serverRootProvider, err := NewPEMFileProvider(serverRootOptions) + serverRootProvider, err := pemfile.NewProvider(serverRootOptions) if err != nil { return nil, nil, nil, nil, err } diff --git a/security/advancedtls/pemfile_provider.go b/security/advancedtls/pemfile_provider.go deleted file mode 100644 index 96b3587776e..00000000000 --- a/security/advancedtls/pemfile_provider.go +++ /dev/null @@ -1,197 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package advancedtls - -import ( - "context" - "crypto/tls" - "crypto/x509" - "fmt" - "io/ioutil" - "time" - - "google.golang.org/grpc/credentials/tls/certprovider" - "google.golang.org/grpc/grpclog" -) - -const defaultIdentityInterval = 1 * time.Hour -const defaultRootInterval = 2 * time.Hour - -// readKeyCertPairFunc will be overridden from unit tests. -var readKeyCertPairFunc = tls.LoadX509KeyPair - -// readTrustCertFunc will be overridden from unit tests. -var readTrustCertFunc = func(trustFile string) (*x509.CertPool, error) { - trustData, err := ioutil.ReadFile(trustFile) - if err != nil { - return nil, err - } - trustPool := x509.NewCertPool() - if !trustPool.AppendCertsFromPEM(trustData) { - return nil, fmt.Errorf("AppendCertsFromPEM failed to parse certificates") - } - return trustPool, nil -} - -var logger = grpclog.Component("advancedtls") - -// PEMFileProviderOptions contains options to configure a PEMFileProvider. -// Note that these fields will only take effect during construction. Once the -// PEMFileProvider starts, changing fields in PEMFileProviderOptions will have -// no effect. -type PEMFileProviderOptions struct { - // CertFile is the file path that holds identity certificate whose updates - // will be captured by a watching goroutine. - // Optional. If this is set, KeyFile must also be set. - CertFile string - // KeyFile is the file path that holds identity private key whose updates - // will be captured by a watching goroutine. - // Optional. If this is set, CertFile must also be set. - KeyFile string - // TrustFile is the file path that holds trust certificate whose updates will - // be captured by a watching goroutine. - // Optional. - TrustFile string - // IdentityInterval is the time duration between two credential update checks - // for identity certs. - // Optional. If not set, we will use the default interval(1 hour). - IdentityInterval time.Duration - // RootInterval is the time duration between two credential update checks - // for root certs. - // Optional. If not set, we will use the default interval(2 hours). - RootInterval time.Duration -} - -// PEMFileProvider implements certprovider.Provider. -// It provides the most up-to-date identity private key-cert pairs and/or -// root certificates. -type PEMFileProvider struct { - identityDistributor *certprovider.Distributor - rootDistributor *certprovider.Distributor - cancel context.CancelFunc -} - -func updateIdentityDistributor(distributor *certprovider.Distributor, certFile, keyFile string) { - if distributor == nil { - return - } - // Read identity certs from PEM files. - identityCert, err := readKeyCertPairFunc(certFile, keyFile) - if err != nil { - // If the reading produces an error, we will skip the update for this - // round and log the error. - logger.Warningf("tls.LoadX509KeyPair reads %s and %s failed: %v", certFile, keyFile, err) - return - } - distributor.Set(&certprovider.KeyMaterial{Certs: []tls.Certificate{identityCert}}, nil) -} - -func updateRootDistributor(distributor *certprovider.Distributor, trustFile string) { - if distributor == nil { - return - } - // Read root certs from PEM files. - trustPool, err := readTrustCertFunc(trustFile) - if err != nil { - // If the reading produces an error, we will skip the update for this - // round and log the error. - logger.Warningf("readTrustCertFunc reads %v failed: %v", trustFile, err) - return - } - distributor.Set(&certprovider.KeyMaterial{Roots: trustPool}, nil) -} - -// NewPEMFileProvider returns a new PEMFileProvider constructed using the -// provided options. -func NewPEMFileProvider(o PEMFileProviderOptions) (*PEMFileProvider, error) { - if o.CertFile == "" && o.KeyFile == "" && o.TrustFile == "" { - return nil, fmt.Errorf("at least one credential file needs to be specified") - } - if keySpecified, certSpecified := o.KeyFile != "", o.CertFile != ""; keySpecified != certSpecified { - return nil, fmt.Errorf("private key file and identity cert file should be both specified or not specified") - } - if o.IdentityInterval == 0 { - o.IdentityInterval = defaultIdentityInterval - } - if o.RootInterval == 0 { - o.RootInterval = defaultRootInterval - } - provider := &PEMFileProvider{} - if o.CertFile != "" && o.KeyFile != "" { - provider.identityDistributor = certprovider.NewDistributor() - } - if o.TrustFile != "" { - provider.rootDistributor = certprovider.NewDistributor() - } - // A goroutine to pull file changes. - identityTicker := time.NewTicker(o.IdentityInterval) - rootTicker := time.NewTicker(o.RootInterval) - ctx, cancel := context.WithCancel(context.Background()) - - go func() { - for { - updateIdentityDistributor(provider.identityDistributor, o.CertFile, o.KeyFile) - updateRootDistributor(provider.rootDistributor, o.TrustFile) - select { - case <-ctx.Done(): - identityTicker.Stop() - rootTicker.Stop() - return - case <-identityTicker.C: - break - case <-rootTicker.C: - break - } - } - }() - provider.cancel = cancel - return provider, nil -} - -// KeyMaterial returns the key material sourced by the PEMFileProvider. -// Callers are expected to use the returned value as read-only. -func (p *PEMFileProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) { - km := &certprovider.KeyMaterial{} - if p.identityDistributor != nil { - identityKM, err := p.identityDistributor.KeyMaterial(ctx) - if err != nil { - return nil, err - } - km.Certs = identityKM.Certs - } - if p.rootDistributor != nil { - rootKM, err := p.rootDistributor.KeyMaterial(ctx) - if err != nil { - return nil, err - } - km.Roots = rootKM.Roots - } - return km, nil -} - -// Close cleans up resources allocated by the PEMFileProvider. -func (p *PEMFileProvider) Close() { - p.cancel() - if p.identityDistributor != nil { - p.identityDistributor.Stop() - } - if p.rootDistributor != nil { - p.rootDistributor.Stop() - } -} diff --git a/security/advancedtls/pemfile_provider_test.go b/security/advancedtls/pemfile_provider_test.go deleted file mode 100644 index 48e0bd2f1c3..00000000000 --- a/security/advancedtls/pemfile_provider_test.go +++ /dev/null @@ -1,220 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package advancedtls - -import ( - "context" - "crypto/tls" - "crypto/x509" - "fmt" - "math/big" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/credentials/tls/certprovider" - "google.golang.org/grpc/security/advancedtls/internal/testutils" - "google.golang.org/grpc/security/advancedtls/testdata" -) - -func (s) TestNewPEMFileProvider(t *testing.T) { - tests := []struct { - desc string - options PEMFileProviderOptions - certFile string - keyFile string - trustFile string - wantError bool - }{ - { - desc: "Expect error if no credential files specified", - options: PEMFileProviderOptions{}, - wantError: true, - }, - { - desc: "Expect error if only certFile is specified", - options: PEMFileProviderOptions{ - CertFile: testdata.Path("client_cert_1.pem"), - }, - wantError: true, - }, - { - desc: "Should be good if only identity key cert pairs are specified", - options: PEMFileProviderOptions{ - KeyFile: testdata.Path("client_key_1.pem"), - CertFile: testdata.Path("client_cert_1.pem"), - }, - wantError: false, - }, - { - desc: "Should be good if only root certs are specified", - options: PEMFileProviderOptions{ - TrustFile: testdata.Path("client_trust_cert_1.pem"), - }, - wantError: false, - }, - { - desc: "Should be good if both identity pairs and root certs are specified", - options: PEMFileProviderOptions{ - KeyFile: testdata.Path("client_key_1.pem"), - CertFile: testdata.Path("client_cert_1.pem"), - TrustFile: testdata.Path("client_trust_cert_1.pem"), - }, - wantError: false, - }, - } - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - provider, err := NewPEMFileProvider(test.options) - if (err != nil) != test.wantError { - t.Fatalf("NewPEMFileProvider(%v) = %v, want %v", test.options, err, test.wantError) - } - if err != nil { - return - } - provider.Close() - }) - } - -} - -// This test overwrites the credential reading function used by the watching -// goroutine. It is tested under different stages: -// At stage 0, we force reading function to load ClientCert1 and ServerTrust1, -// and see if the credentials are picked up by the watching go routine. -// At stage 1, we force reading function to cause an error. The watching go -// routine should log the error while leaving the credentials unchanged. -// At stage 2, we force reading function to load ClientCert2 and ServerTrust2, -// and see if the new credentials are picked up. -func (s) TestWatchingRoutineUpdates(t *testing.T) { - // Load certificates. - cs := &testutils.CertStore{} - if err := cs.LoadCerts(); err != nil { - t.Fatalf("cs.LoadCerts() failed, err: %v", err) - } - tests := []struct { - desc string - options PEMFileProviderOptions - wantKmStage0 certprovider.KeyMaterial - wantKmStage1 certprovider.KeyMaterial - wantKmStage2 certprovider.KeyMaterial - }{ - { - desc: "use identity certs and root certs", - options: PEMFileProviderOptions{ - CertFile: "not_empty_cert_file", - KeyFile: "not_empty_key_file", - TrustFile: "not_empty_trust_file", - }, - wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1}, - wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1}, - wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}, Roots: cs.ServerTrust2}, - }, - { - desc: "use identity certs only", - options: PEMFileProviderOptions{ - CertFile: "not_empty_cert_file", - KeyFile: "not_empty_key_file", - }, - wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}}, - wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}}, - wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}}, - }, - { - desc: "use trust certs only", - options: PEMFileProviderOptions{ - TrustFile: "not_empty_trust_file", - }, - wantKmStage0: certprovider.KeyMaterial{Roots: cs.ServerTrust1}, - wantKmStage1: certprovider.KeyMaterial{Roots: cs.ServerTrust1}, - wantKmStage2: certprovider.KeyMaterial{Roots: cs.ServerTrust2}, - }, - } - for _, test := range tests { - testInterval := 200 * time.Millisecond - test.options.IdentityInterval = testInterval - test.options.RootInterval = testInterval - t.Run(test.desc, func(t *testing.T) { - stage := &stageInfo{} - oldReadKeyCertPairFunc := readKeyCertPairFunc - readKeyCertPairFunc = func(certFile, keyFile string) (tls.Certificate, error) { - switch stage.read() { - case 0: - return cs.ClientCert1, nil - case 1: - return tls.Certificate{}, fmt.Errorf("error occurred while reloading") - case 2: - return cs.ClientCert2, nil - default: - return tls.Certificate{}, fmt.Errorf("test stage not supported") - } - } - defer func() { - readKeyCertPairFunc = oldReadKeyCertPairFunc - }() - oldReadTrustCertFunc := readTrustCertFunc - readTrustCertFunc = func(trustFile string) (*x509.CertPool, error) { - switch stage.read() { - case 0: - return cs.ServerTrust1, nil - case 1: - return nil, fmt.Errorf("error occurred while reloading") - case 2: - return cs.ServerTrust2, nil - default: - return nil, fmt.Errorf("test stage not supported") - } - } - defer func() { - readTrustCertFunc = oldReadTrustCertFunc - }() - provider, err := NewPEMFileProvider(test.options) - if err != nil { - t.Fatalf("NewPEMFileProvider failed: %v", err) - } - defer provider.Close() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - //// ------------------------Stage 0------------------------------------ - // Wait for the refreshing go-routine to pick up the changes. - time.Sleep(1 * time.Second) - gotKM, err := provider.KeyMaterial(ctx) - if !cmp.Equal(*gotKM, test.wantKmStage0, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { - t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage0) - } - // ------------------------Stage 1------------------------------------ - stage.increase() - // Wait for the refreshing go-routine to pick up the changes. - time.Sleep(1 * time.Second) - gotKM, err = provider.KeyMaterial(ctx) - if !cmp.Equal(*gotKM, test.wantKmStage1, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { - t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage1) - } - //// ------------------------Stage 2------------------------------------ - // Wait for the refreshing go-routine to pick up the changes. - stage.increase() - time.Sleep(1 * time.Second) - gotKM, err = provider.KeyMaterial(ctx) - if !cmp.Equal(*gotKM, test.wantKmStage2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) { - t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage2) - } - stage.reset() - }) - } -}