diff --git a/pkg/daemon/ceph/osd/kms/kms.go b/pkg/daemon/ceph/osd/kms/kms.go index 018143da77cc..18e18b6d6521 100644 --- a/pkg/daemon/ceph/osd/kms/kms.go +++ b/pkg/daemon/ceph/osd/kms/kms.go @@ -202,7 +202,7 @@ func ValidateConnectionDetails(clusterdContext *clusterd.Context, securitySpec * case VaultKVSecretEngineKey: // Append Backend Version if not already present if GetParam(securitySpec.KeyManagementService.ConnectionDetails, vault.VaultBackendKey) == "" { - backendVersion, err := BackendVersion(securitySpec.KeyManagementService.ConnectionDetails) + backendVersion, err := BackendVersion(clusterdContext, ns, securitySpec.KeyManagementService.ConnectionDetails) if err != nil { return errors.Wrap(err, "failed to get backend version") } diff --git a/pkg/daemon/ceph/osd/kms/kms_test.go b/pkg/daemon/ceph/osd/kms/kms_test.go index 957f153fea23..b84a0dbd79d6 100644 --- a/pkg/daemon/ceph/osd/kms/kms_test.go +++ b/pkg/daemon/ceph/osd/kms/kms_test.go @@ -91,7 +91,7 @@ func TestValidateConnectionDetails(t *testing.T) { securitySpec.KeyManagementService.ConnectionDetails["VAULT_CACERT"] = "vault-ca-secret" err = ValidateConnectionDetails(context, securitySpec, ns) assert.Error(t, err, "") - assert.EqualError(t, err, "failed to validate vault connection details: failed to find TLS connection details k8s secret \"VAULT_CACERT\"") + assert.EqualError(t, err, "failed to validate vault connection details: failed to find TLS connection details k8s secret \"vault-ca-secret\"") // Error: TLS secret exists but empty key tlsSecret := &v1.Secret{ @@ -122,7 +122,9 @@ func TestValidateConnectionDetails(t *testing.T) { vault.TestWaitActive(t, core) client := cluster.Cores[0].Client // Mock the client here - vaultClient = func(secretConfig map[string]string) (*api.Client, error) { return client, nil } + vaultClient = func(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string) (*api.Client, error) { + return client, nil + } if err := client.Sys().Mount("rook/", &api.MountInput{ Type: "kv-v2", Options: map[string]string{"version": "2"}, diff --git a/pkg/daemon/ceph/osd/kms/vault.go b/pkg/daemon/ceph/osd/kms/vault.go index 2c27a3c77f83..182c55e1db40 100644 --- a/pkg/daemon/ceph/osd/kms/vault.go +++ b/pkg/daemon/ceph/osd/kms/vault.go @@ -19,6 +19,7 @@ package kms import ( "context" "io/ioutil" + "os" "strings" "github.com/hashicorp/vault/api" @@ -45,6 +46,14 @@ var ( vaultMandatoryConnectionDetails = []string{api.EnvVaultAddress} ) +// Used for unit tests mocking too as well as production code +var ( + createTmpFile = ioutil.TempFile + getRemoveCertFiles = getRemoveCertFilesFunc +) + +type removeCertFilesFunction func() + /* VAULT API INTERNAL VALUES // Refer to https://pkg.golangclub.com/github.com/hashicorp/vault/api?tab=doc#pkg-constants const EnvVaultAddress = "VAULT_ADDR" @@ -77,10 +86,11 @@ func InitVault(context *clusterd.Context, namespace string, config map[string]st } // Populate TLS config - newConfigWithTLS, err := configTLS(context, namespace, oriConfig) + newConfigWithTLS, removeCertFiles, err := configTLS(context, namespace, oriConfig) if err != nil { return nil, errors.Wrap(err, "failed to initialize vault tls configuration") } + defer removeCertFiles() // Populate TLS config for key, value := range newConfigWithTLS { @@ -96,8 +106,31 @@ func InitVault(context *clusterd.Context, namespace string, config map[string]st return v, nil } -func configTLS(clusterdContext *clusterd.Context, namespace string, config map[string]string) (map[string]string, error) { +// configTLS returns a map of TLS config that map physical files for the TLS library to load +// Also it returns a function to remove the temporary files (certs, keys) +// The signature has named result parameters to help building 'defer' statements especially for the +// content of removeCertFiles which needs to be populated by the files to remove if no errors and be +// nil on errors +func configTLS(clusterdContext *clusterd.Context, namespace string, config map[string]string) (newConfig map[string]string, removeCertFiles removeCertFilesFunction, retErr error) { ctx := context.TODO() + var filesToRemove []*os.File + + defer func() { + // Build the function that the caller should use to remove the temp files here + // create it when this function is returning based on the currently-recorded files + removeCertFiles = getRemoveCertFiles(filesToRemove) + if retErr != nil { + // If we encountered an error, remove the temp files + removeCertFiles() + + // Also return an empty function to remove the temp files + // It's fine to use nil here since the defer from the calling functions is only + // triggered after evaluating any error, if on error the defer is not triggered since we + // have returned already + removeCertFiles = nil + } + }() + for _, tlsOption := range cephv1.VaultTLSConnectionDetails { tlsSecretName := GetParam(config, tlsOption) if tlsSecretName == "" { @@ -107,31 +140,52 @@ func configTLS(clusterdContext *clusterd.Context, namespace string, config map[s if !strings.Contains(tlsSecretName, EtcVaultDir) { secret, err := clusterdContext.Clientset.CoreV1().Secrets(namespace).Get(ctx, tlsSecretName, v1.GetOptions{}) if err != nil { - return nil, errors.Wrapf(err, "failed to fetch tls k8s secret %q", tlsSecretName) + return nil, removeCertFiles, errors.Wrapf(err, "failed to fetch tls k8s secret %q", tlsSecretName) } - // Generate a temp file - file, err := ioutil.TempFile("", "") + file, err := createTmpFile("", "") if err != nil { - return nil, errors.Wrapf(err, "failed to generate temp file for k8s secret %q content", tlsSecretName) + return nil, removeCertFiles, errors.Wrapf(err, "failed to generate temp file for k8s secret %q content", tlsSecretName) } // Write into a file err = ioutil.WriteFile(file.Name(), secret.Data[tlsSecretKeyToCheck(tlsOption)], 0444) if err != nil { - return nil, errors.Wrapf(err, "failed to write k8s secret %q content to a file", tlsSecretName) + return nil, removeCertFiles, errors.Wrapf(err, "failed to write k8s secret %q content to a file", tlsSecretName) } logger.Debugf("replacing %q current content %q with %q", tlsOption, config[tlsOption], file.Name()) - // update the env var with the path + // Update the env var with the path config[tlsOption] = file.Name() + + // Add the file to the list of files to remove + filesToRemove = append(filesToRemove, file) } else { logger.Debugf("value of tlsOption %q tlsSecretName is already correct %q", tlsOption, tlsSecretName) } } - return config, nil + return config, removeCertFiles, nil +} + +func getRemoveCertFilesFunc(filesToRemove []*os.File) removeCertFilesFunction { + return removeCertFilesFunction(func() { + for _, file := range filesToRemove { + logger.Debugf("closing %q", file.Name()) + err := file.Close() + if err != nil { + logger.Errorf("failed to close file %q. %v", file.Name(), err) + } + logger.Debugf("closed %q", file.Name()) + logger.Debugf("removing %q", file.Name()) + err = os.Remove(file.Name()) + if err != nil { + logger.Errorf("failed to remove file %q. %v", file.Name(), err) + } + logger.Debugf("removed %q", file.Name()) + } + }) } func put(v secrets.Secrets, secretName, secretValue string, keyContext map[string]string) error { @@ -215,7 +269,7 @@ func validateVaultConnectionDetails(clusterdContext *clusterd.Context, ns string // Fetch the secret s, err := clusterdContext.Clientset.CoreV1().Secrets(ns).Get(ctx, tlsSecretName, v1.GetOptions{}) if err != nil { - return errors.Errorf("failed to find TLS connection details k8s secret %q", tlsOption) + return errors.Errorf("failed to find TLS connection details k8s secret %q", tlsSecretName) } // Check the Secret key and its content diff --git a/pkg/daemon/ceph/osd/kms/vault_api.go b/pkg/daemon/ceph/osd/kms/vault_api.go index 2e2cdb2a7788..9361a6c9ca68 100644 --- a/pkg/daemon/ceph/osd/kms/vault_api.go +++ b/pkg/daemon/ceph/osd/kms/vault_api.go @@ -23,6 +23,7 @@ import ( "github.com/libopenstorage/secrets/vault" "github.com/libopenstorage/secrets/vault/utils" "github.com/pkg/errors" + "github.com/rook/rook/pkg/clusterd" "github.com/hashicorp/vault/api" ) @@ -38,16 +39,35 @@ var vaultClient = newVaultClient // newVaultClient returns a vault client, there is no need for any secretConfig validation // Since this is called after an already validated call InitVault() -func newVaultClient(secretConfig map[string]string) (*api.Client, error) { +func newVaultClient(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string) (*api.Client, error) { // DefaultConfig uses the environment variables if present. config := api.DefaultConfig() + // Always use a new map otherwise the map will mutate and subsequent calls will fail since the + // TLS content has been altered by the TLS config in vaultClient() + localSecretConfig := make(map[string]string) + for k, v := range secretConfig { + localSecretConfig[k] = v + } + // Convert map string to map interface c := make(map[string]interface{}) - for k, v := range secretConfig { + for k, v := range localSecretConfig { c[k] = v } + // Populate TLS config + newConfigWithTLS, removeCertFiles, err := configTLS(clusterdContext, namespace, localSecretConfig) + if err != nil { + return nil, errors.Wrap(err, "failed to initialize vault tls configuration") + } + defer removeCertFiles() + + // Populate TLS config + for key, value := range newConfigWithTLS { + c[key] = string(value) + } + // Configure TLS if err := utils.ConfigureTLS(config, c); err != nil { return nil, err @@ -64,7 +84,7 @@ func newVaultClient(secretConfig map[string]string) (*api.Client, error) { client.SetToken(strings.TrimSuffix(os.Getenv(api.EnvVaultToken), "\n")) // Set Vault address, was validated by ValidateConnectionDetails() - err = client.SetAddress(strings.TrimSuffix(secretConfig[api.EnvVaultAddress], "\n")) + err = client.SetAddress(strings.TrimSuffix(localSecretConfig[api.EnvVaultAddress], "\n")) if err != nil { return nil, err } @@ -72,7 +92,7 @@ func newVaultClient(secretConfig map[string]string) (*api.Client, error) { return client, nil } -func BackendVersion(secretConfig map[string]string) (string, error) { +func BackendVersion(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string) (string, error) { v1 := "v1" v2 := "v2" @@ -91,7 +111,7 @@ func BackendVersion(secretConfig map[string]string) (string, error) { return v2, nil default: // Initialize Vault client - vaultClient, err := vaultClient(secretConfig) + vaultClient, err := vaultClient(clusterdContext, namespace, secretConfig) if err != nil { return "", errors.Wrap(err, "failed to initialize vault client") } diff --git a/pkg/daemon/ceph/osd/kms/vault_api_test.go b/pkg/daemon/ceph/osd/kms/vault_api_test.go index 774298271a7e..50863abcfa57 100644 --- a/pkg/daemon/ceph/osd/kms/vault_api_test.go +++ b/pkg/daemon/ceph/osd/kms/vault_api_test.go @@ -17,6 +17,7 @@ limitations under the License. package kms import ( + "context" "testing" kv "github.com/hashicorp/vault-plugin-secrets-kv" @@ -24,6 +25,12 @@ import ( vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" + "github.com/libopenstorage/secrets/vault/utils" + "github.com/rook/rook/pkg/clusterd" + "github.com/rook/rook/pkg/operator/test" + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) func TestBackendVersion(t *testing.T) { @@ -35,7 +42,9 @@ func TestBackendVersion(t *testing.T) { client := cluster.Cores[0].Client // Mock the client here - vaultClient = func(secretConfig map[string]string) (*api.Client, error) { return client, nil } + vaultClient = func(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string) (*api.Client, error) { + return client, nil + } // Set up the kv store if err := client.Sys().Mount("rook/", &api.MountInput{ @@ -67,7 +76,7 @@ func TestBackendVersion(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := BackendVersion(tt.args.secretConfig) + got, err := BackendVersion(&clusterd.Context{}, "ns", tt.args.secretConfig) if (err != nil) != tt.wantErr { t.Errorf("BackendVersion() error = %v, wantErr %v", err, tt.wantErr) return @@ -91,3 +100,102 @@ func fakeVaultServer(t *testing.T) *vault.TestCluster { return cluster } + +func TestTLSConfig(t *testing.T) { + ns := "rook-ceph" + ctx := context.TODO() + context := &clusterd.Context{Clientset: test.New(t, 3)} + secretConfig := map[string]string{ + "foo": "bar", + "KMS_PROVIDER": "vault", + "VAULT_ADDR": "1.1.1.1", + "VAULT_BACKEND_PATH": "vault", + "VAULT_CACERT": "vault-ca-cert", + "VAULT_CLIENT_CERT": "vault-client-cert", + "VAULT_CLIENT_KEY": "vault-client-key", + } + + // DefaultConfig uses the environment variables if present. + config := api.DefaultConfig() + + // Convert map string to map interface + c := make(map[string]interface{}) + for k, v := range secretConfig { + c[k] = v + } + + sCa := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "vault-ca-cert", + Namespace: ns, + }, + Data: map[string][]byte{"cert": []byte(`-----BEGIN CERTIFICATE----- +MIIBJTCB0AIJAPNFNz1CNlDOMA0GCSqGSIb3DQEBCwUAMBoxCzAJBgNVBAYTAkZS +MQswCQYDVQQIDAJGUjAeFw0yMTA5MzAwODAzNDBaFw0yNDA2MjYwODAzNDBaMBox +CzAJBgNVBAYTAkZSMQswCQYDVQQIDAJGUjBcMA0GCSqGSIb3DQEBAQUAA0sAMEgC +QQDHeZ47hVBcryl6SCghM8Zj3Q6DQzJzno1J7EjPXef5m+pIVAEylS9sQuwKtFZc +vv3qS/OVFExmMdbrvfKEIfbBAgMBAAEwDQYJKoZIhvcNAQELBQADQQAAnflLuUM3 +4Dq0v7If4cgae2mr7jj3U/lIpHVtFbF7kVjC/eqmeN1a9u0UbRHKkUr+X1mVX3rJ +BvjQDN6didwQ +-----END CERTIFICATE-----`)}, + } + + sClCert := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "vault-client-cert", + Namespace: ns, + }, + Data: map[string][]byte{"cert": []byte(`-----BEGIN CERTIFICATE----- +MIIBEDCBuwIBATANBgkqhkiG9w0BAQUFADAaMQswCQYDVQQGEwJGUjELMAkGA1UE +CAwCRlIwHhcNMjEwOTMwMDgwNDA1WhcNMjQwNjI2MDgwNDA1WjANMQswCQYDVQQG +EwJGUjBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQCpWJqKhSES3BiFkt2M82xy3tkB +plDS8DM0s/+VkqfZlVG18KbbIVDHi1lsPjjs/Aja7lWymw0ycV4KGEcqxdmNAgMB +AAEwDQYJKoZIhvcNAQEFBQADQQC5esmoTqp4uEWyC+GKbTTFp8ngMUywAtZJs4nS +wdoF3ZJJzo4ps0saP1ww5LBdeeXUURscxyaFfCFmGODaHJJn +-----END CERTIFICATE-----`)}, + } + + sClKey := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "vault-client-key", + Namespace: ns, + }, + Data: map[string][]byte{"key": []byte(`-----BEGIN PRIVATE KEY----- +MIIBVgIBADANBgkqhkiG9w0BAQEFAASCAUAwggE8AgEAAkEAqViaioUhEtwYhZLd +jPNsct7ZAaZQ0vAzNLP/lZKn2ZVRtfCm2yFQx4tZbD447PwI2u5VspsNMnFeChhH +KsXZjQIDAQABAkARlCv+oxEq1wQIoZUz83TXe8CFBlGvg9Wc6+5lBWM9F7K4by7i +IB5hQ2oaTNN+1Kxzf+XRM9R7sMPP9qFEp0LhAiEA0PzsQqbvNUVEx8X16Hed6V/Z +yvL1iZeHvc2QIbGjZGkCIQDPcM7U0frsFIPuMY4zpX2b6w4rpxZN7Kybp9/3l0tX +hQIhAJVWVsGeJksLr4WNuRYf+9BbNPdoO/rRNCd2L+tT060ZAiEAl0uontITl9IS +s0yTcZm29lxG9pGkE+uVrOWQ1W0Ud10CIQDJ/L+VCQgjO+SviUECc/nMwhWDMT+V +cjLxGL8tcZjHKg== +-----END PRIVATE KEY-----`)}, + } + + for _, s := range []*v1.Secret{sCa, sClCert, sClKey} { + if secret, err := context.Clientset.CoreV1().Secrets(ns).Create(ctx, s, metav1.CreateOptions{}); err != nil { + t.Fatal(err) + } else { + defer func() { + err := context.Clientset.CoreV1().Secrets(ns).Delete(ctx, secret.Name, metav1.DeleteOptions{}) + if err != nil { + logger.Errorf("failed to delete secret %s: %v", secret.Name, err) + } + }() + } + } + + // Populate TLS config + newConfigWithTLS, removeCertFiles, err := configTLS(context, ns, secretConfig) + assert.NoError(t, err) + defer removeCertFiles() + + // Populate TLS config + for key, value := range newConfigWithTLS { + c[key] = string(value) + } + + // Configure TLS + err = utils.ConfigureTLS(config, c) + assert.NoError(t, err) +} diff --git a/pkg/daemon/ceph/osd/kms/vault_test.go b/pkg/daemon/ceph/osd/kms/vault_test.go index 043462f52edc..adbe16148342 100644 --- a/pkg/daemon/ceph/osd/kms/vault_test.go +++ b/pkg/daemon/ceph/osd/kms/vault_test.go @@ -18,8 +18,12 @@ package kms import ( "context" + "io/ioutil" + "os" "testing" + "github.com/coreos/pkg/capnslog" + "github.com/pkg/errors" "github.com/rook/rook/pkg/clusterd" "github.com/rook/rook/pkg/operator/test" "github.com/stretchr/testify/assert" @@ -50,112 +54,204 @@ func Test_tlsSecretKeyToCheck(t *testing.T) { } func Test_configTLS(t *testing.T) { + // Set DEBUG logging + capnslog.SetGlobalLogLevel(capnslog.DEBUG) + os.Setenv("ROOK_LOG_LEVEL", "DEBUG") ctx := context.TODO() - config := map[string]string{ - "foo": "bar", - "KMS_PROVIDER": "vault", - "VAULT_ADDR": "1.1.1.1", - "VAULT_BACKEND_PATH": "vault", - } ns := "rook-ceph" context := &clusterd.Context{Clientset: test.New(t, 3)} - // No tls config - _, err := configTLS(context, ns, config) - assert.NoError(t, err) - - // TLS config with correct values - config = map[string]string{ - "foo": "bar", - "KMS_PROVIDER": "vault", - "VAULT_ADDR": "1.1.1.1", - "VAULT_BACKEND_PATH": "vault", - "VAULT_CACERT": "/etc/vault/cacert", - "VAULT_SKIP_VERIFY": "false", - } - config, err = configTLS(context, ns, config) - assert.NoError(t, err) - assert.Equal(t, "/etc/vault/cacert", config["VAULT_CACERT"]) - - // TLS config but no secret - config = map[string]string{ - "foo": "bar", - "KMS_PROVIDER": "vault", - "VAULT_ADDR": "1.1.1.1", - "VAULT_BACKEND_PATH": "vault", - "VAULT_CACERT": "vault-ca-cert", - "VAULT_SKIP_VERIFY": "false", - } - _, err = configTLS(context, ns, config) - assert.Error(t, err) - assert.EqualError(t, err, "failed to fetch tls k8s secret \"vault-ca-cert\": secrets \"vault-ca-cert\" not found") - - // TLS config success! - config = map[string]string{ - "foo": "bar", - "KMS_PROVIDER": "vault", - "VAULT_ADDR": "1.1.1.1", - "VAULT_BACKEND_PATH": "vault", - "VAULT_CACERT": "vault-ca-cert", - "VAULT_SKIP_VERIFY": "false", - } - s := &v1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: "vault-ca-cert", - Namespace: ns, - }, - Data: map[string][]byte{"cert": []byte("bar")}, - } - _, err = context.Clientset.CoreV1().Secrets(ns).Create(ctx, s, metav1.CreateOptions{}) - assert.NoError(t, err) - config, err = configTLS(context, ns, config) - assert.NoError(t, err) - assert.NotEqual(t, "vault-ca-cert", config["VAULT_CACERT"]) - err = context.Clientset.CoreV1().Secrets(ns).Delete(ctx, s.Name, metav1.DeleteOptions{}) - assert.NoError(t, err) - - // All TLS success! - config = map[string]string{ - "foo": "bar", - "KMS_PROVIDER": "vault", - "VAULT_ADDR": "1.1.1.1", - "VAULT_BACKEND_PATH": "vault", - "VAULT_CACERT": "vault-ca-cert", - "VAULT_CLIENT_CERT": "vault-client-cert", - "VAULT_CLIENT_KEY": "vault-client-key", - } - sCa := &v1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: "vault-ca-cert", - Namespace: ns, - }, - Data: map[string][]byte{"cert": []byte("bar")}, - } - sClCert := &v1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: "vault-client-cert", - Namespace: ns, - }, - Data: map[string][]byte{"cert": []byte("bar")}, - } - sClKey := &v1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: "vault-client-key", - Namespace: ns, - }, - Data: map[string][]byte{"key": []byte("bar")}, - } - _, err = context.Clientset.CoreV1().Secrets(ns).Create(ctx, sCa, metav1.CreateOptions{}) - assert.NoError(t, err) - _, err = context.Clientset.CoreV1().Secrets(ns).Create(ctx, sClCert, metav1.CreateOptions{}) - assert.NoError(t, err) - _, err = context.Clientset.CoreV1().Secrets(ns).Create(ctx, sClKey, metav1.CreateOptions{}) - assert.NoError(t, err) - config, err = configTLS(context, ns, config) - assert.NoError(t, err) - assert.NotEqual(t, "vault-ca-cert", config["VAULT_CACERT"]) - assert.NotEqual(t, "vault-client-cert", config["VAULT_CLIENT_CERT"]) - assert.NotEqual(t, "vault-client-key", config["VAULT_CLIENT_KEY"]) + t.Run("no TLS config", func(t *testing.T) { + config := map[string]string{ + "foo": "bar", + "KMS_PROVIDER": "vault", + "VAULT_ADDR": "1.1.1.1", + "VAULT_BACKEND_PATH": "vault", + } + // No tls config + _, removeCertFiles, err := configTLS(context, ns, config) + assert.NoError(t, err) + defer removeCertFiles() + }) + + t.Run("TLS config with already populated cert path", func(t *testing.T) { + config := map[string]string{ + "foo": "bar", + "KMS_PROVIDER": "vault", + "VAULT_ADDR": "1.1.1.1", + "VAULT_BACKEND_PATH": "vault", + "VAULT_CACERT": "/etc/vault/cacert", + "VAULT_SKIP_VERIFY": "false", + } + config, removeCertFiles, err := configTLS(context, ns, config) + assert.NoError(t, err) + assert.Equal(t, "/etc/vault/cacert", config["VAULT_CACERT"]) + defer removeCertFiles() + }) + + t.Run("TLS config but no secret", func(t *testing.T) { + config := map[string]string{ + "foo": "bar", + "KMS_PROVIDER": "vault", + "VAULT_ADDR": "1.1.1.1", + "VAULT_BACKEND_PATH": "vault", + "VAULT_CACERT": "vault-ca-cert", + "VAULT_SKIP_VERIFY": "false", + } + _, removeCertFiles, err := configTLS(context, ns, config) + assert.Error(t, err) + assert.EqualError(t, err, "failed to fetch tls k8s secret \"vault-ca-cert\": secrets \"vault-ca-cert\" not found") + assert.Nil(t, removeCertFiles) + }) + + t.Run("TLS config success!", func(t *testing.T) { + config := map[string]string{ + "foo": "bar", + "KMS_PROVIDER": "vault", + "VAULT_ADDR": "1.1.1.1", + "VAULT_BACKEND_PATH": "vault", + "VAULT_CACERT": "vault-ca-cert", + "VAULT_SKIP_VERIFY": "false", + } + s := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "vault-ca-cert", + Namespace: ns, + }, + Data: map[string][]byte{"cert": []byte("bar")}, + } + _, err := context.Clientset.CoreV1().Secrets(ns).Create(ctx, s, metav1.CreateOptions{}) + assert.NoError(t, err) + config, removeCertFiles, err := configTLS(context, ns, config) + defer removeCertFiles() + assert.NoError(t, err) + assert.NotEqual(t, "vault-ca-cert", config["VAULT_CACERT"]) + err = context.Clientset.CoreV1().Secrets(ns).Delete(ctx, s.Name, metav1.DeleteOptions{}) + assert.NoError(t, err) + }) + + t.Run("advanced TLS config success!", func(t *testing.T) { + config := map[string]string{ + "foo": "bar", + "KMS_PROVIDER": "vault", + "VAULT_ADDR": "1.1.1.1", + "VAULT_BACKEND_PATH": "vault", + "VAULT_CACERT": "vault-ca-cert", + "VAULT_CLIENT_CERT": "vault-client-cert", + "VAULT_CLIENT_KEY": "vault-client-key", + } + sCa := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "vault-ca-cert", + Namespace: ns, + }, + Data: map[string][]byte{"cert": []byte("bar")}, + } + sClCert := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "vault-client-cert", + Namespace: ns, + }, + Data: map[string][]byte{"cert": []byte("bar")}, + } + sClKey := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "vault-client-key", + Namespace: ns, + }, + Data: map[string][]byte{"key": []byte("bar")}, + } + _, err := context.Clientset.CoreV1().Secrets(ns).Create(ctx, sCa, metav1.CreateOptions{}) + assert.NoError(t, err) + _, err = context.Clientset.CoreV1().Secrets(ns).Create(ctx, sClCert, metav1.CreateOptions{}) + assert.NoError(t, err) + _, err = context.Clientset.CoreV1().Secrets(ns).Create(ctx, sClKey, metav1.CreateOptions{}) + assert.NoError(t, err) + config, removeCertFiles, err := configTLS(context, ns, config) + assert.NoError(t, err) + assert.NotEqual(t, "vault-ca-cert", config["VAULT_CACERT"]) + assert.NotEqual(t, "vault-client-cert", config["VAULT_CLIENT_CERT"]) + assert.NotEqual(t, "vault-client-key", config["VAULT_CLIENT_KEY"]) + assert.FileExists(t, config["VAULT_CACERT"]) + assert.FileExists(t, config["VAULT_CLIENT_CERT"]) + assert.FileExists(t, config["VAULT_CLIENT_KEY"]) + removeCertFiles() + assert.NoFileExists(t, config["VAULT_CACERT"]) + assert.NoFileExists(t, config["VAULT_CLIENT_CERT"]) + assert.NoFileExists(t, config["VAULT_CLIENT_KEY"]) + }) + + t.Run("advanced TLS config success with timeout!", func(t *testing.T) { + config := map[string]string{ + "foo": "bar", + "KMS_PROVIDER": "vault", + "VAULT_ADDR": "1.1.1.1", + "VAULT_BACKEND_PATH": "vault", + "VAULT_CACERT": "vault-ca-cert", + "VAULT_CLIENT_CERT": "vault-client-cert", + "VAULT_CLIENT_KEY": "vault-client-key", + } + config, removeCertFiles, err := configTLS(context, ns, config) + assert.NoError(t, err) + assert.NotEqual(t, "vault-ca-cert", config["VAULT_CACERT"]) + assert.NotEqual(t, "vault-client-cert", config["VAULT_CLIENT_CERT"]) + assert.NotEqual(t, "vault-client-key", config["VAULT_CLIENT_KEY"]) + assert.FileExists(t, config["VAULT_CACERT"]) + assert.FileExists(t, config["VAULT_CLIENT_CERT"]) + assert.FileExists(t, config["VAULT_CLIENT_KEY"]) + removeCertFiles() + assert.NoFileExists(t, config["VAULT_CACERT"]) + assert.NoFileExists(t, config["VAULT_CLIENT_CERT"]) + assert.NoFileExists(t, config["VAULT_CLIENT_KEY"]) + }) + + // This test verifies that if any of ioutil.TempFile or ioutil.WriteFile fail during the TLS + // config loop we cleanup the already generated files. For instance, let's say we are at the + // second iteration, a file has been created, and then ioutil.TempFile fails, we must cleanup + // the previous file. Essentially we are verifying that defer does what it is supposed to do. + // Also, in this situation the cleanup function will be 'nil' and the caller won't run it so the + // configTLS() must do its own cleanup. + t.Run("advanced TLS config with temp file creation error", func(t *testing.T) { + createTmpFile = func(dir string, pattern string) (f *os.File, err error) { + // Create a fake temp file + ff, err := ioutil.TempFile("", "") + if err != nil { + logger.Error(err) + return nil, err + } + + // Add the file to the list of files to remove + var fakeFilesToRemove []*os.File + fakeFilesToRemove = append(fakeFilesToRemove, ff) + getRemoveCertFiles = func(filesToRemove []*os.File) removeCertFilesFunction { + return func() { + filesToRemove = fakeFilesToRemove + for _, f := range filesToRemove { + t.Logf("removing file %q after failure from TempFile call", f.Name()) + f.Close() + os.Remove(f.Name()) + } + } + } + os.Setenv("ROOK_TMP_FILE", ff.Name()) + + return ff, errors.New("error creating tmp file") + } + config := map[string]string{ + "foo": "bar", + "KMS_PROVIDER": "vault", + "VAULT_ADDR": "1.1.1.1", + "VAULT_BACKEND_PATH": "vault", + "VAULT_CACERT": "vault-ca-cert", + "VAULT_CLIENT_CERT": "vault-client-cert", + "VAULT_CLIENT_KEY": "vault-client-key", + } + _, _, err := configTLS(context, ns, config) + assert.Error(t, err) + assert.EqualError(t, err, "failed to generate temp file for k8s secret \"vault-ca-cert\" content: error creating tmp file") + assert.NoFileExists(t, os.Getenv("ROOK_TMP_FILE")) + os.Unsetenv("ROOK_TMP_FILE") + }) } func Test_buildKeyContext(t *testing.T) { diff --git a/tests/manifests/test-kms-vault-spec.yaml b/tests/manifests/test-kms-vault-spec.yaml index d9541f960533..6848fe48d69b 100644 --- a/tests/manifests/test-kms-vault-spec.yaml +++ b/tests/manifests/test-kms-vault-spec.yaml @@ -7,4 +7,7 @@ spec: VAULT_BACKEND_PATH: rook/ver1 VAULT_SECRET_ENGINE: kv VAULT_SKIP_VERIFY: "true" + VAULT_CLIENT_KEY: "vault-client-key" + VAULT_CLIENT_CERT: "vault-client-cert" + VAULT_CACERT: "vault-ca-cert" tokenSecretName: rook-vault-token