Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Wrong target type restriction when decoding struct to map #313

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 19 additions & 6 deletions mapstructure.go
Original file line number Diff line number Diff line change
Expand Up @@ -918,9 +918,6 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re
// Next get the actual value of this field and verify it is assignable
// to the map value.
v := dataVal.Field(i)
if !v.Type().AssignableTo(valMap.Type().Elem()) {
return fmt.Errorf("cannot assign type '%s' to map value field of type '%s'", v.Type(), valMap.Type().Elem())
}

tagValue := f.Tag.Get(d.config.TagName)
keyName := f.Name
Expand Down Expand Up @@ -973,9 +970,21 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re
x := reflect.New(v.Type())
x.Elem().Set(v)

vType := valMap.Type()
vKeyType := vType.Key()
vElemType := vType.Elem()
var vKeyType reflect.Type
var vElemType reflect.Type
switch valMap.Type().Elem().Kind() {
case reflect.Map:
// When the target field is a typed map, use the map type
vType := valMap.Type().Elem()
vKeyType = vType.Key()
vElemType = vType.Elem()

default:
// For any other target field type, use the root map type (map[string]interface{})
vKeyType = valMap.Type().Key()
vElemType = valMap.Type().Elem()
}

mType := reflect.MapOf(vKeyType, vElemType)
vMap := reflect.MakeMap(mType)

Expand Down Expand Up @@ -1004,6 +1013,10 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re
}

default:
if !v.Type().AssignableTo(valMap.Type().Elem()) {
return fmt.Errorf("cannot assign type '%s' to map value field of type '%s'", v.Type(), valMap.Type().Elem())
}

valMap.SetMapIndex(reflect.ValueOf(keyName), v)
}
}
Expand Down
88 changes: 88 additions & 0 deletions mapstructure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2788,6 +2788,94 @@ func testArrayInput(t *testing.T, input map[string]interface{}, expected *Array)
}
}

func TestDecode_structToGenericMap(t *testing.T) {
type SourceChild struct {
String string `mapstructure:"string"`
Int int `mapstructure:"int"`
Map map[string]float32 `mapstructure:"map"`
}

type SourceParent struct {
Child SourceChild `mapstructure:"child"`
}

var target map[string]interface{}

source := SourceParent{
Child: SourceChild{
String: "hello",
Int: 1,
Map: map[string]float32{
"one": 1.0,
"two": 2.0,
},
},
}

if err := Decode(source, &target); err != nil {
t.Fatalf("got error: %s", err)
}

expected := map[string]interface{}{
"child": map[string]interface{}{
"string": "hello",
"int": 1,
"map": map[string]float32{
"one": 1.0,
"two": 2.0,
},
},
}

if !reflect.DeepEqual(target, expected) {
t.Fatalf("bad: \nexpected: %#v\nresult: %#v", expected, target)
}
}

func TestDecode_structToTypedMap(t *testing.T) {
type SourceChild struct {
String string `mapstructure:"string"`
Int int `mapstructure:"int"`
Map map[string]float32 `mapstructure:"map"`
}

type SourceParent struct {
Child SourceChild `mapstructure:"child"`
}

var target map[string]map[string]interface{}

source := SourceParent{
Child: SourceChild{
String: "hello",
Int: 1,
Map: map[string]float32{
"one": 1.0,
"two": 2.0,
},
},
}

if err := Decode(source, &target); err != nil {
t.Fatalf("got error: %s", err)
}

expected := map[string]map[string]interface{}{
"child": {
"string": "hello",
"int": 1,
"map": map[string]float32{
"one": 1.0,
"two": 2.0,
},
},
}

if !reflect.DeepEqual(target, expected) {
t.Fatalf("bad: \nexpected: %#v\nresult: %#v", expected, target)
}
}

func stringPtr(v string) *string { return &v }
func intPtr(v int) *int { return &v }
func uintPtr(v uint) *uint { return &v }
Expand Down