Skip to content

Commit

Permalink
GODRIVER-1808 Fix unmarshaling BSON into interfaces containing concre…
Browse files Browse the repository at this point in the history
…te values.
  • Loading branch information
qingyang-hu committed Mar 20, 2024
1 parent 7252a3c commit 1db6753
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 12 deletions.
7 changes: 4 additions & 3 deletions bson/bsoncodec/bsoncodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,13 @@ var _ typeDecoder = decodeAdapter{}
// t and calls decoder.DecodeValue on it.
func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
td, _ := decoder.(typeDecoder)
return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t, true)
val := reflect.New(t).Elem()
return decodeTypeOrValueWithInfo(decoder, td, dc, vr, val, true)
}

func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type, convert bool) (reflect.Value, error) {
func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value, convert bool) (reflect.Value, error) {
if td != nil {
t := val.Type()
val, err := td.decodeType(dc, vr, t)
if err == nil && convert && val.Type() != t {
// This conversion step is necessary for slices and maps. If a user declares variables like:
Expand All @@ -366,7 +368,6 @@ func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext
return val, err
}

val := reflect.New(t).Elem()
err := vd.DecodeValue(dc, vr, val)
return val, err
}
Expand Down
44 changes: 36 additions & 8 deletions bson/bsoncodec/default_value_decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueRe
}

// Pass false for convert because we don't need to call reflect.Value.Convert for tEmpty.
elem, err := decodeTypeOrValueWithInfo(decoder, tEmptyTypeDecoder, dc, elemVr, tEmpty, false)
elem, err := decodeTypeOrValueWithInfo(decoder, tEmptyTypeDecoder, dc, elemVr, reflect.New(tEmpty).Elem(), false)
if err != nil {
return err
}
Expand Down Expand Up @@ -1666,11 +1666,15 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR

eType := val.Type().Elem()

decoder, err := dc.LookupDecoder(eType)
if err != nil {
return nil, err
var vDecoder ValueDecoder
var tDecoder typeDecoder
if !(eType.Kind() == reflect.Interface && val.Len() > 0) {
vDecoder, err = dc.LookupDecoder(eType)
if err != nil {
return nil, err
}
tDecoder, _ = vDecoder.(typeDecoder)
}
eTypeDecoder, _ := decoder.(typeDecoder)

idx := 0
for {
Expand All @@ -1682,10 +1686,34 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR
return nil, err
}

elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
if err != nil {
return nil, newDecodeError(strconv.Itoa(idx), err)
var elem reflect.Value
if vDecoder == nil {
e := val.Index(idx).Elem()
valueDecoder, err := dc.LookupDecoder(e.Type())
if err != nil {
return nil, err
}
typeDecoder, _ := valueDecoder.(typeDecoder)
if e.Kind() == reflect.Ptr {
v := reflect.New(e.Type()).Elem()
v.Set(e)
val.Index(idx).Set(v)
e = v
}
elem, err = decodeTypeOrValueWithInfo(valueDecoder, typeDecoder, dc, vr, e, true)
if err != nil {
return nil, newDecodeError(strconv.Itoa(idx), err)
}
if e.Kind() == reflect.Ptr && e.IsZero() {
elem = reflect.Zero(val.Index(idx).Type())
}
} else {
elem, err = decodeTypeOrValueWithInfo(vDecoder, tDecoder, dc, vr, reflect.New(eType).Elem(), true)
if err != nil {
return nil, newDecodeError(strconv.Itoa(idx), err)
}
}

elems = append(elems, elem)
idx++
}
Expand Down
2 changes: 1 addition & 1 deletion bson/bsoncodec/map_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref
return err
}

elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, reflect.New(eType).Elem(), true)
if err != nil {
return newDecodeError(key, err)
}
Expand Down
18 changes: 18 additions & 0 deletions bson/bsoncodec/struct_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,24 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val
}
}

if field.Kind() == reflect.Interface && !field.IsNil() && field.Elem().Kind() == reflect.Ptr {
decoder, err = dc.LookupDecoder(field.Elem().Type())
if err != nil {
return err
}
v := reflect.New(field.Elem().Type()).Elem()
v.Set(field.Elem())
field.Set(v)
err = decoder.DecodeValue(dc, vr, v)
if err != nil {
return newDecodeError(fd.name, err)
}
if v.IsZero() {
field.Set(reflect.Zero(field.Type()))
}
continue
}

if !field.CanSet() { // Being settable is a super set of being addressable.
innerErr := fmt.Errorf("field %v is not settable", field)
return newDecodeError(fd.name, innerErr)
Expand Down
156 changes: 156 additions & 0 deletions bson/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,162 @@ func TestBasicDecode(t *testing.T) {
}
}

func TestDecodingInterfaces(t *testing.T) {
t.Parallel()

type testCase struct {
name string
stub func() ([]byte, interface{}, func(*testing.T))
}
testCases := []testCase{
{
name: "struct with interface containing a concrete value",
stub: func() ([]byte, interface{}, func(*testing.T)) {
type testStruct struct {
Value interface{}
}
var value string

data := docToBytes(struct {
Value string
}{
Value: "foo",
})

receiver := testStruct{&value}

check := func(t *testing.T) {
t.Helper()
assert.Equal(t, "foo", value)
}

return data, &receiver, check
},
},
{
name: "struct with interface containing a struct",
stub: func() ([]byte, interface{}, func(*testing.T)) {
type demo struct {
Data string
}

type testStruct struct {
Value interface{}
}
var value demo

data := docToBytes(struct {
Value demo
}{
Value: demo{"foo"},
})

receiver := testStruct{&value}

check := func(t *testing.T) {
t.Helper()
assert.Equal(t, "foo", value.Data)
}

return data, &receiver, check
},
},
{
name: "struct with interface containing a slice",
stub: func() ([]byte, interface{}, func(*testing.T)) {
type testStruct struct {
Values interface{}
}
var values []string

data := docToBytes(struct {
Values []string
}{
Values: []string{"foo", "bar"},
})

receiver := testStruct{&values}

check := func(t *testing.T) {
t.Helper()
assert.Equal(t, []string{"foo", "bar"}, values)
}

return data, &receiver, check
},
},
{
name: "struct with interface containing an array",
stub: func() ([]byte, interface{}, func(*testing.T)) {
type testStruct struct {
Values interface{}
}
var values [2]string

data := docToBytes(struct {
Values []string
}{
Values: []string{"foo", "bar"},
})

receiver := testStruct{&values}

check := func(t *testing.T) {
t.Helper()
assert.Equal(t, [2]string{"foo", "bar"}, values)
}

return data, &receiver, check
},
},
{
name: "struct with interface array containing concrete values",
stub: func() ([]byte, interface{}, func(*testing.T)) {
type testStruct struct {
Values [3]interface{}
}
var str string
var i, j int

data := docToBytes(struct {
Values []interface{}
}{
Values: []interface{}{"foo", 42, nil},
})

receiver := testStruct{[3]interface{}{&str, &i, &j}}

check := func(t *testing.T) {
t.Helper()
assert.Equal(t, "foo", str)
assert.Equal(t, 42, i)
assert.Equal(t, 0, j)
assert.Equal(t, testStruct{[3]interface{}{&str, &i, nil}}, receiver)
}

return data, &receiver, check
},
},
}
for _, tc := range testCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

data, receiver, check := tc.stub()
got := reflect.ValueOf(receiver).Elem()
vr := bsonrw.NewValueReader(data)
reg := DefaultRegistry
decoder, err := reg.LookupDecoder(got.Type())
noerr(t, err)
err = decoder.DecodeValue(bsoncodec.DecodeContext{Registry: reg}, vr, got)
noerr(t, err)
check(t)
})
}
}

func TestDecoderv2(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 1db6753

Please sign in to comment.