Skip to content

Commit

Permalink
new-approach
Browse files Browse the repository at this point in the history
Signed-off-by: Sébastien Han <seb@redhat.com>
  • Loading branch information
leseb committed Oct 1, 2021
1 parent 922335e commit 1691be9
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 105 deletions.
27 changes: 3 additions & 24 deletions pkg/daemon/ceph/osd/kms/kms.go
Expand Up @@ -83,14 +83,7 @@ func (c *Config) PutSecret(secretName, secretValue string) error {
}
if c.IsVault() {
// Store the secret in Vault
removeCertFiles := make(chan bool, 1)
// Remove cert files from operator's filesystem
defer func() {
removeCertFiles <- true
close(removeCertFiles)
}()

v, err := InitVault(c.context, c.clusterInfo.Namespace, c.clusterSpec.Security.KeyManagementService.ConnectionDetails, removeCertFiles)
v, err := InitVault(c.context, c.clusterInfo.Namespace, c.clusterSpec.Security.KeyManagementService.ConnectionDetails)
if err != nil {
return errors.Wrap(err, "failed to init vault kms")
}
Expand All @@ -109,14 +102,7 @@ func (c *Config) GetSecret(secretName string) (string, error) {
var value string
if c.IsVault() {
// Store the secret in Vault
removeCertFiles := make(chan bool, 1)
// Remove cert files from operator's filesystem
defer func() {
removeCertFiles <- true
close(removeCertFiles)
}()

v, err := InitVault(c.context, c.clusterInfo.Namespace, c.clusterSpec.Security.KeyManagementService.ConnectionDetails, removeCertFiles)
v, err := InitVault(c.context, c.clusterInfo.Namespace, c.clusterSpec.Security.KeyManagementService.ConnectionDetails)
if err != nil {
return "", errors.Wrap(err, "failed to get secret in vault")
}
Expand All @@ -135,14 +121,7 @@ func (c *Config) GetSecret(secretName string) (string, error) {
func (c *Config) DeleteSecret(secretName string) error {
if c.IsVault() {
// Store the secret in Vault
removeCertFiles := make(chan bool, 1)
// Remove cert files from operator's filesystem
defer func() {
removeCertFiles <- true
close(removeCertFiles)
}()

v, err := InitVault(c.context, c.clusterInfo.Namespace, c.clusterSpec.Security.KeyManagementService.ConnectionDetails, removeCertFiles)
v, err := InitVault(c.context, c.clusterInfo.Namespace, c.clusterSpec.Security.KeyManagementService.ConnectionDetails)
if err != nil {
return errors.Wrap(err, "failed to delete secret in vault")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/daemon/ceph/osd/kms/kms_test.go
Expand Up @@ -122,7 +122,7 @@ func TestValidateConnectionDetails(t *testing.T) {
vault.TestWaitActive(t, core)
client := cluster.Cores[0].Client
// Mock the client here
vaultClient = func(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string, removeCertFiles chan bool) (*api.Client, error) {
vaultClient = func(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string) (*api.Client, error) {
return client, nil
}
if err := client.Sys().Mount("rook/", &api.MountInput{
Expand Down
63 changes: 25 additions & 38 deletions pkg/daemon/ceph/osd/kms/vault.go
Expand Up @@ -21,7 +21,6 @@ import (
"io/ioutil"
"os"
"strings"
"time"

"github.com/hashicorp/vault/api"
"github.com/libopenstorage/secrets"
Expand All @@ -47,6 +46,8 @@ var (
vaultMandatoryConnectionDetails = []string{api.EnvVaultAddress}
)

type RemoveCertFilesFunction func()

/* VAULT API INTERNAL VALUES
// Refer to https://pkg.golangclub.com/github.com/hashicorp/vault/api?tab=doc#pkg-constants
const EnvVaultAddress = "VAULT_ADDR"
Expand All @@ -68,7 +69,7 @@ var (
*/

// InitVault inits the secret store
func InitVault(context *clusterd.Context, namespace string, config map[string]string, removeCertFiles chan bool) (secrets.Secrets, error) {
func InitVault(context *clusterd.Context, namespace string, config map[string]string) (secrets.Secrets, error) {
c := make(map[string]interface{})

// So that we don't alter the content of c.config for later iterations
Expand All @@ -79,10 +80,11 @@ func InitVault(context *clusterd.Context, namespace string, config map[string]st
}

// Populate TLS config
newConfigWithTLS, err := configTLS(context, namespace, oriConfig, removeCertFiles)
newConfigWithTLS, removeCertFiles, err := configTLS(context, namespace, oriConfig)
if err != nil {
return nil, errors.Wrap(err, "failed to initialize vault tls configuration")
}
defer removeCertFiles()

// Populate TLS config
for key, value := range newConfigWithTLS {
Expand All @@ -98,9 +100,10 @@ func InitVault(context *clusterd.Context, namespace string, config map[string]st
return v, nil
}

func configTLS(clusterdContext *clusterd.Context, namespace string, config map[string]string, removeCertFiles chan bool) (map[string]string, error) {
func configTLS(clusterdContext *clusterd.Context, namespace string, config map[string]string) (map[string]string, RemoveCertFilesFunction, error) {
ctx := context.TODO()
var filesToRemove []*os.File
var removeCertFiles RemoveCertFilesFunction
for _, tlsOption := range cephv1.VaultTLSConnectionDetails {
tlsSecretName := GetParam(config, tlsOption)
if tlsSecretName == "" {
Expand All @@ -110,19 +113,19 @@ func configTLS(clusterdContext *clusterd.Context, namespace string, config map[s
if !strings.Contains(tlsSecretName, EtcVaultDir) {
secret, err := clusterdContext.Clientset.CoreV1().Secrets(namespace).Get(ctx, tlsSecretName, v1.GetOptions{})
if err != nil {
return nil, errors.Wrapf(err, "failed to fetch tls k8s secret %q", tlsSecretName)
return nil, removeCertFiles, errors.Wrapf(err, "failed to fetch tls k8s secret %q", tlsSecretName)
}

// Generate a temp file
file, err := ioutil.TempFile("", "")
if err != nil {
return nil, errors.Wrapf(err, "failed to generate temp file for k8s secret %q content", tlsSecretName)
return nil, removeCertFiles, errors.Wrapf(err, "failed to generate temp file for k8s secret %q content", tlsSecretName)
}

// Write into a file
err = ioutil.WriteFile(file.Name(), secret.Data[tlsSecretKeyToCheck(tlsOption)], 0444)
if err != nil {
return nil, errors.Wrapf(err, "failed to write k8s secret %q content to a file", tlsSecretName)
return nil, removeCertFiles, errors.Wrapf(err, "failed to write k8s secret %q content to a file", tlsSecretName)
}

logger.Debugf("replacing %q current content %q with %q", tlsOption, config[tlsOption], file.Name())
Expand All @@ -137,41 +140,25 @@ func configTLS(clusterdContext *clusterd.Context, namespace string, config map[s
}
}

// Run a goroutine that waits for a confirmation to remove the certificates from the filesystem
go func(filesToRemove []*os.File) {
logger.Debugf("files to remove: %+v", filesToRemove)
rmCertFiles := func() {
for _, file := range filesToRemove {
logger.Debugf("closing %q", file.Name())
err := file.Close()
if err != nil {
logger.Errorf("failed to close file %q. %v", file.Name(), err)
}
logger.Debugf("closed %q", file.Name())
logger.Debugf("removing %q", file.Name())
err = os.Remove(file.Name())
if err != nil {
logger.Errorf("failed to remove file %q. %v", file.Name(), err)
}
logger.Debugf("removed %q", file.Name())
removeCertFiles = RemoveCertFilesFunction(func() {
filesToRemove := filesToRemove
for _, file := range filesToRemove {
logger.Debugf("closing %q", file.Name())
err := file.Close()
if err != nil {
logger.Errorf("failed to close file %q. %v", file.Name(), err)
}
}
for {
select {
case <-removeCertFiles:
logger.Debug("received confirmation from channel to remove vault certificates")
rmCertFiles()
logger.Info("successfully removed secret files")
return
case <-time.After(time.Second * 5):
logger.Info("never received anything from the channel removing vault cert files anyway")
rmCertFiles()
return
logger.Debugf("closed %q", file.Name())
logger.Debugf("removing %q", file.Name())
err = os.Remove(file.Name())
if err != nil {
logger.Errorf("failed to remove file %q. %v", file.Name(), err)
}
logger.Debugf("removed %q", file.Name())
}
}(filesToRemove)
})

return config, nil
return config, removeCertFiles, nil
}

func put(v secrets.Secrets, secretName, secretValue string, keyContext map[string]string) error {
Expand Down
14 changes: 4 additions & 10 deletions pkg/daemon/ceph/osd/kms/vault_api.go
Expand Up @@ -39,7 +39,7 @@ var vaultClient = newVaultClient

// newVaultClient returns a vault client, there is no need for any secretConfig validation
// Since this is called after an already validated call InitVault()
func newVaultClient(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string, removeCertFiles chan bool) (*api.Client, error) {
func newVaultClient(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string) (*api.Client, error) {
// DefaultConfig uses the environment variables if present.
config := api.DefaultConfig()

Expand All @@ -57,10 +57,11 @@ func newVaultClient(clusterdContext *clusterd.Context, namespace string, secretC
}

// Populate TLS config
newConfigWithTLS, err := configTLS(clusterdContext, namespace, localSecretConfig, removeCertFiles)
newConfigWithTLS, removeCertFiles, err := configTLS(clusterdContext, namespace, localSecretConfig)
if err != nil {
return nil, errors.Wrap(err, "failed to initialize vault tls configuration")
}
defer removeCertFiles()

// Populate TLS config
for key, value := range newConfigWithTLS {
Expand Down Expand Up @@ -109,15 +110,8 @@ func BackendVersion(clusterdContext *clusterd.Context, namespace string, secretC
logger.Info("vault kv secret engine version set to v2")
return v2, nil
default:
removeCertFiles := make(chan bool, 1)
// Remove cert files from operator's filesystem
defer func() {
removeCertFiles <- true
close(removeCertFiles)
}()

// Initialize Vault client
vaultClient, err := vaultClient(clusterdContext, namespace, secretConfig, removeCertFiles)
vaultClient, err := vaultClient(clusterdContext, namespace, secretConfig)
if err != nil {
return "", errors.Wrap(err, "failed to initialize vault client")
}
Expand Down
18 changes: 6 additions & 12 deletions pkg/daemon/ceph/osd/kms/vault_api_test.go
Expand Up @@ -42,7 +42,7 @@ func TestBackendVersion(t *testing.T) {
client := cluster.Cores[0].Client

// Mock the client here
vaultClient = func(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string, removeCertFiles chan bool) (*api.Client, error) {
vaultClient = func(clusterdContext *clusterd.Context, namespace string, secretConfig map[string]string) (*api.Client, error) {
return client, nil
}

Expand Down Expand Up @@ -173,28 +173,22 @@ cjLxGL8tcZjHKg==
}

for _, s := range []*v1.Secret{sCa, sClCert, sClKey} {
if _, err := context.Clientset.CoreV1().Secrets(ns).Create(ctx, s, metav1.CreateOptions{}); err != nil {
if secret, err := context.Clientset.CoreV1().Secrets(ns).Create(ctx, s, metav1.CreateOptions{}); err != nil {
t.Fatal(err)
} else {
defer func() {
err := context.Clientset.CoreV1().Secrets(ns).Delete(ctx, s.Name, metav1.DeleteOptions{})
err := context.Clientset.CoreV1().Secrets(ns).Delete(ctx, secret.Name, metav1.DeleteOptions{})
if err != nil {
logger.Errorf("failed to delete secret %s: %v", s.Name, err)
logger.Errorf("failed to delete secret %s: %v", secret.Name, err)
}
}()
}
}

removeCertFiles := make(chan bool, 1)
// Remove cert files from operator's filesystem
defer func() {
removeCertFiles <- true
close(removeCertFiles)
}()

// Populate TLS config
newConfigWithTLS, err := configTLS(context, ns, secretConfig, removeCertFiles)
newConfigWithTLS, removeCertFiles, err := configTLS(context, ns, secretConfig)
assert.NoError(t, err)
defer removeCertFiles()

// Populate TLS config
for key, value := range newConfigWithTLS {
Expand Down
31 changes: 11 additions & 20 deletions pkg/daemon/ceph/osd/kms/vault_test.go
Expand Up @@ -63,14 +63,12 @@ func Test_configTLS(t *testing.T) {
"VAULT_BACKEND_PATH": "vault",
}
// No tls config
removeCertFiles := make(chan bool, 1)
_, err := configTLS(context, ns, config, removeCertFiles)
_, removeCertFiles, err := configTLS(context, ns, config)
assert.NoError(t, err)
close(removeCertFiles)
defer removeCertFiles()
})

t.Run("TLS config with already populated cert path", func(t *testing.T) {
removeCertFiles := make(chan bool, 1)
config := map[string]string{
"foo": "bar",
"KMS_PROVIDER": "vault",
Expand All @@ -79,14 +77,13 @@ func Test_configTLS(t *testing.T) {
"VAULT_CACERT": "/etc/vault/cacert",
"VAULT_SKIP_VERIFY": "false",
}
config, err := configTLS(context, ns, config, removeCertFiles)
config, removeCertFiles, err := configTLS(context, ns, config)
assert.NoError(t, err)
assert.Equal(t, "/etc/vault/cacert", config["VAULT_CACERT"])
close(removeCertFiles)
defer removeCertFiles()
})

t.Run("TLS config but no secret", func(t *testing.T) {
removeCertFiles := make(chan bool, 1)
config := map[string]string{
"foo": "bar",
"KMS_PROVIDER": "vault",
Expand All @@ -95,14 +92,12 @@ func Test_configTLS(t *testing.T) {
"VAULT_CACERT": "vault-ca-cert",
"VAULT_SKIP_VERIFY": "false",
}
_, err := configTLS(context, ns, config, removeCertFiles)
_, _, err := configTLS(context, ns, config)
assert.Error(t, err)
assert.EqualError(t, err, "failed to fetch tls k8s secret \"vault-ca-cert\": secrets \"vault-ca-cert\" not found")
close(removeCertFiles)
})

t.Run("TLS config success!", func(t *testing.T) {
removeCertFiles := make(chan bool, 1)
config := map[string]string{
"foo": "bar",
"KMS_PROVIDER": "vault",
Expand All @@ -120,17 +115,15 @@ func Test_configTLS(t *testing.T) {
}
_, err := context.Clientset.CoreV1().Secrets(ns).Create(ctx, s, metav1.CreateOptions{})
assert.NoError(t, err)
config, err = configTLS(context, ns, config, removeCertFiles)
config, removeCertFiles, err := configTLS(context, ns, config)
defer removeCertFiles()
assert.NoError(t, err)
assert.NotEqual(t, "vault-ca-cert", config["VAULT_CACERT"])
err = context.Clientset.CoreV1().Secrets(ns).Delete(ctx, s.Name, metav1.DeleteOptions{})
assert.NoError(t, err)
removeCertFiles <- true
close(removeCertFiles)
})

t.Run("advanced TLS config success!", func(t *testing.T) {
removeCertFiles := make(chan bool, 1)
config := map[string]string{
"foo": "bar",
"KMS_PROVIDER": "vault",
Expand Down Expand Up @@ -167,22 +160,20 @@ func Test_configTLS(t *testing.T) {
assert.NoError(t, err)
_, err = context.Clientset.CoreV1().Secrets(ns).Create(ctx, sClKey, metav1.CreateOptions{})
assert.NoError(t, err)
config, err = configTLS(context, ns, config, removeCertFiles)
config, removeCertFiles, err := configTLS(context, ns, config)
assert.NoError(t, err)
assert.NotEqual(t, "vault-ca-cert", config["VAULT_CACERT"])
assert.NotEqual(t, "vault-client-cert", config["VAULT_CLIENT_CERT"])
assert.NotEqual(t, "vault-client-key", config["VAULT_CLIENT_KEY"])
assert.FileExists(t, config["VAULT_CACERT"])
removeCertFiles <- true
close(removeCertFiles)
removeCertFiles()
time.Sleep(time.Second)
assert.NoFileExists(t, config["VAULT_CACERT"])
assert.NoFileExists(t, config["VAULT_CLIENT_CERT"])
assert.NoFileExists(t, config["VAULT_CLIENT_KEY"])
})

t.Run("advanced TLS config success with timeout!", func(t *testing.T) {
removeCertFiles := make(chan bool, 1)
config := map[string]string{
"foo": "bar",
"KMS_PROVIDER": "vault",
Expand All @@ -192,17 +183,17 @@ func Test_configTLS(t *testing.T) {
"VAULT_CLIENT_CERT": "vault-client-cert",
"VAULT_CLIENT_KEY": "vault-client-key",
}
config, err := configTLS(context, ns, config, removeCertFiles)
config, removeCertFiles, err := configTLS(context, ns, config)
assert.NoError(t, err)
assert.NotEqual(t, "vault-ca-cert", config["VAULT_CACERT"])
assert.NotEqual(t, "vault-client-cert", config["VAULT_CLIENT_CERT"])
assert.NotEqual(t, "vault-client-key", config["VAULT_CLIENT_KEY"])
assert.FileExists(t, config["VAULT_CACERT"])
removeCertFiles()
time.Sleep(time.Second * 6)
assert.NoFileExists(t, config["VAULT_CACERT"])
assert.NoFileExists(t, config["VAULT_CLIENT_CERT"])
assert.NoFileExists(t, config["VAULT_CLIENT_KEY"])
close(removeCertFiles)
})
}

Expand Down

0 comments on commit 1691be9

Please sign in to comment.