Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(config): ensure all expected env vars are bound #1113

Merged
merged 2 commits into from Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 16 additions & 17 deletions internal/config/cache.go
Expand Up @@ -25,30 +25,29 @@ func (c *CacheConfig) setDefaults(v *viper.Viper) (warnings []string) {
"backend": CacheMemory,
"ttl": 1 * time.Minute,
"redis": map[string]any{
"host": "localhost",
"port": 6379,
"host": "localhost",
"port": 6379,
"password": "",
"db": 0,
},
"memory": map[string]any{
"enabled": false, // deprecated (see below)
"eviction_interval": 5 * time.Minute,
},
})

if mem := v.Sub("cache.memory"); mem != nil {
mem.SetDefault("eviction_interval", 5*time.Minute)
// handle legacy memory structure
if mem.GetBool("enabled") {
warnings = append(warnings, deprecatedMsgMemoryEnabled)
// forcibly set top-level `enabled` to true
v.Set("cache.enabled", true)
// ensure ttl is mapped to the value at memory.expiration
v.RegisterAlias("cache.ttl", "cache.memory.expiration")
// ensure ttl default is set
v.SetDefault("cache.memory.expiration", 1*time.Minute)
}
if v.GetBool("cache.memory.enabled") {
warnings = append(warnings, deprecatedMsgMemoryEnabled)
// forcibly set top-level `enabled` to true
v.Set("cache.enabled", true)
// ensure ttl is mapped to the value at memory.expiration
v.RegisterAlias("cache.ttl", "cache.memory.expiration")
// ensure ttl default is set
v.SetDefault("cache.memory.expiration", 1*time.Minute)
}

if mem.IsSet("expiration") {
warnings = append(warnings, deprecatedMsgMemoryExpiration)
}
if v.IsSet("cache.memory.expiration") {
warnings = append(warnings, deprecatedMsgMemoryExpiration)
}

return
Expand Down
70 changes: 52 additions & 18 deletions internal/config/config.go
Expand Up @@ -59,21 +59,16 @@ func Load(path string) (*Config, error) {
}

var (
cfg = &Config{}
fields = cfg.fields()
cfg = &Config{}
validators = cfg.prepare(v)
)

// set viper defaults per field
for _, defaulter := range fields.defaulters {
cfg.Warnings = append(cfg.Warnings, defaulter.setDefaults(v)...)
}

if err := v.Unmarshal(cfg, viper.DecodeHook(decodeHooks)); err != nil {
return nil, err
}

// run any validation steps
for _, validator := range fields.validators {
for _, validator := range validators {
if err := validator.validate(); err != nil {
return nil, err
}
Expand All @@ -90,28 +85,67 @@ type validator interface {
validate() error
}

type fields struct {
defaulters []defaulter
validators []validator
}
func (c *Config) prepare(v *viper.Viper) (validators []validator) {
val := reflect.ValueOf(c).Elem()
for i := 0; i < val.NumField(); i++ {
// search for all expected env vars since Viper cannot
// infer when doing Unmarshal + AutomaticEnv.
// see: https://github.com/spf13/viper/issues/761
bindEnvVars(v, "", val.Type().Field(i))

func (c *Config) fields() (fields fields) {
structVal := reflect.ValueOf(c).Elem()
for i := 0; i < structVal.NumField(); i++ {
field := structVal.Field(i).Addr().Interface()
field := val.Field(i).Addr().Interface()

// for-each defaulter implementing fields we invoke
// setting any defaults during this prepare stage
// on the supplied viper.
if defaulter, ok := field.(defaulter); ok {
fields.defaulters = append(fields.defaulters, defaulter)
c.Warnings = append(c.Warnings, defaulter.setDefaults(v)...)
}

// for-each validator implementing field we collect
// them up and return them to be validated after
// unmarshalling.
if validator, ok := field.(validator); ok {
fields.validators = append(fields.validators, validator)
validators = append(validators, validator)
}
}

return
}

// bindEnvVars descends into the provided struct field binding any expected
// environment variable keys it finds reflecting struct and field tags.
func bindEnvVars(v *viper.Viper, prefix string, field reflect.StructField) {
tag := field.Tag.Get("mapstructure")
if tag == "" {
tag = strings.ToLower(field.Name)
}

var (
key = prefix + tag
typ = field.Type
)

// descend through pointers
if typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}

// descend into struct fields
if typ.Kind() == reflect.Struct {
for i := 0; i < field.Type.NumField(); i++ {
structField := field.Type.Field(i)

// key becomes prefix for sub-fields
bindEnvVars(v, key+".", structField)
}

return
}

v.MustBindEnv(key)
}

func (c *Config) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var (
out []byte
Expand Down
72 changes: 70 additions & 2 deletions internal/config/config_test.go
Expand Up @@ -6,12 +6,15 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/uber/jaeger-client-go"
"gopkg.in/yaml.v2"
)

func TestScheme(t *testing.T) {
Expand Down Expand Up @@ -236,7 +239,7 @@ func TestLoad(t *testing.T) {
cfg := defaultConfig()
cfg.Cache.Enabled = true
cfg.Cache.Backend = CacheMemory
cfg.Cache.TTL = -1
cfg.Cache.TTL = -time.Second
cfg.Warnings = append(cfg.Warnings, deprecatedMsgMemoryEnabled, deprecatedMsgMemoryExpiration)
return cfg
},
Expand Down Expand Up @@ -415,7 +418,7 @@ func TestLoad(t *testing.T) {
expected = tt.expected()
}

t.Run(tt.name, func(t *testing.T) {
t.Run(tt.name+" (YAML)", func(t *testing.T) {
cfg, err := Load(path)

if wantErr != nil {
Expand All @@ -429,6 +432,39 @@ func TestLoad(t *testing.T) {
assert.NotNil(t, cfg)
assert.Equal(t, expected, cfg)
})

t.Run(tt.name+" (ENV)", func(t *testing.T) {
// backup and restore environment
backup := os.Environ()
defer func() {
os.Clearenv()
for _, env := range backup {
key, value, _ := strings.Cut(env, "=")
os.Setenv(key, value)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️ love all this

}()

// read the input config file into equivalent envs
envs := readYAMLIntoEnv(t, path)
for _, env := range envs {
t.Logf("Setting env '%s=%s'\n", env[0], env[1])
os.Setenv(env[0], env[1])
}

// load default (empty) config
cfg, err := Load("./testdata/default.yml")

if wantErr != nil {
t.Log(err)
require.ErrorIs(t, err, wantErr)
return
}

require.NoError(t, err)

assert.NotNil(t, cfg)
assert.Equal(t, expected, cfg)
})
}
}

Expand All @@ -449,3 +485,35 @@ func TestServeHTTP(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotEmpty(t, body)
}

// readyYAMLIntoEnv parses the file provided at path as YAML.
// It walks the keys and values and builds up a set of environment variables
// compatible with viper's expectations for automatic env capability.
func readYAMLIntoEnv(t *testing.T, path string) [][2]string {
t.Helper()

configFile, err := os.ReadFile(path)
require.NoError(t, err)

var config map[any]any
err = yaml.Unmarshal(configFile, &config)
require.NoError(t, err)

return getEnvVars("flipt", config)
}

func getEnvVars(prefix string, v map[any]any) (vals [][2]string) {
for key, value := range v {
switch v := value.(type) {
case map[any]any:
vals = append(vals, getEnvVars(fmt.Sprintf("%s_%v", prefix, key), v)...)
default:
vals = append(vals, [2]string{
fmt.Sprintf("%s_%s", strings.ToUpper(prefix), strings.ToUpper(fmt.Sprintf("%v", key))),
fmt.Sprintf("%v", value),
})
}
}

return
}
@@ -1,4 +1,4 @@
cache:
memory:
enabled: true
expiration: -1
expiration: -1s