Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep original reference of slice element #229

Merged
merged 1 commit into from May 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 35 additions & 14 deletions decode_slice.go
Expand Up @@ -49,9 +49,20 @@ func newSliceDecoder(dec decoder, elemType *rtype, size uintptr, structName, fie
}
}

func (d *sliceDecoder) newSlice() *sliceHeader {
func (d *sliceDecoder) newSlice(src *sliceHeader) *sliceHeader {
slice := d.arrayPool.Get().(*sliceHeader)
slice.len = 0
if src.len > 0 {
// copy original elem
if slice.cap < src.cap {
data := newArray(d.elemType, src.cap)
slice = &sliceHeader{data: data, len: src.len, cap: src.cap}
} else {
slice.len = src.len
}
copySlice(d.elemType, *slice, *src)
} else {
slice.len = 0
}
return slice
}

Expand Down Expand Up @@ -109,7 +120,8 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er
return nil
}
idx := 0
slice := d.newSlice()
slice := d.newSlice((*sliceHeader)(p))
srcLen := slice.len
capacity := slice.cap
data := slice.data
for {
Expand All @@ -121,12 +133,17 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er
copySlice(d.elemType, dst, src)
}
ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size)
if d.isElemPointerType {
**(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer
} else {
// assign new element to the slice
typedmemmove(d.elemType, ep, unsafe_New(d.elemType))

// if srcLen is greater than idx, keep the original reference
if srcLen <= idx {
if d.isElemPointerType {
**(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer
} else {
// assign new element to the slice
typedmemmove(d.elemType, ep, unsafe_New(d.elemType))
}
}

if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil {
return err
}
Expand Down Expand Up @@ -212,7 +229,8 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer)
return cursor, nil
}
idx := 0
slice := d.newSlice()
slice := d.newSlice((*sliceHeader)(p))
srcLen := slice.len
capacity := slice.cap
data := slice.data
for {
Expand All @@ -224,11 +242,14 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer)
copySlice(d.elemType, dst, src)
}
ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size)
if d.isElemPointerType {
**(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer
} else {
// assign new element to the slice
typedmemmove(d.elemType, ep, unsafe_New(d.elemType))
// if srcLen is greater than idx, keep the original reference
if srcLen <= idx {
if d.isElemPointerType {
**(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer
} else {
// assign new element to the slice
typedmemmove(d.elemType, ep, unsafe_New(d.elemType))
}
}
c, err := d.valueDecoder.decode(buf, cursor, depth, ep)
if err != nil {
Expand Down
35 changes: 30 additions & 5 deletions decode_test.go
Expand Up @@ -3052,7 +3052,7 @@ func TestMultipleDecodeWithRawMessage(t *testing.T) {
type intUnmarshaler int

func (u *intUnmarshaler) UnmarshalJSON(b []byte) error {
if *u != 0 {
if *u != 0 && *u != 10 {
return fmt.Errorf("failed to decode of slice with int unmarshaler")
}
*u = 10
Expand All @@ -3062,7 +3062,7 @@ func (u *intUnmarshaler) UnmarshalJSON(b []byte) error {
type arrayUnmarshaler [5]int

func (u *arrayUnmarshaler) UnmarshalJSON(b []byte) error {
if (*u)[0] != 0 {
if (*u)[0] != 0 && (*u)[0] != 10 {
return fmt.Errorf("failed to decode of slice with array unmarshaler")
}
(*u)[0] = 10
Expand All @@ -3072,22 +3072,24 @@ func (u *arrayUnmarshaler) UnmarshalJSON(b []byte) error {
type mapUnmarshaler map[string]int

func (u *mapUnmarshaler) UnmarshalJSON(b []byte) error {
if len(*u) != 0 {
if len(*u) != 0 && len(*u) != 1 {
return fmt.Errorf("failed to decode of slice with map unmarshaler")
}
*u = map[string]int{"a": 10}
return nil
}

type structUnmarshaler struct {
A int
A int
notFirst bool
}

func (u *structUnmarshaler) UnmarshalJSON(b []byte) error {
if u.A != 0 {
if !u.notFirst && u.A != 0 {
return fmt.Errorf("failed to decode of slice with struct unmarshaler")
}
u.A = 10
u.notFirst = true
return nil
}

Expand Down Expand Up @@ -3199,6 +3201,29 @@ func TestSliceElemUnmarshaler(t *testing.T) {
})
}

type keepRefTest struct {
A int
B string
}

func (t *keepRefTest) UnmarshalJSON(data []byte) error {
v := []interface{}{&t.A, &t.B}
return json.Unmarshal(data, &v)
}

func TestKeepReferenceSlice(t *testing.T) {
var v keepRefTest
if err := json.Unmarshal([]byte(`[54,"hello"]`), &v); err != nil {
t.Fatal(err)
}
if v.A != 54 {
t.Fatal("failed to keep reference for slice")
}
if v.B != "hello" {
t.Fatal("failed to keep reference for slice")
}
}

func TestInvalidTopLevelValue(t *testing.T) {
t.Run("invalid end of buffer", func(t *testing.T) {
var v struct{}
Expand Down