Skip to content

Commit

Permalink
fix: panic on embedded struct with recursive
Browse files Browse the repository at this point in the history
fixes goccy#459
  • Loading branch information
NgoKimPhu committed Nov 3, 2023
1 parent df897ae commit dd524a4
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 26 deletions.
7 changes: 0 additions & 7 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,6 @@ func Test_Decoder(t *testing.T) {
assertEq(t, "interface{}", v.F, nil)
assertEq(t, "nilfunc", true, v.G == nil)
})
t.Run("struct.pointer must be nil", func(t *testing.T) {
var v struct {
A *int
}
json.Unmarshal([]byte(`{"a": "alpha"}`), &v)
assertEq(t, "struct.A", v.A, (*int)(nil))
})
})
t.Run("interface", func(t *testing.T) {
t.Run("number", func(t *testing.T) {
Expand Down
1 change: 0 additions & 1 deletion internal/decoder/ptr.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ func (d *ptrDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.P
}
c, err := d.dec.Decode(ctx, cursor, depth, newptr)
if err != nil {
*(*unsafe.Pointer)(p) = nil
return 0, err
}
cursor = c
Expand Down
1 change: 1 addition & 0 deletions internal/encoder/code.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ func (c *StructCode) ToAnonymousOpcode(ctx *compileContext) Opcodes {
prevField = firstField
codes = codes.Add(fieldCodes...)
}
ctx.structTypeToCodes[uintptr(unsafe.Pointer(c.typ))] = codes
return codes
}

Expand Down
42 changes: 24 additions & 18 deletions internal/encoder/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*Opcod

type Compiler struct {
structTypeToCode map[uintptr]*StructCode
anonymousStructTypeToCode map[uintptr]*StructCode
}

func newCompiler() *Compiler {
return &Compiler{
structTypeToCode: map[uintptr]*StructCode{},
anonymousStructTypeToCode: map[uintptr]*StructCode{},
}
}

Expand Down Expand Up @@ -169,11 +171,11 @@ func (c *Compiler) typeToCode(typ *runtime.Type) (Code, error) {
return c.sliceCode(typ)
case reflect.Map:
if isPtr {
return c.ptrCode(runtime.PtrTo(typ))
return c.ptrCode(runtime.PtrTo(typ), false)
}
return c.mapCode(typ)
case reflect.Struct:
return c.structCode(typ, isPtr)
return c.structCode(typ, isPtr, false)
case reflect.Int:
return c.intCode(typ, isPtr)
case reflect.Int8:
Expand Down Expand Up @@ -208,11 +210,11 @@ func (c *Compiler) typeToCode(typ *runtime.Type) (Code, error) {
if isPtr && typ.Implements(marshalTextType) {
typ = orgType
}
return c.typeToCodeWithPtr(typ, isPtr)
return c.typeToCodeWithPtr(typ, isPtr, false)
}
}

func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error) {
func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr, isAnonymous bool) (Code, error) {
switch {
case c.implementsMarshalJSON(typ):
return c.marshalJSONCode(typ)
Expand All @@ -221,7 +223,7 @@ func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error
}
switch typ.Kind() {
case reflect.Ptr:
return c.ptrCode(typ)
return c.ptrCode(typ, isAnonymous)
case reflect.Slice:
elem := typ.Elem()
if elem.Kind() == reflect.Uint8 {
Expand All @@ -236,7 +238,7 @@ func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error
case reflect.Map:
return c.mapCode(typ)
case reflect.Struct:
return c.structCode(typ, isPtr)
return c.structCode(typ, isPtr, isAnonymous)
case reflect.Interface:
return c.interfaceCode(typ, false)
case reflect.Int:
Expand Down Expand Up @@ -424,8 +426,8 @@ func (c *Compiler) marshalTextCode(typ *runtime.Type) (*MarshalTextCode, error)
}, nil
}

func (c *Compiler) ptrCode(typ *runtime.Type) (*PtrCode, error) {
code, err := c.typeToCodeWithPtr(typ.Elem(), true)
func (c *Compiler) ptrCode(typ *runtime.Type, isAnonymous bool) (*PtrCode, error) {
code, err := c.typeToCodeWithPtr(typ.Elem(), true, isAnonymous)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -485,12 +487,12 @@ func (c *Compiler) listElemCode(typ *runtime.Type) (Code, error) {
case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType):
return c.marshalTextCode(typ)
case typ.Kind() == reflect.Map:
return c.ptrCode(runtime.PtrTo(typ))
return c.ptrCode(runtime.PtrTo(typ), false)
default:
// isPtr was originally used to indicate whether the type of top level is pointer.
// However, since the slice/array element is a specification that can get the pointer address, explicitly set isPtr to true.
// See here for related issues: https://github.com/goccy/go-json/issues/370
code, err := c.typeToCodeWithPtr(typ, true)
code, err := c.typeToCodeWithPtr(typ, true, false)
if err != nil {
return nil, err
}
Expand All @@ -511,7 +513,7 @@ func (c *Compiler) mapKeyCode(typ *runtime.Type) (Code, error) {
}
switch typ.Kind() {
case reflect.Ptr:
return c.ptrCode(typ)
return c.ptrCode(typ, false)
case reflect.String:
return c.stringCode(typ, false)
case reflect.Int:
Expand Down Expand Up @@ -543,9 +545,9 @@ func (c *Compiler) mapKeyCode(typ *runtime.Type) (Code, error) {
func (c *Compiler) mapValueCode(typ *runtime.Type) (Code, error) {
switch typ.Kind() {
case reflect.Map:
return c.ptrCode(runtime.PtrTo(typ))
return c.ptrCode(runtime.PtrTo(typ), false)
default:
code, err := c.typeToCodeWithPtr(typ, false)
code, err := c.typeToCodeWithPtr(typ, false, false)
if err != nil {
return nil, err
}
Expand All @@ -559,16 +561,20 @@ func (c *Compiler) mapValueCode(typ *runtime.Type) (Code, error) {
}
}

func (c *Compiler) structCode(typ *runtime.Type, isPtr bool) (*StructCode, error) {
func (c *Compiler) structCode(typ *runtime.Type, isPtr, isAnonymous bool) (*StructCode, error) {
typeptr := uintptr(unsafe.Pointer(typ))
if code, exists := c.structTypeToCode[typeptr]; exists {
structTypeToCode := c.structTypeToCode
if isAnonymous {
structTypeToCode = c.anonymousStructTypeToCode
}
if code, exists := structTypeToCode[typeptr]; exists {
derefCode := *code
derefCode.isRecursive = true
return &derefCode, nil
}
indirect := runtime.IfaceIndir(typ)
code := &StructCode{typ: typ, isPtr: isPtr, isIndirect: indirect}
c.structTypeToCode[typeptr] = code
structTypeToCode[typeptr] = code

fieldNum := typ.NumField()
tags := c.typeToStructTags(typ)
Expand Down Expand Up @@ -613,7 +619,7 @@ func (c *Compiler) structCode(typ *runtime.Type, isPtr bool) (*StructCode, error
if !code.disableIndirectConversion && !indirect && isPtr {
code.enableIndirect()
}
delete(c.structTypeToCode, typeptr)
delete(structTypeToCode, typeptr)
return code, nil
}

Expand Down Expand Up @@ -680,7 +686,7 @@ func (c *Compiler) structFieldCode(structCode *StructCode, tag *runtime.StructTa
fieldCode.isAddrForMarshaler = true
fieldCode.isNilCheck = false
default:
code, err := c.typeToCodeWithPtr(fieldType, isPtr)
code, err := c.typeToCodeWithPtr(fieldType, isPtr, fieldCode.isAnonymous)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit dd524a4

Please sign in to comment.