diff --git a/internal/go-json/decoder/compile.go b/internal/go-json/decoder/compile.go index 0e461622d9..02644b3474 100644 --- a/internal/go-json/decoder/compile.go +++ b/internal/go-json/decoder/compile.go @@ -9,7 +9,6 @@ import ( "unicode" "unsafe" - "github.com/gofiber/fiber/v2/internal/go-json/errors" "github.com/gofiber/fiber/v2/internal/go-json/runtime" ) @@ -126,13 +125,7 @@ func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecode case reflect.Func: return compileFunc(typ, structName, fieldName) } - return nil, &errors.UnmarshalTypeError{ - Value: "object", - Type: runtime.RType2Type(typ), - Offset: 0, - Struct: structName, - Field: fieldName, - } + return newInvalidDecoder(typ, structName, fieldName), nil } func isStringTagSupportedType(typ *runtime.Type) bool { @@ -174,17 +167,9 @@ func compileMapKey(typ *runtime.Type, structName, fieldName string, structTypeTo case *ptrDecoder: dec = t.dec default: - goto ERROR + return newInvalidDecoder(typ, structName, fieldName), nil } } -ERROR: - return nil, &errors.UnmarshalTypeError{ - Value: "object", - Type: runtime.RType2Type(typ), - Offset: 0, - Struct: structName, - Field: fieldName, - } } func compilePtr(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) { @@ -322,64 +307,21 @@ func compileFunc(typ *runtime.Type, strutName, fieldName string) (Decoder, error return newFuncDecoder(typ, strutName, fieldName), nil } -func removeConflictFields(fieldMap map[string]*structFieldSet, conflictedMap map[string]struct{}, dec *structDecoder, field reflect.StructField) { - for k, v := range dec.fieldMap { - if _, exists := conflictedMap[k]; exists { - // already conflicted key - continue - } - set, exists := fieldMap[k] - if !exists { - fieldSet := &structFieldSet{ - dec: v.dec, - offset: field.Offset + v.offset, - isTaggedKey: v.isTaggedKey, - key: k, - keyLen: int64(len(k)), - } - fieldMap[k] = fieldSet - lower := strings.ToLower(k) - if _, exists := fieldMap[lower]; !exists { - fieldMap[lower] = fieldSet - } +func typeToStructTags(typ *runtime.Type) runtime.StructTags { + tags := runtime.StructTags{} + fieldNum := typ.NumField() + for i := 0; i < fieldNum; i++ { + field := typ.Field(i) + if runtime.IsIgnoredStructField(field) { continue } - if set.isTaggedKey { - if v.isTaggedKey { - // conflict tag key - delete(fieldMap, k) - delete(fieldMap, strings.ToLower(k)) - conflictedMap[k] = struct{}{} - conflictedMap[strings.ToLower(k)] = struct{}{} - } - } else { - if v.isTaggedKey { - fieldSet := &structFieldSet{ - dec: v.dec, - offset: field.Offset + v.offset, - isTaggedKey: v.isTaggedKey, - key: k, - keyLen: int64(len(k)), - } - fieldMap[k] = fieldSet - lower := strings.ToLower(k) - if _, exists := fieldMap[lower]; !exists { - fieldMap[lower] = fieldSet - } - } else { - // conflict tag key - delete(fieldMap, k) - delete(fieldMap, strings.ToLower(k)) - conflictedMap[k] = struct{}{} - conflictedMap[strings.ToLower(k)] = struct{}{} - } - } + tags = append(tags, runtime.StructTagFromField(field)) } + return tags } func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) { fieldNum := typ.NumField() - conflictedMap := map[string]struct{}{} fieldMap := map[string]*structFieldSet{} typeptr := uintptr(unsafe.Pointer(typ)) if dec, exists := structTypeToDecoder[typeptr]; exists { @@ -388,6 +330,8 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo structDec := newStructDecoder(structName, fieldName, fieldMap) structTypeToDecoder[typeptr] = structDec structName = typ.Name() + tags := typeToStructTags(typ) + allFields := []*structFieldSet{} for i := 0; i < fieldNum; i++ { field := typ.Field(i) if runtime.IsIgnoredStructField(field) { @@ -405,7 +349,19 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo // recursive definition continue } - removeConflictFields(fieldMap, conflictedMap, stDec, field) + for k, v := range stDec.fieldMap { + if tags.ExistsKey(k) { + continue + } + fieldSet := &structFieldSet{ + dec: v.dec, + offset: field.Offset + v.offset, + isTaggedKey: v.isTaggedKey, + key: k, + keyLen: int64(len(k)), + } + allFields = append(allFields, fieldSet) + } } else if pdec, ok := dec.(*ptrDecoder); ok { contentDec := pdec.contentDecoder() if pdec.typ == typ { @@ -421,58 +377,18 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo } if dec, ok := contentDec.(*structDecoder); ok { for k, v := range dec.fieldMap { - if _, exists := conflictedMap[k]; exists { - // already conflicted key - continue - } - set, exists := fieldMap[k] - if !exists { - fieldSet := &structFieldSet{ - dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec), - offset: field.Offset, - isTaggedKey: v.isTaggedKey, - key: k, - keyLen: int64(len(k)), - err: fieldSetErr, - } - fieldMap[k] = fieldSet - lower := strings.ToLower(k) - if _, exists := fieldMap[lower]; !exists { - fieldMap[lower] = fieldSet - } + if tags.ExistsKey(k) { continue } - if set.isTaggedKey { - if v.isTaggedKey { - // conflict tag key - delete(fieldMap, k) - delete(fieldMap, strings.ToLower(k)) - conflictedMap[k] = struct{}{} - conflictedMap[strings.ToLower(k)] = struct{}{} - } - } else { - if v.isTaggedKey { - fieldSet := &structFieldSet{ - dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec), - offset: field.Offset, - isTaggedKey: v.isTaggedKey, - key: k, - keyLen: int64(len(k)), - err: fieldSetErr, - } - fieldMap[k] = fieldSet - lower := strings.ToLower(k) - if _, exists := fieldMap[lower]; !exists { - fieldMap[lower] = fieldSet - } - } else { - // conflict tag key - delete(fieldMap, k) - delete(fieldMap, strings.ToLower(k)) - conflictedMap[k] = struct{}{} - conflictedMap[strings.ToLower(k)] = struct{}{} - } + fieldSet := &structFieldSet{ + dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec), + offset: field.Offset, + isTaggedKey: v.isTaggedKey, + key: k, + keyLen: int64(len(k)), + err: fieldSetErr, } + allFields = append(allFields, fieldSet) } } } @@ -493,11 +409,15 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo key: key, keyLen: int64(len(key)), } - fieldMap[key] = fieldSet - lower := strings.ToLower(key) - if _, exists := fieldMap[lower]; !exists { - fieldMap[lower] = fieldSet - } + allFields = append(allFields, fieldSet) + } + } + for _, set := range filterDuplicatedFields(allFields) { + fieldMap[set.key] = set + lower := strings.ToLower(set.key) + if _, exists := fieldMap[lower]; !exists { + // first win + fieldMap[lower] = set } } delete(structTypeToDecoder, typeptr) @@ -505,6 +425,42 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo return structDec, nil } +func filterDuplicatedFields(allFields []*structFieldSet) []*structFieldSet { + fieldMap := map[string][]*structFieldSet{} + for _, field := range allFields { + fieldMap[field.key] = append(fieldMap[field.key], field) + } + duplicatedFieldMap := map[string]struct{}{} + for k, sets := range fieldMap { + sets = filterFieldSets(sets) + if len(sets) != 1 { + duplicatedFieldMap[k] = struct{}{} + } + } + + filtered := make([]*structFieldSet, 0, len(allFields)) + for _, field := range allFields { + if _, exists := duplicatedFieldMap[field.key]; exists { + continue + } + filtered = append(filtered, field) + } + return filtered +} + +func filterFieldSets(sets []*structFieldSet) []*structFieldSet { + if len(sets) == 1 { + return sets + } + filtered := make([]*structFieldSet, 0, len(sets)) + for _, set := range sets { + if set.isTaggedKey { + filtered = append(filtered, set) + } + } + return filtered +} + func implementsUnmarshalJSONType(typ *runtime.Type) bool { return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType) } diff --git a/internal/go-json/decoder/invalid.go b/internal/go-json/decoder/invalid.go new file mode 100644 index 0000000000..94f754e62e --- /dev/null +++ b/internal/go-json/decoder/invalid.go @@ -0,0 +1,45 @@ +package decoder + +import ( + "reflect" + "unsafe" + + "github.com/gofiber/fiber/v2/internal/go-json/errors" + "github.com/gofiber/fiber/v2/internal/go-json/runtime" +) + +type invalidDecoder struct { + typ *runtime.Type + kind reflect.Kind + structName string + fieldName string +} + +func newInvalidDecoder(typ *runtime.Type, structName, fieldName string) *invalidDecoder { + return &invalidDecoder{ + typ: typ, + kind: typ.Kind(), + structName: structName, + fieldName: fieldName, + } +} + +func (d *invalidDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) error { + return &errors.UnmarshalTypeError{ + Value: "object", + Type: runtime.RType2Type(d.typ), + Offset: s.totalOffset(), + Struct: d.structName, + Field: d.fieldName, + } +} + +func (d *invalidDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.Pointer) (int64, error) { + return 0, &errors.UnmarshalTypeError{ + Value: "object", + Type: runtime.RType2Type(d.typ), + Offset: cursor, + Struct: d.structName, + Field: d.fieldName, + } +}