diff --git a/pkg/daemon/ceph/osd/kms/kms.go b/pkg/daemon/ceph/osd/kms/kms.go index 3217d98ac98a9..18e18b6d65215 100644 --- a/pkg/daemon/ceph/osd/kms/kms.go +++ b/pkg/daemon/ceph/osd/kms/kms.go @@ -83,14 +83,7 @@ func (c *Config) PutSecret(secretName, secretValue string) error { } if c.IsVault() { // Store the secret in Vault - removeCertFiles := make(chan bool, 1) - // Remove cert files from operator's filesystem - defer func() { - removeCertFiles <- true - close(removeCertFiles) - }() - - v, err := InitVault(c.context, c.clusterInfo.Namespace, c.clusterSpec.Security.KeyManagementService.ConnectionDetails, removeCertFiles) + v, err := InitVault(c.context, c.clusterInfo.Namespace, c.clusterSpec.Security.KeyManagementService.ConnectionDetails) if err != nil { return errors.Wrap(err, "failed to init vault kms") } @@ -109,14 +102,7 @@ func (c *Config) GetSecret(secretName string) (string, error) { var value string if c.IsVault() { // Store the secret in Vault - removeCertFiles := make(chan bool, 1) - // Remove cert files from operator's filesystem - defer func() { - removeCertFiles <- true - close(removeCertFiles) - }() - - v, err := InitVault(c.context, c.clusterInfo.Namespace, c.clusterSpec.Security.KeyManagementService.ConnectionDetails, removeCertFiles) + v, err := InitVault(c.context, c.clusterInfo.Namespace, c.clusterSpec.Security.KeyManagementService.ConnectionDetails) if err != nil { return "", errors.Wrap(err, "failed to get secret in vault") } @@ -135,14 +121,7 @@ func (c *Config) GetSecret(secretName string) (string, error) { func (c *Config) DeleteSecret(secretName string) error { if c.IsVault() { // Store the secret in Vault - removeCertFiles := make(chan bool, 1) - // Remove cert files from operator's filesystem - defer func() { - removeCertFiles <- true - close(removeCertFiles) - }() - - v, err := InitVault(c.context, c.clusterInfo.Namespace, c.clusterSpec.Security.KeyManagementService.ConnectionDetails, removeCertFiles) + v, err := InitVault(c.context, c.clusterInfo.Namespace, c.clusterSpec.Security.KeyManagementService.ConnectionDetails) if err != nil { return errors.Wrap(err, "failed to delete secret in vault") } diff --git a/pkg/daemon/ceph/osd/kms/kms_test.go b/pkg/daemon/ceph/osd/kms/kms_test.go index ee2f31cd9f936..b84a0dbd79d61 100644 --- a/pkg/daemon/ceph/osd/kms/kms_test.go +++ b/pkg/daemon/ceph/osd/kms/kms_test.go @@ -122,7 +122,7 @@ func TestValidateConnectionDetails(t *testing.T) { vault.TestWaitActive(t, core) client := cluster.Cores[0].Client // Mock the client here - vaultClient = func(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string, removeCertFiles chan bool) (*api.Client, error) { + vaultClient = func(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string) (*api.Client, error) { return client, nil } if err := client.Sys().Mount("rook/", &api.MountInput{ diff --git a/pkg/daemon/ceph/osd/kms/vault.go b/pkg/daemon/ceph/osd/kms/vault.go index 801ff0a2d2edb..6c9dce8c4c553 100644 --- a/pkg/daemon/ceph/osd/kms/vault.go +++ b/pkg/daemon/ceph/osd/kms/vault.go @@ -21,7 +21,6 @@ import ( "io/ioutil" "os" "strings" - "time" "github.com/hashicorp/vault/api" "github.com/libopenstorage/secrets" @@ -47,6 +46,8 @@ var ( vaultMandatoryConnectionDetails = []string{api.EnvVaultAddress} ) +type RemoveCertFilesFunction func() + /* VAULT API INTERNAL VALUES // Refer to https://pkg.golangclub.com/github.com/hashicorp/vault/api?tab=doc#pkg-constants const EnvVaultAddress = "VAULT_ADDR" @@ -68,7 +69,7 @@ var ( */ // InitVault inits the secret store -func InitVault(context *clusterd.Context, namespace string, config map[string]string, removeCertFiles chan bool) (secrets.Secrets, error) { +func InitVault(context *clusterd.Context, namespace string, config map[string]string) (secrets.Secrets, error) { c := make(map[string]interface{}) // So that we don't alter the content of c.config for later iterations @@ -79,10 +80,11 @@ func InitVault(context *clusterd.Context, namespace string, config map[string]st } // Populate TLS config - newConfigWithTLS, err := configTLS(context, namespace, oriConfig, removeCertFiles) + newConfigWithTLS, removeCertFiles, err := configTLS(context, namespace, oriConfig) if err != nil { return nil, errors.Wrap(err, "failed to initialize vault tls configuration") } + defer removeCertFiles() // Populate TLS config for key, value := range newConfigWithTLS { @@ -98,9 +100,10 @@ func InitVault(context *clusterd.Context, namespace string, config map[string]st return v, nil } -func configTLS(clusterdContext *clusterd.Context, namespace string, config map[string]string, removeCertFiles chan bool) (map[string]string, error) { +func configTLS(clusterdContext *clusterd.Context, namespace string, config map[string]string) (map[string]string, RemoveCertFilesFunction, error) { ctx := context.TODO() var filesToRemove []*os.File + var removeCertFiles RemoveCertFilesFunction for _, tlsOption := range cephv1.VaultTLSConnectionDetails { tlsSecretName := GetParam(config, tlsOption) if tlsSecretName == "" { @@ -110,19 +113,19 @@ func configTLS(clusterdContext *clusterd.Context, namespace string, config map[s if !strings.Contains(tlsSecretName, EtcVaultDir) { secret, err := clusterdContext.Clientset.CoreV1().Secrets(namespace).Get(ctx, tlsSecretName, v1.GetOptions{}) if err != nil { - return nil, errors.Wrapf(err, "failed to fetch tls k8s secret %q", tlsSecretName) + return nil, removeCertFiles, errors.Wrapf(err, "failed to fetch tls k8s secret %q", tlsSecretName) } // Generate a temp file file, err := ioutil.TempFile("", "") if err != nil { - return nil, errors.Wrapf(err, "failed to generate temp file for k8s secret %q content", tlsSecretName) + return nil, removeCertFiles, errors.Wrapf(err, "failed to generate temp file for k8s secret %q content", tlsSecretName) } // Write into a file err = ioutil.WriteFile(file.Name(), secret.Data[tlsSecretKeyToCheck(tlsOption)], 0444) if err != nil { - return nil, errors.Wrapf(err, "failed to write k8s secret %q content to a file", tlsSecretName) + return nil, removeCertFiles, errors.Wrapf(err, "failed to write k8s secret %q content to a file", tlsSecretName) } logger.Debugf("replacing %q current content %q with %q", tlsOption, config[tlsOption], file.Name()) @@ -137,41 +140,25 @@ func configTLS(clusterdContext *clusterd.Context, namespace string, config map[s } } - // Run a goroutine that waits for a confirmation to remove the certificates from the filesystem - go func(filesToRemove []*os.File) { - logger.Debugf("files to remove: %+v", filesToRemove) - rmCertFiles := func() { - for _, file := range filesToRemove { - logger.Debugf("closing %q", file.Name()) - err := file.Close() - if err != nil { - logger.Errorf("failed to close file %q. %v", file.Name(), err) - } - logger.Debugf("closed %q", file.Name()) - logger.Debugf("removing %q", file.Name()) - err = os.Remove(file.Name()) - if err != nil { - logger.Errorf("failed to remove file %q. %v", file.Name(), err) - } - logger.Debugf("removed %q", file.Name()) + removeCertFiles = RemoveCertFilesFunction(func() { + filesToRemove := filesToRemove + for _, file := range filesToRemove { + logger.Debugf("closing %q", file.Name()) + err := file.Close() + if err != nil { + logger.Errorf("failed to close file %q. %v", file.Name(), err) } - } - for { - select { - case <-removeCertFiles: - logger.Debug("received confirmation from channel to remove vault certificates") - rmCertFiles() - logger.Info("successfully removed secret files") - return - case <-time.After(time.Second * 5): - logger.Info("never received anything from the channel removing vault cert files anyway") - rmCertFiles() - return + logger.Debugf("closed %q", file.Name()) + logger.Debugf("removing %q", file.Name()) + err = os.Remove(file.Name()) + if err != nil { + logger.Errorf("failed to remove file %q. %v", file.Name(), err) } + logger.Debugf("removed %q", file.Name()) } - }(filesToRemove) + }) - return config, nil + return config, removeCertFiles, nil } func put(v secrets.Secrets, secretName, secretValue string, keyContext map[string]string) error { diff --git a/pkg/daemon/ceph/osd/kms/vault_api.go b/pkg/daemon/ceph/osd/kms/vault_api.go index a5ceec99afe25..9361a6c9ca68f 100644 --- a/pkg/daemon/ceph/osd/kms/vault_api.go +++ b/pkg/daemon/ceph/osd/kms/vault_api.go @@ -39,7 +39,7 @@ var vaultClient = newVaultClient // newVaultClient returns a vault client, there is no need for any secretConfig validation // Since this is called after an already validated call InitVault() -func newVaultClient(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string, removeCertFiles chan bool) (*api.Client, error) { +func newVaultClient(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string) (*api.Client, error) { // DefaultConfig uses the environment variables if present. config := api.DefaultConfig() @@ -57,10 +57,11 @@ func newVaultClient(clusterdContext *clusterd.Context, namespace string, secretC } // Populate TLS config - newConfigWithTLS, err := configTLS(clusterdContext, namespace, localSecretConfig, removeCertFiles) + newConfigWithTLS, removeCertFiles, err := configTLS(clusterdContext, namespace, localSecretConfig) if err != nil { return nil, errors.Wrap(err, "failed to initialize vault tls configuration") } + defer removeCertFiles() // Populate TLS config for key, value := range newConfigWithTLS { @@ -109,15 +110,8 @@ func BackendVersion(clusterdContext *clusterd.Context, namespace string, secretC logger.Info("vault kv secret engine version set to v2") return v2, nil default: - removeCertFiles := make(chan bool, 1) - // Remove cert files from operator's filesystem - defer func() { - removeCertFiles <- true - close(removeCertFiles) - }() - // Initialize Vault client - vaultClient, err := vaultClient(clusterdContext, namespace, secretConfig, removeCertFiles) + vaultClient, err := vaultClient(clusterdContext, namespace, secretConfig) if err != nil { return "", errors.Wrap(err, "failed to initialize vault client") } diff --git a/pkg/daemon/ceph/osd/kms/vault_api_test.go b/pkg/daemon/ceph/osd/kms/vault_api_test.go index d388bf2341081..50863abcfa57e 100644 --- a/pkg/daemon/ceph/osd/kms/vault_api_test.go +++ b/pkg/daemon/ceph/osd/kms/vault_api_test.go @@ -42,7 +42,7 @@ func TestBackendVersion(t *testing.T) { client := cluster.Cores[0].Client // Mock the client here - vaultClient = func(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string, removeCertFiles chan bool) (*api.Client, error) { + vaultClient = func(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string) (*api.Client, error) { return client, nil } @@ -173,28 +173,22 @@ cjLxGL8tcZjHKg== } for _, s := range []*v1.Secret{sCa, sClCert, sClKey} { - if _, err := context.Clientset.CoreV1().Secrets(ns).Create(ctx, s, metav1.CreateOptions{}); err != nil { + if secret, err := context.Clientset.CoreV1().Secrets(ns).Create(ctx, s, metav1.CreateOptions{}); err != nil { t.Fatal(err) } else { defer func() { - err := context.Clientset.CoreV1().Secrets(ns).Delete(ctx, s.Name, metav1.DeleteOptions{}) + err := context.Clientset.CoreV1().Secrets(ns).Delete(ctx, secret.Name, metav1.DeleteOptions{}) if err != nil { - logger.Errorf("failed to delete secret %s: %v", s.Name, err) + logger.Errorf("failed to delete secret %s: %v", secret.Name, err) } }() } } - removeCertFiles := make(chan bool, 1) - // Remove cert files from operator's filesystem - defer func() { - removeCertFiles <- true - close(removeCertFiles) - }() - // Populate TLS config - newConfigWithTLS, err := configTLS(context, ns, secretConfig, removeCertFiles) + newConfigWithTLS, removeCertFiles, err := configTLS(context, ns, secretConfig) assert.NoError(t, err) + defer removeCertFiles() // Populate TLS config for key, value := range newConfigWithTLS { diff --git a/pkg/daemon/ceph/osd/kms/vault_test.go b/pkg/daemon/ceph/osd/kms/vault_test.go index e18321df73788..851dc3182298f 100644 --- a/pkg/daemon/ceph/osd/kms/vault_test.go +++ b/pkg/daemon/ceph/osd/kms/vault_test.go @@ -63,14 +63,12 @@ func Test_configTLS(t *testing.T) { "VAULT_BACKEND_PATH": "vault", } // No tls config - removeCertFiles := make(chan bool, 1) - _, err := configTLS(context, ns, config, removeCertFiles) + _, removeCertFiles, err := configTLS(context, ns, config) assert.NoError(t, err) - close(removeCertFiles) + defer removeCertFiles() }) t.Run("TLS config with already populated cert path", func(t *testing.T) { - removeCertFiles := make(chan bool, 1) config := map[string]string{ "foo": "bar", "KMS_PROVIDER": "vault", @@ -79,14 +77,13 @@ func Test_configTLS(t *testing.T) { "VAULT_CACERT": "/etc/vault/cacert", "VAULT_SKIP_VERIFY": "false", } - config, err := configTLS(context, ns, config, removeCertFiles) + config, removeCertFiles, err := configTLS(context, ns, config) assert.NoError(t, err) assert.Equal(t, "/etc/vault/cacert", config["VAULT_CACERT"]) - close(removeCertFiles) + defer removeCertFiles() }) t.Run("TLS config but no secret", func(t *testing.T) { - removeCertFiles := make(chan bool, 1) config := map[string]string{ "foo": "bar", "KMS_PROVIDER": "vault", @@ -95,14 +92,12 @@ func Test_configTLS(t *testing.T) { "VAULT_CACERT": "vault-ca-cert", "VAULT_SKIP_VERIFY": "false", } - _, err := configTLS(context, ns, config, removeCertFiles) + _, _, err := configTLS(context, ns, config) assert.Error(t, err) assert.EqualError(t, err, "failed to fetch tls k8s secret \"vault-ca-cert\": secrets \"vault-ca-cert\" not found") - close(removeCertFiles) }) t.Run("TLS config success!", func(t *testing.T) { - removeCertFiles := make(chan bool, 1) config := map[string]string{ "foo": "bar", "KMS_PROVIDER": "vault", @@ -120,17 +115,15 @@ func Test_configTLS(t *testing.T) { } _, err := context.Clientset.CoreV1().Secrets(ns).Create(ctx, s, metav1.CreateOptions{}) assert.NoError(t, err) - config, err = configTLS(context, ns, config, removeCertFiles) + config, removeCertFiles, err := configTLS(context, ns, config) + defer removeCertFiles() assert.NoError(t, err) assert.NotEqual(t, "vault-ca-cert", config["VAULT_CACERT"]) err = context.Clientset.CoreV1().Secrets(ns).Delete(ctx, s.Name, metav1.DeleteOptions{}) assert.NoError(t, err) - removeCertFiles <- true - close(removeCertFiles) }) t.Run("advanced TLS config success!", func(t *testing.T) { - removeCertFiles := make(chan bool, 1) config := map[string]string{ "foo": "bar", "KMS_PROVIDER": "vault", @@ -167,14 +160,13 @@ func Test_configTLS(t *testing.T) { assert.NoError(t, err) _, err = context.Clientset.CoreV1().Secrets(ns).Create(ctx, sClKey, metav1.CreateOptions{}) assert.NoError(t, err) - config, err = configTLS(context, ns, config, removeCertFiles) + config, removeCertFiles, err := configTLS(context, ns, config) assert.NoError(t, err) assert.NotEqual(t, "vault-ca-cert", config["VAULT_CACERT"]) assert.NotEqual(t, "vault-client-cert", config["VAULT_CLIENT_CERT"]) assert.NotEqual(t, "vault-client-key", config["VAULT_CLIENT_KEY"]) assert.FileExists(t, config["VAULT_CACERT"]) - removeCertFiles <- true - close(removeCertFiles) + removeCertFiles() time.Sleep(time.Second) assert.NoFileExists(t, config["VAULT_CACERT"]) assert.NoFileExists(t, config["VAULT_CLIENT_CERT"]) @@ -182,7 +174,6 @@ func Test_configTLS(t *testing.T) { }) t.Run("advanced TLS config success with timeout!", func(t *testing.T) { - removeCertFiles := make(chan bool, 1) config := map[string]string{ "foo": "bar", "KMS_PROVIDER": "vault", @@ -192,17 +183,17 @@ func Test_configTLS(t *testing.T) { "VAULT_CLIENT_CERT": "vault-client-cert", "VAULT_CLIENT_KEY": "vault-client-key", } - config, err := configTLS(context, ns, config, removeCertFiles) + config, removeCertFiles, err := configTLS(context, ns, config) assert.NoError(t, err) assert.NotEqual(t, "vault-ca-cert", config["VAULT_CACERT"]) assert.NotEqual(t, "vault-client-cert", config["VAULT_CLIENT_CERT"]) assert.NotEqual(t, "vault-client-key", config["VAULT_CLIENT_KEY"]) assert.FileExists(t, config["VAULT_CACERT"]) + removeCertFiles() time.Sleep(time.Second * 6) assert.NoFileExists(t, config["VAULT_CACERT"]) assert.NoFileExists(t, config["VAULT_CLIENT_CERT"]) assert.NoFileExists(t, config["VAULT_CLIENT_KEY"]) - close(removeCertFiles) }) }