diff --git a/viper.go b/viper.go index e5accf7ab..3481d1da1 100644 --- a/viper.go +++ b/viper.go @@ -1780,17 +1780,6 @@ func mergeMaps( svType := reflect.TypeOf(sv) tvType := reflect.TypeOf(tv) - if tvType != nil && svType != tvType { // Allow for the target to be nil - v.logger.Error( - "svType != tvType", - "key", sk, - "st", svType, - "tt", tvType, - "sv", sv, - "tv", tv, - ) - continue - } v.logger.Trace( "processing", @@ -1804,13 +1793,27 @@ func mergeMaps( switch ttv := tv.(type) { case map[interface{}]interface{}: v.logger.Trace("merging maps (must convert)") - tsv := sv.(map[interface{}]interface{}) + tsv, ok := sv.(map[interface{}]interface{}) + if !ok { + v.logger.Error( + "Could not cast sv to map[interface{}]interface{}; key=%s, st=%v, tt=%v, sv=%v, tv=%v", + sk, svType, tvType, sv, tv) + continue + } + ssv := castToMapStringInterface(tsv) stv := castToMapStringInterface(ttv) mergeMaps(ssv, stv, ttv) case map[string]interface{}: v.logger.Trace("merging maps") - mergeMaps(sv.(map[string]interface{}), ttv, nil) + tsv, ok := sv.(map[string]interface{}) + if !ok { + v.logger.Error( + "Could not cast sv to map[string]interface{}; key=%s, st=%v, tt=%v, sv=%v, tv=%v", + sk, svType, tvType, sv, tv) + continue + } + mergeMaps(tsv, ttv, nil) default: v.logger.Trace("setting value") tgt[tk] = sv diff --git a/viper_test.go b/viper_test.go index 8c864f117..8a5dec688 100644 --- a/viper_test.go +++ b/viper_test.go @@ -1912,6 +1912,24 @@ hello: fu: bar `) +var jsonMergeExampleTgt = []byte(` +{ + "hello": { + "foo": null, + "pop": 123456 + } +} +`) + +var jsonMergeExampleSrc = []byte(` +{ + "hello": { + "foo": "foo str", + "pop": "pop str" + } +} +`) + func TestMergeConfig(t *testing.T) { v := New() v.SetConfigType("yml") @@ -1984,6 +2002,26 @@ func TestMergeConfig(t *testing.T) { } } +func TestMergeConfigOverrideType(t *testing.T) { + v := New() + v.SetConfigType("json") + if err := v.ReadConfig(bytes.NewBuffer(jsonMergeExampleTgt)); err != nil { + t.Fatal(err) + } + + if err := v.MergeConfig(bytes.NewBuffer(jsonMergeExampleSrc)); err != nil { + t.Fatal(err) + } + + if pop := v.GetString("hello.pop"); pop != "pop str" { + t.Fatalf("pop != \"pop str\", = %s", pop) + } + + if foo := v.GetString("hello.foo"); foo != "foo str" { + t.Fatalf("foo != \"foo str\", = %s", foo) + } +} + func TestMergeConfigNoMerge(t *testing.T) { v := New() v.SetConfigType("yml")