diff --git a/issue66_test.go b/issue66_test.go index 9e4bcce..ebc4276 100644 --- a/issue66_test.go +++ b/issue66_test.go @@ -21,7 +21,7 @@ func TestPrivateSlice(t *testing.T) { t.Fatalf("Error during the merge: %v", err) } if len(p1.PublicStrings) != 3 { - t.Error("5 elements should be in 'PublicStrings' field") + t.Error("3 elements should be in 'PublicStrings' field, when no append") } if len(p1.privateStrings) != 2 { t.Error("2 elements should be in 'privateStrings' field") diff --git a/issue90_test.go b/issue90_test.go new file mode 100644 index 0000000..311bf78 --- /dev/null +++ b/issue90_test.go @@ -0,0 +1,50 @@ +package mergo + +import ( + "reflect" + "testing" +) + +type CustomStruct struct { + SomeMap map[string]string +} + +func TestMergoStructMap(t *testing.T) { + var testData = []struct { + name string + src map[string]CustomStruct + dst map[string]CustomStruct + exp map[string]CustomStruct + }{ + {name: "Normal", + dst: map[string]CustomStruct{"a": CustomStruct{SomeMap: map[string]string{"key1": "loosethis", "key2": "keepthis"}}}, + src: map[string]CustomStruct{"a": CustomStruct{SomeMap: map[string]string{"key1": "key10"}}}, + exp: map[string]CustomStruct{"a": CustomStruct{SomeMap: map[string]string{"key1": "key10", "key2": "keepthis"}}}, + }, + {name: "Init of struct key", dst: map[string]CustomStruct{"a": CustomStruct{SomeMap: map[string]string{}}}, + src: map[string]CustomStruct{"a": CustomStruct{SomeMap: map[string]string{"key1": "key10"}}}, + exp: map[string]CustomStruct{"a": CustomStruct{SomeMap: map[string]string{"key1": "key10"}}}, + }, + {name: "Not Init of struct key", dst: map[string]CustomStruct{}, + src: map[string]CustomStruct{"a": CustomStruct{SomeMap: map[string]string{"key1": "key10"}}}, + exp: map[string]CustomStruct{"a": CustomStruct{SomeMap: map[string]string{"key1": "key10"}}}, + }, + {name: "Nil struct key", dst: map[string]CustomStruct{"a": CustomStruct{SomeMap: nil}}, + src: map[string]CustomStruct{"a": CustomStruct{SomeMap: map[string]string{"key1": "key10"}}}, + exp: map[string]CustomStruct{"a": CustomStruct{SomeMap: map[string]string{"key1": "key10"}}}}, + } + + for _, data := range testData { + dst := data.dst + src := data.src + exp := data.exp + + err := Merge(&dst, src, WithAppendSlice, WithOverride) + if err != nil { + t.Errorf("mergo error was not nil, %v", err) + } + if !reflect.DeepEqual(dst, exp) { + t.Errorf("Actual: %#v did not match \nExpected: %#v", dst, exp) + } + } +} diff --git a/map.go b/map.go index 3f5afa8..d83258b 100644 --- a/map.go +++ b/map.go @@ -99,11 +99,11 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int, conf continue } if srcKind == dstKind { - if err = deepMerge(dstElement, srcElement, visited, depth+1, config); err != nil { + if _, err = deepMerge(dstElement, srcElement, visited, depth+1, config); err != nil { return } } else if dstKind == reflect.Interface && dstElement.Kind() == reflect.Interface { - if err = deepMerge(dstElement, srcElement, visited, depth+1, config); err != nil { + if _, err = deepMerge(dstElement, srcElement, visited, depth+1, config); err != nil { return } } else if srcKind == reflect.Map { @@ -157,7 +157,8 @@ func _map(dst, src interface{}, opts ...func(*Config)) error { // To be friction-less, we redirect equal-type arguments // to deepMerge. Only because arguments can be anything. if vSrc.Kind() == vDst.Kind() { - return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, config) + _, err := deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, config) + return err } switch vSrc.Kind() { case reflect.Struct: diff --git a/merge.go b/merge.go index 87eb70c..3332c9c 100644 --- a/merge.go +++ b/merge.go @@ -11,20 +11,32 @@ package mergo import ( "fmt" "reflect" + "unsafe" ) func hasExportedField(dst reflect.Value) (exported bool) { for i, n := 0, dst.NumField(); i < n; i++ { field := dst.Type().Field(i) - if field.Anonymous && dst.Field(i).Kind() == reflect.Struct { - exported = exported || hasExportedField(dst.Field(i)) - } else { - exported = exported || len(field.PkgPath) == 0 + if isExportedComponent(&field) { + return true } } return } +func isExportedComponent(field *reflect.StructField) bool { + name := field.Name + pkgPath := field.PkgPath + if len(pkgPath) > 0 { + return false + } + c := name[0] + if 'a' <= c && c <= 'z' || c == '_' { + return false + } + return true +} + type Config struct { Overwrite bool AppendSlice bool @@ -41,7 +53,8 @@ type Transformers interface { // Traverses recursively both values, assigning src's fields values to dst. // The map argument tracks comparisons that have already been seen, which allows // short circuiting on recursive types. -func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, config *Config) (err error) { +func deepMerge(dstIn, src reflect.Value, visited map[uintptr]*visit, depth int, config *Config) (dst reflect.Value, err error) { + dst = dstIn overwrite := config.Overwrite typeCheck := config.TypeCheck overwriteWithEmptySrc := config.overwriteWithEmptyValue @@ -50,6 +63,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co if !src.IsValid() { return } + if dst.CanAddr() { addr := dst.UnsafeAddr() h := 17 * addr @@ -57,7 +71,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co typ := dst.Type() for p := seen; p != nil; p = p.next { if p.ptr == addr && p.typ == typ { - return nil + return dst, nil } } // Remember, remember... @@ -71,114 +85,124 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co } } + if dst.IsValid() && src.IsValid() && src.Type() != dst.Type() { + err = fmt.Errorf("cannot append two different types (%s, %s)", src.Kind(), dst.Kind()) + return + } + switch dst.Kind() { case reflect.Struct: if hasExportedField(dst) { + dstCp := reflect.New(dst.Type()).Elem() for i, n := 0, dst.NumField(); i < n; i++ { - if err = deepMerge(dst.Field(i), src.Field(i), visited, depth+1, config); err != nil { + dstField := dst.Field(i) + structField := dst.Type().Field(i) + // copy un-exported struct fields + if !isExportedComponent(&structField) { + rf := dstCp.Field(i) + rf = reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem() //nolint:gosec + dstRF := dst.Field(i) + if !dst.Field(i).CanAddr() { + continue + } + + dstRF = reflect.NewAt(dstRF.Type(), unsafe.Pointer(dstRF.UnsafeAddr())).Elem() //nolint:gosec + rf.Set(dstRF) + continue + } + dstField, err = deepMerge(dstField, src.Field(i), visited, depth+1, config) + if err != nil { return } + dstCp.Field(i).Set(dstField) } + + if dst.CanSet() { + dst.Set(dstCp) + } else { + dst = dstCp + } + return } else { - if dst.CanSet() && (!isEmptyValue(src) || overwriteWithEmptySrc) && (overwrite || isEmptyValue(dst)) { - dst.Set(src) + if (isReflectNil(dst) || overwrite) && (!isEmptyValue(src) || overwriteWithEmptySrc) { + dst = src } } + case reflect.Map: if dst.IsNil() && !src.IsNil() { - dst.Set(reflect.MakeMap(dst.Type())) + if dst.CanSet() { + dst.Set(reflect.MakeMap(dst.Type())) + } else { + dst = src + return + } } for _, key := range src.MapKeys() { srcElement := src.MapIndex(key) + dstElement := dst.MapIndex(key) if !srcElement.IsValid() { continue } - dstElement := dst.MapIndex(key) - switch srcElement.Kind() { - case reflect.Chan, reflect.Func, reflect.Map, reflect.Interface, reflect.Slice: - if srcElement.IsNil() { - continue - } - fallthrough - default: - if !srcElement.CanInterface() { - continue - } - switch reflect.TypeOf(srcElement.Interface()).Kind() { - case reflect.Struct: - fallthrough - case reflect.Ptr: - fallthrough - case reflect.Map: - srcMapElm := srcElement - dstMapElm := dstElement - if srcMapElm.CanInterface() { - srcMapElm = reflect.ValueOf(srcMapElm.Interface()) - if dstMapElm.IsValid() { - dstMapElm = reflect.ValueOf(dstMapElm.Interface()) - } - } - if err = deepMerge(dstMapElm, srcMapElm, visited, depth+1, config); err != nil { - return - } - case reflect.Slice: - srcSlice := reflect.ValueOf(srcElement.Interface()) - - var dstSlice reflect.Value - if !dstElement.IsValid() || dstElement.IsNil() { - dstSlice = reflect.MakeSlice(srcSlice.Type(), 0, srcSlice.Len()) - } else { - dstSlice = reflect.ValueOf(dstElement.Interface()) - } - - if (!isEmptyValue(src) || overwriteWithEmptySrc || overwriteSliceWithEmptySrc) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice { - if typeCheck && srcSlice.Type() != dstSlice.Type() { - return fmt.Errorf("cannot override two slices with different type (%s, %s)", srcSlice.Type(), dstSlice.Type()) - } - dstSlice = srcSlice - } else if config.AppendSlice { - if srcSlice.Type() != dstSlice.Type() { - return fmt.Errorf("cannot append two slices with different type (%s, %s)", srcSlice.Type(), dstSlice.Type()) - } - dstSlice = reflect.AppendSlice(dstSlice, srcSlice) - } - dst.SetMapIndex(key, dstSlice) + if dst.MapIndex(key).IsValid() { + k := dstElement.Interface() + dstElement = reflect.ValueOf(k) + } + if isReflectNil(srcElement) { + if overwrite || isReflectNil(dstElement) { + dst.SetMapIndex(key, srcElement) } + continue } - if dstElement.IsValid() && !isEmptyValue(dstElement) && (reflect.TypeOf(srcElement.Interface()).Kind() == reflect.Map || reflect.TypeOf(srcElement.Interface()).Kind() == reflect.Slice) { + if !srcElement.CanInterface() { continue } - if srcElement.IsValid() && ((srcElement.Kind() != reflect.Ptr && overwrite) || !dstElement.IsValid() || isEmptyValue(dstElement)) { - if dst.IsNil() { - dst.Set(reflect.MakeMap(dst.Type())) + if srcElement.CanInterface() { + srcElement = reflect.ValueOf(srcElement.Interface()) + if dstElement.IsValid() { + dstElement = reflect.ValueOf(dstElement.Interface()) } - dst.SetMapIndex(key, srcElement) } + dstElement, err = deepMerge(dstElement, srcElement, visited, depth+1, config) + if err != nil { + return + } + dst.SetMapIndex(key, dstElement) + } case reflect.Slice: - if !dst.CanSet() { - break - } + newSlice := dst if (!isEmptyValue(src) || overwriteWithEmptySrc || overwriteSliceWithEmptySrc) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice { - dst.Set(src) + if typeCheck && src.Type() != dst.Type() { + return dst, fmt.Errorf("cannot override two slices with different type (%s, %s)", src.Type(), dst.Type()) + } + newSlice = src } else if config.AppendSlice { - if src.Type() != dst.Type() { - return fmt.Errorf("cannot append two slice with different type (%s, %s)", src.Type(), dst.Type()) + if typeCheck && src.Type() != dst.Type() { + err = fmt.Errorf("cannot append two slice with different type (%s, %s)", src.Type(), dst.Type()) + return } - dst.Set(reflect.AppendSlice(dst, src)) + newSlice = reflect.AppendSlice(dst, src) } - case reflect.Ptr: - fallthrough - case reflect.Interface: - if src.IsNil() { + if dst.CanSet() { + dst.Set(newSlice) + } else { + dst = newSlice + } + case reflect.Ptr, reflect.Interface: + if isReflectNil(src) { break } if dst.Kind() != reflect.Ptr && src.Type().AssignableTo(dst.Type()) { if dst.IsNil() || overwrite { - if dst.CanSet() && (overwrite || isEmptyValue(dst)) { - dst.Set(src) + if overwrite || isEmptyValue(dst) { + if dst.CanSet() { + dst.Set(src) + } else { + dst = src + } } } break @@ -190,28 +214,38 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co dst.Set(src) } } else if src.Kind() == reflect.Ptr { - if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, config); err != nil { + if dst, err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, config); err != nil { return } + dst = dst.Addr() } else if dst.Elem().Type() == src.Type() { - if err = deepMerge(dst.Elem(), src, visited, depth+1, config); err != nil { + if dst, err = deepMerge(dst.Elem(), src, visited, depth+1, config); err != nil { return } } else { - return ErrDifferentArgumentsTypes + return dst, ErrDifferentArgumentsTypes } break } if dst.IsNil() || overwrite { - if dst.CanSet() && (overwrite || isEmptyValue(dst)) { - dst.Set(src) + if (overwrite || isEmptyValue(dst)) && (overwriteWithEmptySrc || !isEmptyValue(src)) { + if dst.CanSet() { + dst.Set(src) + } else { + dst = src + } } - } else if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, config); err != nil { + } else if _, err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, config); err != nil { return } default: - if dst.CanSet() && (!isEmptyValue(src) || overwriteWithEmptySrc) && (overwrite || isEmptyValue(dst)) { - dst.Set(src) + overwriteFull := (!isEmptyValue(src) || overwriteWithEmptySrc) && (overwrite || isEmptyValue(dst)) + if overwriteFull { + if dst.CanSet() { + dst.Set(src) + } else { + dst = src + } } } @@ -280,8 +314,25 @@ func merge(dst, src interface{}, opts ...func(*Config)) error { if vDst, vSrc, err = resolveValues(dst, src); err != nil { return err } + if !vDst.CanSet() { + return fmt.Errorf("cannot set dst, needs reference") + } if vDst.Type() != vSrc.Type() { return ErrDifferentArgumentsTypes } - return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, config) + _, err = deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, config) + return err +} + +// IsReflectNil is the reflect value provided nil +func isReflectNil(v reflect.Value) bool { + k := v.Kind() + switch k { + case reflect.Interface, reflect.Slice, reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr: + // Both interface and slice are nil if first word is 0. + // Both are always bigger than a word; assume flagIndir. + return v.IsNil() + default: + return false + } } diff --git a/mergo_test.go b/mergo_test.go index 1786824..cb28e4d 100644 --- a/mergo_test.go +++ b/mergo_test.go @@ -129,10 +129,10 @@ func TestComplexStruct(t *testing.T) { } func TestComplexStructWithOverwrite(t *testing.T) { - a := complexTest{simpleTest{1}, 1, "do-not-overwrite-with-empty-value"} - b := complexTest{simpleTest{42}, 2, ""} + a := complexTest{St: simpleTest{1}, sz: 1, ID: "do-not-overwrite-with-empty-value"} + b := complexTest{St: simpleTest{42}, sz: 2, ID: ""} - expect := complexTest{simpleTest{42}, 1, "do-not-overwrite-with-empty-value"} + expect := complexTest{St: simpleTest{42}, sz: 1, ID: "do-not-overwrite-with-empty-value"} if err := MergeWithOverwrite(&a, b); err != nil { t.FailNow() } @@ -156,7 +156,7 @@ func TestPointerStruct(t *testing.T) { } type embeddingStruct struct { - embeddedStruct + A embeddedStruct } type embeddedStruct struct { @@ -345,13 +345,13 @@ func TestEmptyToNotEmptyMaps(t *testing.T) { } func TestMapsWithOverwrite(t *testing.T) { - m := map[string]simpleTest{ + dst := map[string]simpleTest{ "a": {}, // overwritten by 16 "b": {42}, // overwritten by 0, as map Value is not addressable and it doesn't check for b is set or not set in `n` "c": {13}, // overwritten by 12 "d": {61}, } - n := map[string]simpleTest{ + src := map[string]simpleTest{ "a": {16}, "b": {}, "c": {12}, @@ -359,18 +359,18 @@ func TestMapsWithOverwrite(t *testing.T) { } expect := map[string]simpleTest{ "a": {16}, - "b": {}, + "b": {42}, "c": {12}, "d": {61}, "e": {14}, } - if err := MergeWithOverwrite(&m, n); err != nil { + if err := MergeWithOverwrite(&dst, src); err != nil { t.Fatalf(err.Error()) } - if !reflect.DeepEqual(m, expect) { - t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", m, expect) + if !reflect.DeepEqual(dst, expect) { + t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", dst, expect) } } @@ -536,23 +536,13 @@ func TestMergeUsingStructAndMap(t *testing.T) { } func TestMaps(t *testing.T) { m := map[string]simpleTest{ - "a": {}, - "b": {42}, - "c": {13}, - "d": {61}, + "a": {0}, "b": {42}, "c": {13}, "d": {61}, } n := map[string]simpleTest{ - "a": {16}, - "b": {}, - "c": {12}, - "e": {14}, + "a": {16}, "b": {}, "c": {12}, "e": {14}, } expect := map[string]simpleTest{ - "a": {0}, - "b": {42}, - "c": {13}, - "d": {61}, - "e": {14}, + "a": {16}, "b": {42}, "c": {13}, "d": {61}, "e": {14}, } if err := Merge(&m, n); err != nil { @@ -562,7 +552,7 @@ func TestMaps(t *testing.T) { if !reflect.DeepEqual(m, expect) { t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", m, expect) } - if m["a"].Value != 0 { + if m["a"].Value != 16 { t.Fatalf(`n merged in m because I solved non-addressable map values TODO: m["a"].Value(%d) != n["a"].Value(%d)`, m["a"].Value, n["a"].Value) } if m["b"].Value != 42 { @@ -903,12 +893,12 @@ func TestMergeMapWithInnerSliceOfDifferentType(t *testing.T) { { "With override and append slice", []func(*Config){WithOverride, WithAppendSlice}, - "cannot append two slices with different type", + "cannot append two different types (slice, slice)", }, { "With override and type check", []func(*Config){WithOverride, WithTypeCheck}, - "cannot override two slices with different type", + "cannot append two different types (slice, slice)", }, } for _, tc := range testCases {