Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement viper.BindStruct for automatic unmarshalling from environment variables #1429

Merged
merged 2 commits into from Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 31 additions & 2 deletions viper.go
Expand Up @@ -1111,7 +1111,32 @@ func Unmarshal(rawVal any, opts ...DecoderConfigOption) error {
}

func (v *Viper) Unmarshal(rawVal any, opts ...DecoderConfigOption) error {
return decode(v.AllSettings(), defaultDecoderConfig(rawVal, opts...))
// TODO: make this optional?
structKeys, err := v.decodeStructKeys(rawVal, opts...)
if err != nil {
return err
}

// TODO: struct keys should be enough?
return decode(v.getSettings(append(v.AllKeys(), structKeys...)), defaultDecoderConfig(rawVal, opts...))
}

func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]string, error) {
var structKeyMap map[string]any

err := decode(input, defaultDecoderConfig(&structKeyMap, opts...))
if err != nil {
return nil, err
}

flattenedStructKeyMap := v.flattenAndMergeMap(map[string]bool{}, structKeyMap, "")

r := make([]string, 0, len(flattenedStructKeyMap))
for v := range flattenedStructKeyMap {
r = append(r, v)
}

return r, nil
}

// defaultDecoderConfig returns default mapstructure.DecoderConfig with support
Expand Down Expand Up @@ -2098,9 +2123,13 @@ outer:
func AllSettings() map[string]any { return v.AllSettings() }

func (v *Viper) AllSettings() map[string]any {
return v.getSettings(v.AllKeys())
}

func (v *Viper) getSettings(keys []string) map[string]any {
m := map[string]any{}
// start from the list of keys, and construct the map one value at a time
for _, k := range v.AllKeys() {
for _, k := range keys {
value := v.Get(k)
if value == nil {
// should not happen, since AllKeys() returns only keys holding a value,
Expand Down
99 changes: 99 additions & 0 deletions viper_test.go
Expand Up @@ -948,6 +948,105 @@ func TestUnmarshalWithDecoderOptions(t *testing.T) {
}, &C)
}

func TestUnmarshalWithAutomaticEnv(t *testing.T) {
t.Setenv("PORT", "1313")
t.Setenv("NAME", "Steve")
t.Setenv("DURATION", "1s1ms")
t.Setenv("MODES", "1,2,3")
t.Setenv("SECRET", "42")
t.Setenv("FILESYSTEM_SIZE", "4096")

type AuthConfig struct {
Secret string `mapstructure:"secret"`
}

type StorageConfig struct {
Size int `mapstructure:"size"`
}

type Configuration struct {
Port int `mapstructure:"port"`
Name string `mapstructure:"name"`
Duration time.Duration `mapstructure:"duration"`

// Infer name from struct
Modes []int

// Squash nested struct (omit prefix)
Authentication AuthConfig `mapstructure:",squash"`

// Different key
Storage StorageConfig `mapstructure:"filesystem"`

// Omitted field
Flag bool `mapstructure:"flag"`
}

v := New()
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()

t.Run("OK", func(t *testing.T) {
var config Configuration
if err := v.Unmarshal(&config); err != nil {
t.Fatalf("unable to decode into struct, %v", err)
}

assert.Equal(
t,
Configuration{
Name: "Steve",
Port: 1313,
Duration: time.Second + time.Millisecond,
Modes: []int{1, 2, 3},
Authentication: AuthConfig{
Secret: "42",
},
Storage: StorageConfig{
Size: 4096,
},
},
config,
)
})

t.Run("Precedence", func(t *testing.T) {
var config Configuration

v.Set("port", 1234)
if err := v.Unmarshal(&config); err != nil {
t.Fatalf("unable to decode into struct, %v", err)
}

assert.Equal(
t,
Configuration{
Name: "Steve",
Port: 1234,
Duration: time.Second + time.Millisecond,
Modes: []int{1, 2, 3},
Authentication: AuthConfig{
Secret: "42",
},
Storage: StorageConfig{
Size: 4096,
},
},
config,
)
})

t.Run("Unset", func(t *testing.T) {
var config Configuration

err := v.Unmarshal(&config, func(config *mapstructure.DecoderConfig) {
config.ErrorUnset = true
})

assert.Error(t, err, "expected viper.Unmarshal to return error due to unset field 'FLAG'")
})
}

func TestBindPFlags(t *testing.T) {
v := New() // create independent Viper object
flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError)
Expand Down