diff --git a/internal/config/cache.go b/internal/config/cache.go index 534f2e9582..fe9b6cd1fd 100644 --- a/internal/config/cache.go +++ b/internal/config/cache.go @@ -25,30 +25,29 @@ func (c *CacheConfig) setDefaults(v *viper.Viper) (warnings []string) { "backend": CacheMemory, "ttl": 1 * time.Minute, "redis": map[string]any{ - "host": "localhost", - "port": 6379, + "host": "localhost", + "port": 6379, + "password": "", + "db": 0, }, "memory": map[string]any{ + "enabled": false, // deprecated (see below) "eviction_interval": 5 * time.Minute, }, }) - if mem := v.Sub("cache.memory"); mem != nil { - mem.SetDefault("eviction_interval", 5*time.Minute) - // handle legacy memory structure - if mem.GetBool("enabled") { - warnings = append(warnings, deprecatedMsgMemoryEnabled) - // forcibly set top-level `enabled` to true - v.Set("cache.enabled", true) - // ensure ttl is mapped to the value at memory.expiration - v.RegisterAlias("cache.ttl", "cache.memory.expiration") - // ensure ttl default is set - v.SetDefault("cache.memory.expiration", 1*time.Minute) - } + if v.GetBool("cache.memory.enabled") { + warnings = append(warnings, deprecatedMsgMemoryEnabled) + // forcibly set top-level `enabled` to true + v.Set("cache.enabled", true) + // ensure ttl is mapped to the value at memory.expiration + v.RegisterAlias("cache.ttl", "cache.memory.expiration") + // ensure ttl default is set + v.SetDefault("cache.memory.expiration", 1*time.Minute) + } - if mem.IsSet("expiration") { - warnings = append(warnings, deprecatedMsgMemoryExpiration) - } + if v.IsSet("cache.memory.expiration") { + warnings = append(warnings, deprecatedMsgMemoryExpiration) } return diff --git a/internal/config/config.go b/internal/config/config.go index 3becdc8c38..f6800766e1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -59,21 +59,16 @@ func Load(path string) (*Config, error) { } var ( - cfg = &Config{} - fields = cfg.fields() + cfg = &Config{} + validators = cfg.prepare(v) ) - // set viper defaults per field - for _, defaulter := range fields.defaulters { - cfg.Warnings = append(cfg.Warnings, defaulter.setDefaults(v)...) - } - if err := v.Unmarshal(cfg, viper.DecodeHook(decodeHooks)); err != nil { return nil, err } // run any validation steps - for _, validator := range fields.validators { + for _, validator := range validators { if err := validator.validate(); err != nil { return nil, err } @@ -90,28 +85,67 @@ type validator interface { validate() error } -type fields struct { - defaulters []defaulter - validators []validator -} +func (c *Config) prepare(v *viper.Viper) (validators []validator) { + val := reflect.ValueOf(c).Elem() + for i := 0; i < val.NumField(); i++ { + // search for all expected env vars since Viper cannot + // infer when doing Unmarshal + AutomaticEnv. + // see: https://github.com/spf13/viper/issues/761 + bindEnvVars(v, "", val.Type().Field(i)) -func (c *Config) fields() (fields fields) { - structVal := reflect.ValueOf(c).Elem() - for i := 0; i < structVal.NumField(); i++ { - field := structVal.Field(i).Addr().Interface() + field := val.Field(i).Addr().Interface() + // for-each defaulter implementing fields we invoke + // setting any defaults during this prepare stage + // on the supplied viper. if defaulter, ok := field.(defaulter); ok { - fields.defaulters = append(fields.defaulters, defaulter) + c.Warnings = append(c.Warnings, defaulter.setDefaults(v)...) } + // for-each validator implementing field we collect + // them up and return them to be validated after + // unmarshalling. if validator, ok := field.(validator); ok { - fields.validators = append(fields.validators, validator) + validators = append(validators, validator) } } return } +// bindEnvVars descends into the provided struct field binding any expected +// environment variable keys it finds reflecting struct and field tags. +func bindEnvVars(v *viper.Viper, prefix string, field reflect.StructField) { + tag := field.Tag.Get("mapstructure") + if tag == "" { + tag = strings.ToLower(field.Name) + } + + var ( + key = prefix + tag + typ = field.Type + ) + + // descend through pointers + if typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + + // descend into struct fields + if typ.Kind() == reflect.Struct { + for i := 0; i < field.Type.NumField(); i++ { + structField := field.Type.Field(i) + + // key becomes prefix for sub-fields + bindEnvVars(v, key+".", structField) + } + + return + } + + v.MustBindEnv(key) +} + func (c *Config) ServeHTTP(w http.ResponseWriter, r *http.Request) { var ( out []byte diff --git a/internal/config/config_test.go b/internal/config/config_test.go index e060bced23..b65ce78fb3 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -6,12 +6,15 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "os" + "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber/jaeger-client-go" + "gopkg.in/yaml.v2" ) func TestScheme(t *testing.T) { @@ -236,7 +239,7 @@ func TestLoad(t *testing.T) { cfg := defaultConfig() cfg.Cache.Enabled = true cfg.Cache.Backend = CacheMemory - cfg.Cache.TTL = -1 + cfg.Cache.TTL = -time.Second cfg.Warnings = append(cfg.Warnings, deprecatedMsgMemoryEnabled, deprecatedMsgMemoryExpiration) return cfg }, @@ -415,7 +418,7 @@ func TestLoad(t *testing.T) { expected = tt.expected() } - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.name+" (YAML)", func(t *testing.T) { cfg, err := Load(path) if wantErr != nil { @@ -429,6 +432,39 @@ func TestLoad(t *testing.T) { assert.NotNil(t, cfg) assert.Equal(t, expected, cfg) }) + + t.Run(tt.name+" (ENV)", func(t *testing.T) { + // backup and restore environment + backup := os.Environ() + defer func() { + os.Clearenv() + for _, env := range backup { + key, value, _ := strings.Cut(env, "=") + os.Setenv(key, value) + } + }() + + // read the input config file into equivalent envs + envs := readYAMLIntoEnv(t, path) + for _, env := range envs { + t.Logf("Setting env '%s=%s'\n", env[0], env[1]) + os.Setenv(env[0], env[1]) + } + + // load default (empty) config + cfg, err := Load("./testdata/default.yml") + + if wantErr != nil { + t.Log(err) + require.ErrorIs(t, err, wantErr) + return + } + + require.NoError(t, err) + + assert.NotNil(t, cfg) + assert.Equal(t, expected, cfg) + }) } } @@ -449,3 +485,35 @@ func TestServeHTTP(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) assert.NotEmpty(t, body) } + +// readyYAMLIntoEnv parses the file provided at path as YAML. +// It walks the keys and values and builds up a set of environment variables +// compatible with viper's expectations for automatic env capability. +func readYAMLIntoEnv(t *testing.T, path string) [][2]string { + t.Helper() + + configFile, err := os.ReadFile(path) + require.NoError(t, err) + + var config map[any]any + err = yaml.Unmarshal(configFile, &config) + require.NoError(t, err) + + return getEnvVars("flipt", config) +} + +func getEnvVars(prefix string, v map[any]any) (vals [][2]string) { + for key, value := range v { + switch v := value.(type) { + case map[any]any: + vals = append(vals, getEnvVars(fmt.Sprintf("%s_%v", prefix, key), v)...) + default: + vals = append(vals, [2]string{ + fmt.Sprintf("%s_%s", strings.ToUpper(prefix), strings.ToUpper(fmt.Sprintf("%v", key))), + fmt.Sprintf("%v", value), + }) + } + } + + return +} diff --git a/internal/config/testdata/deprecated/cache_memory_enabled.yml b/internal/config/testdata/deprecated/cache_memory_enabled.yml index 2463e071f1..4c3d9d488e 100644 --- a/internal/config/testdata/deprecated/cache_memory_enabled.yml +++ b/internal/config/testdata/deprecated/cache_memory_enabled.yml @@ -1,4 +1,4 @@ cache: memory: enabled: true - expiration: -1 + expiration: -1s