Skip to content

Commit

Permalink
fix spf13#1106 UnmarshalKey looses some data if a subitem is overridden
Browse files Browse the repository at this point in the history
  • Loading branch information
fishautumn committed Jan 30, 2023
1 parent 5182412 commit 6e88dd6
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
8 changes: 4 additions & 4 deletions overrides_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ func TestNestedOverrides(t *testing.T) {
deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)

// Case 4: key:value overridden by a map
v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable
assert.Equal(10, v.Get("tom.age")) // new value should be there
deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there
v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10, "size": 4}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable
assert.Equal(10, v.Get("tom.age")) // new value should be there
deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there
v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10})
assert.Nil(v.Get("tom.size"))
assert.Equal(10, v.Get("tom.age"))
Expand Down
30 changes: 30 additions & 0 deletions viper.go
Original file line number Diff line number Diff line change
Expand Up @@ -887,13 +887,43 @@ func GetViper() *Viper {
// Get returns an interface. For a specific value use one of the Get____ methods.
func Get(key string) interface{} { return v.Get(key) }

func isStringMapInterface(val interface{}) bool {
vt := reflect.TypeOf(val)
return vt.Kind() == reflect.Map &&
vt.Key().Kind() == reflect.String &&
vt.Elem().Kind() == reflect.Interface
}

func (v *Viper) Get(key string) interface{} {
lcaseKey := strings.ToLower(key)
val := v.find(lcaseKey, true)
if val == nil {
return nil
}

// when section is partially overridden,
// make sure to return the complete map.
if isStringMapInterface(val) {
val := val.(map[string]interface{})
prefix := lcaseKey + v.keyDelim
keys := v.AllKeys()
for _, key := range keys {
if !strings.HasPrefix(key, prefix) {
continue
}
mk := strings.TrimPrefix(key, prefix)
mk = strings.Split(mk, v.keyDelim)[0]
if _, exists := val[mk]; exists {
continue
}
mv := v.Get(lcaseKey + v.keyDelim + mk)
if mv == nil {
continue
}
val[mk] = mv
}
}

if v.typeByDefValue {
// TODO(bep) this branch isn't covered by a single test.
valType := val
Expand Down

0 comments on commit 6e88dd6

Please sign in to comment.