Skip to content

Commit

Permalink
fix(config): ensure all expected env vars are bound
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeMac committed Nov 2, 2022
1 parent b5250a4 commit ced9841
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 38 deletions.
33 changes: 16 additions & 17 deletions internal/config/cache.go
Expand Up @@ -25,30 +25,29 @@ func (c *CacheConfig) setDefaults(v *viper.Viper) (warnings []string) {
"backend": CacheMemory,
"ttl": 1 * time.Minute,
"redis": map[string]any{
"host": "localhost",
"port": 6379,
"host": "localhost",
"port": 6379,
"password": "",
"db": 0,
},
"memory": map[string]any{
"enabled": false, // deprecated (see below)
"eviction_interval": 5 * time.Minute,
},
})

if mem := v.Sub("cache.memory"); mem != nil {
mem.SetDefault("eviction_interval", 5*time.Minute)
// handle legacy memory structure
if mem.GetBool("enabled") {
warnings = append(warnings, deprecatedMsgMemoryEnabled)
// forcibly set top-level `enabled` to true
v.Set("cache.enabled", true)
// ensure ttl is mapped to the value at memory.expiration
v.RegisterAlias("cache.ttl", "cache.memory.expiration")
// ensure ttl default is set
v.SetDefault("cache.memory.expiration", 1*time.Minute)
}
if v.GetBool("cache.memory.enabled") {
warnings = append(warnings, deprecatedMsgMemoryEnabled)
// forcibly set top-level `enabled` to true
v.Set("cache.enabled", true)
// ensure ttl is mapped to the value at memory.expiration
v.RegisterAlias("cache.ttl", "cache.memory.expiration")
// ensure ttl default is set
v.SetDefault("cache.memory.expiration", 1*time.Minute)
}

if mem.IsSet("expiration") {
warnings = append(warnings, deprecatedMsgMemoryExpiration)
}
if v.IsSet("cache.memory.expiration") {
warnings = append(warnings, deprecatedMsgMemoryExpiration)
}

return
Expand Down
70 changes: 52 additions & 18 deletions internal/config/config.go
Expand Up @@ -59,21 +59,16 @@ func Load(path string) (*Config, error) {
}

var (
cfg = &Config{}
fields = cfg.fields()
cfg = &Config{}
validators = cfg.prepare(v)
)

// set viper defaults per field
for _, defaulter := range fields.defaulters {
cfg.Warnings = append(cfg.Warnings, defaulter.setDefaults(v)...)
}

if err := v.Unmarshal(cfg, viper.DecodeHook(decodeHooks)); err != nil {
return nil, err
}

// run any validation steps
for _, validator := range fields.validators {
for _, validator := range validators {
if err := validator.validate(); err != nil {
return nil, err
}
Expand All @@ -90,28 +85,67 @@ type validator interface {
validate() error
}

type fields struct {
defaulters []defaulter
validators []validator
}
func (c *Config) prepare(v *viper.Viper) (validators []validator) {
val := reflect.ValueOf(c).Elem()
for i := 0; i < val.NumField(); i++ {
// search for all expected env vars since Viper cannot
// infer when doing Unmarshal + AutomaticEnv.
// see: https://github.com/spf13/viper/issues/761
bindEnvVars(v, "", val.Type().Field(i))

func (c *Config) fields() (fields fields) {
structVal := reflect.ValueOf(c).Elem()
for i := 0; i < structVal.NumField(); i++ {
field := structVal.Field(i).Addr().Interface()
field := val.Field(i).Addr().Interface()

// for-each defaulter implementing fields we invoke
// setting any defaults during this prepare stage
// on the supplied viper.
if defaulter, ok := field.(defaulter); ok {
fields.defaulters = append(fields.defaulters, defaulter)
c.Warnings = append(c.Warnings, defaulter.setDefaults(v)...)
}

// for-each validator implementing field we collect
// them up and return them to be validated after
// unmarshalling.
if validator, ok := field.(validator); ok {
fields.validators = append(fields.validators, validator)
validators = append(validators, validator)
}
}

return
}

// bindEnvVars descends into the provided struct field binding any expected
// environment variable keys it finds reflecting struct and field tags.
func bindEnvVars(v *viper.Viper, prefix string, field reflect.StructField) {
tag := field.Tag.Get("mapstructure")
if tag == "" {
tag = strings.ToLower(field.Name)
}

var (
key = prefix + tag
typ = field.Type
)

// descend through pointers
if typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}

// descend into struct fields
if typ.Kind() == reflect.Struct {
for i := 0; i < field.Type.NumField(); i++ {
structField := field.Type.Field(i)

// key becomes prefix for sub-fields
bindEnvVars(v, key+".", structField)
}

return
}

v.BindEnv(key)
}

func (c *Config) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var (
out []byte
Expand Down
72 changes: 70 additions & 2 deletions internal/config/config_test.go
Expand Up @@ -6,12 +6,15 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/uber/jaeger-client-go"
"gopkg.in/yaml.v2"
)

func TestScheme(t *testing.T) {
Expand Down Expand Up @@ -236,7 +239,7 @@ func TestLoad(t *testing.T) {
cfg := defaultConfig()
cfg.Cache.Enabled = true
cfg.Cache.Backend = CacheMemory
cfg.Cache.TTL = -1
cfg.Cache.TTL = -time.Second
cfg.Warnings = append(cfg.Warnings, deprecatedMsgMemoryEnabled, deprecatedMsgMemoryExpiration)
return cfg
},
Expand Down Expand Up @@ -415,7 +418,7 @@ func TestLoad(t *testing.T) {
expected = tt.expected()
}

t.Run(tt.name, func(t *testing.T) {
t.Run(tt.name+" (YAML)", func(t *testing.T) {
cfg, err := Load(path)

if wantErr != nil {
Expand All @@ -429,6 +432,39 @@ func TestLoad(t *testing.T) {
assert.NotNil(t, cfg)
assert.Equal(t, expected, cfg)
})

t.Run(tt.name+" (ENV)", func(t *testing.T) {
// backup and restore environment
backup := os.Environ()
defer func() {
os.Clearenv()
for _, env := range backup {
key, value, _ := strings.Cut(env, "=")
os.Setenv(key, value)
}
}()

// read the input config file into equivalent envs
envs := readYAMLIntoEnv(t, path)
for _, env := range envs {
t.Logf("Setting env '%s=%s'\n", env[0], env[1])
os.Setenv(env[0], env[1])
}

// load default (empty) config
cfg, err := Load("./testdata/default.yml")

if wantErr != nil {
t.Log(err)
require.ErrorIs(t, err, wantErr)
return
}

require.NoError(t, err)

assert.NotNil(t, cfg)
assert.Equal(t, expected, cfg)
})
}
}

Expand All @@ -449,3 +485,35 @@ func TestServeHTTP(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NotEmpty(t, body)
}

// readyYAMLIntoEnv parses the file provided at path as YAML.
// It walks the keys and values and builds up a set of environment variables
// compatible with viper's expectations for automatic env capability.
func readYAMLIntoEnv(t *testing.T, path string) [][2]string {
t.Helper()

configFile, err := os.ReadFile(path)
require.NoError(t, err)

var config map[any]any
err = yaml.Unmarshal(configFile, &config)
require.NoError(t, err)

return getEnvVars("flipt", config)
}

func getEnvVars(prefix string, v map[any]any) (vals [][2]string) {
for key, value := range v {
switch v := value.(type) {
case map[any]any:
vals = append(vals, getEnvVars(fmt.Sprintf("%s_%v", prefix, key), v)...)
default:
vals = append(vals, [2]string{
fmt.Sprintf("%s_%s", strings.ToUpper(prefix), strings.ToUpper(fmt.Sprintf("%v", key))),
fmt.Sprintf("%v", value),
})
}
}

return
}
@@ -1,4 +1,4 @@
cache:
memory:
enabled: true
expiration: -1
expiration: -1s

0 comments on commit ced9841

Please sign in to comment.