diff --git a/decode_test.go b/decode_test.go index fb00c364..a0d4edad 100644 --- a/decode_test.go +++ b/decode_test.go @@ -152,6 +152,7 @@ func Test_Decoder(t *testing.T) { B string `json:"str"` C bool D *T + E func() } content := []byte(` { @@ -162,7 +163,8 @@ func Test_Decoder(t *testing.T) { "aa": 2, "bb": "world", "cc": true - } + }, + "e" : null }`) assertErr(t, json.Unmarshal(content, &v)) assertEq(t, "struct.A", 123, v.A) @@ -171,6 +173,7 @@ func Test_Decoder(t *testing.T) { assertEq(t, "struct.D.AA", 2, v.D.AA) assertEq(t, "struct.D.BB", "world", v.D.BB) assertEq(t, "struct.D.CC", true, v.D.CC) + assertEq(t, "struct.E", true, v.E == nil) t.Run("struct.field null", func(t *testing.T) { var v struct { A string @@ -179,8 +182,9 @@ func Test_Decoder(t *testing.T) { D map[string]interface{} E [2]string F interface{} + G func() } - assertErr(t, json.Unmarshal([]byte(`{"a":null,"b":null,"c":null,"d":null,"e":null,"f":null}`), &v)) + assertErr(t, json.Unmarshal([]byte(`{"a":null,"b":null,"c":null,"d":null,"e":null,"f":null,"g":null}`), &v)) assertEq(t, "string", v.A, "") assertNeq(t, "[]string", v.B, nil) assertEq(t, "[]string", len(v.B), 0) @@ -191,6 +195,7 @@ func Test_Decoder(t *testing.T) { assertNeq(t, "array", v.E, nil) assertEq(t, "array", len(v.E), 2) assertEq(t, "interface{}", v.F, nil) + assertEq(t, "nilfunc", true, v.G == nil) }) }) t.Run("interface", func(t *testing.T) { @@ -239,6 +244,11 @@ func Test_Decoder(t *testing.T) { assertEq(t, "interface", nil, v) }) }) + t.Run("func", func(t *testing.T) { + var v func() + assertErr(t, json.Unmarshal([]byte(`null`), &v)) + assertEq(t, "nilfunc", true, v == nil) + }) } func TestIssue98(t *testing.T) { diff --git a/internal/decoder/compile.go b/internal/decoder/compile.go index bd566870..08dd044e 100644 --- a/internal/decoder/compile.go +++ b/internal/decoder/compile.go @@ -123,11 +123,15 @@ func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecode return compileFloat32(structName, fieldName) case reflect.Float64: return compileFloat64(structName, fieldName) + case reflect.Func: + return compileFunc(typ, structName, fieldName) } return nil, &errors.UnmarshalTypeError{ Value: "object", Type: runtime.RType2Type(typ), Offset: 0, + Struct: structName, + Field: fieldName, } } @@ -178,6 +182,8 @@ ERROR: Value: "object", Type: runtime.RType2Type(typ), Offset: 0, + Struct: structName, + Field: fieldName, } } @@ -312,6 +318,10 @@ func compileInterface(typ *runtime.Type, structName, fieldName string) (Decoder, return newInterfaceDecoder(typ, structName, fieldName), nil } +func compileFunc(typ *runtime.Type, strutName, fieldName string) (Decoder, error) { + return newFuncDecoder(typ, strutName, fieldName), nil +} + func removeConflictFields(fieldMap map[string]*structFieldSet, conflictedMap map[string]struct{}, dec *structDecoder, field reflect.StructField) { for k, v := range dec.fieldMap { if _, exists := conflictedMap[k]; exists { diff --git a/internal/decoder/func.go b/internal/decoder/func.go new file mode 100644 index 00000000..75afe75c --- /dev/null +++ b/internal/decoder/func.go @@ -0,0 +1,141 @@ +package decoder + +import ( + "bytes" + "unsafe" + + "github.com/goccy/go-json/internal/errors" + "github.com/goccy/go-json/internal/runtime" +) + +type funcDecoder struct { + typ *runtime.Type + structName string + fieldName string +} + +func newFuncDecoder(typ *runtime.Type, structName, fieldName string) *funcDecoder { + fnDecoder := &funcDecoder{typ, structName, fieldName} + return fnDecoder +} + +func (d *funcDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) error { + s.skipWhiteSpace() + start := s.cursor + if err := s.skipValue(depth); err != nil { + return err + } + src := s.buf[start:s.cursor] + if len(src) > 0 { + switch src[0] { + case '"': + return &errors.UnmarshalTypeError{ + Value: "string", + Type: runtime.RType2Type(d.typ), + Offset: s.totalOffset(), + } + case '[': + return &errors.UnmarshalTypeError{ + Value: "array", + Type: runtime.RType2Type(d.typ), + Offset: s.totalOffset(), + } + case '{': + return &errors.UnmarshalTypeError{ + Value: "object", + Type: runtime.RType2Type(d.typ), + Offset: s.totalOffset(), + } + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + return &errors.UnmarshalTypeError{ + Value: "number", + Type: runtime.RType2Type(d.typ), + Offset: s.totalOffset(), + } + case 'n': + if err := nullBytes(s); err != nil { + return err + } + *(*unsafe.Pointer)(p) = nil + return nil + case 't': + if err := trueBytes(s); err == nil { + return &errors.UnmarshalTypeError{ + Value: "boolean", + Type: runtime.RType2Type(d.typ), + Offset: s.totalOffset(), + } + } + case 'f': + if err := falseBytes(s); err == nil { + return &errors.UnmarshalTypeError{ + Value: "boolean", + Type: runtime.RType2Type(d.typ), + Offset: s.totalOffset(), + } + } + } + } + return errors.ErrNotAtBeginningOfValue(start) +} + +func (d *funcDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.Pointer) (int64, error) { + buf := ctx.Buf + cursor = skipWhiteSpace(buf, cursor) + start := cursor + end, err := skipValue(buf, cursor, depth) + if err != nil { + return 0, err + } + src := buf[start:end] + if len(src) > 0 { + switch src[0] { + case '"': + return 0, &errors.UnmarshalTypeError{ + Value: "string", + Type: runtime.RType2Type(d.typ), + Offset: start, + } + case '[': + return 0, &errors.UnmarshalTypeError{ + Value: "array", + Type: runtime.RType2Type(d.typ), + Offset: start, + } + case '{': + return 0, &errors.UnmarshalTypeError{ + Value: "object", + Type: runtime.RType2Type(d.typ), + Offset: start, + } + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + return 0, &errors.UnmarshalTypeError{ + Value: "number", + Type: runtime.RType2Type(d.typ), + Offset: start, + } + case 'n': + if bytes.Equal(src, nullbytes) { + *(*unsafe.Pointer)(p) = nil + return end, nil + } + case 't': + if err := validateTrue(buf, start); err == nil { + return 0, &errors.UnmarshalTypeError{ + Value: "boolean", + Type: runtime.RType2Type(d.typ), + Offset: start, + } + } + case 'f': + if err := validateFalse(buf, start); err == nil { + return 0, &errors.UnmarshalTypeError{ + Value: "boolean", + Type: runtime.RType2Type(d.typ), + Offset: start, + } + } + } + } + return 0, errors.ErrNotAtBeginningOfValue(start) +}