From bfcd6476a38e41247d6bb43dc8f00b23ec9fffc2 Mon Sep 17 00:00:00 2001 From: Josh Humphries Date: Thu, 21 Dec 2023 13:55:43 -0500 Subject: [PATCH] protojson: configurable recursion limit when unmarshalling Fixes golang/protobuf#1583 and golang/protobuf#1584 Limits the level of recursion when parsing JSON to avoid fatal stack overflow errors if input uses pathologically deep nesting. This is already a feature of the binary format, and this adds that feature to the JSON format. This also re-implements how JSON values are discarded to be more efficient (and not use recursion). Change-Id: I4026b739abe0335387209a43645f65e4b6e43409 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/552255 LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase Auto-Submit: Lasse Folger Reviewed-by: Lasse Folger --- encoding/protojson/decode.go | 12 ++++ encoding/protojson/decode_test.go | 81 ++++++++++++++++++++++++++ encoding/protojson/well_known_types.go | 55 +++++------------ 3 files changed, 109 insertions(+), 39 deletions(-) diff --git a/encoding/protojson/decode.go b/encoding/protojson/decode.go index 503390627..f47902371 100644 --- a/encoding/protojson/decode.go +++ b/encoding/protojson/decode.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" + "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/internal/encoding/json" "google.golang.org/protobuf/internal/encoding/messageset" "google.golang.org/protobuf/internal/errors" @@ -47,6 +48,10 @@ type UnmarshalOptions struct { protoregistry.MessageTypeResolver protoregistry.ExtensionTypeResolver } + + // RecursionLimit limits how deeply messages may be nested. + // If zero, a default limit is applied. + RecursionLimit int } // Unmarshal reads the given []byte and populates the given [proto.Message] @@ -67,6 +72,9 @@ func (o UnmarshalOptions) unmarshal(b []byte, m proto.Message) error { if o.Resolver == nil { o.Resolver = protoregistry.GlobalTypes } + if o.RecursionLimit == 0 { + o.RecursionLimit = protowire.DefaultRecursionLimit + } dec := decoder{json.NewDecoder(b), o} if err := dec.unmarshalMessage(m.ProtoReflect(), false); err != nil { @@ -114,6 +122,10 @@ func (d decoder) syntaxError(pos int, f string, x ...interface{}) error { // unmarshalMessage unmarshals a message into the given protoreflect.Message. func (d decoder) unmarshalMessage(m protoreflect.Message, skipTypeURL bool) error { + d.opts.RecursionLimit-- + if d.opts.RecursionLimit < 0 { + return errors.New("exceeded max recursion depth") + } if unmarshal := wellKnownTypeUnmarshaler(m.Descriptor().FullName()); unmarshal != nil { return unmarshal(d, m) } diff --git a/encoding/protojson/decode_test.go b/encoding/protojson/decode_test.go index 417da1dea..4355b7546 100644 --- a/encoding/protojson/decode_test.go +++ b/encoding/protojson/decode_test.go @@ -2489,6 +2489,87 @@ func TestUnmarshal(t *testing.T) { inputText: `{"weak_message1":{"a":1}, "weak_message2":{"a":1}}`, wantErr: `unknown field "weak_message2"`, // weak_message2 is unknown since the package containing it is not imported skip: !flags.ProtoLegacy, + }, { + desc: "just at recursion limit: nested messages", + inputMessage: &testpb.TestAllTypes{}, + inputText: `{"optionalNestedMessage":{"corecursive":{"optionalNestedMessage":{"corecursive":{}}}}}`, + umo: protojson.UnmarshalOptions{RecursionLimit: 5}, + }, { + desc: "exceed recursion limit: nested messages", + inputMessage: &testpb.TestAllTypes{}, + inputText: `{"optionalNestedMessage":{"corecursive":{"optionalNestedMessage":{"corecursive":{"optionalNestedMessage":{}}}}}}`, + umo: protojson.UnmarshalOptions{RecursionLimit: 5}, + wantErr: "exceeded max recursion depth", + }, { + + desc: "just at recursion limit: maps", + inputMessage: &testpb.TestAllTypes{}, + inputText: `{"mapStringNestedMessage":{"key1":{"corecursive":{"mapStringNestedMessage":{}}}}}`, + umo: protojson.UnmarshalOptions{RecursionLimit: 3}, + }, { + desc: "exceed recursion limit: maps", + inputMessage: &testpb.TestAllTypes{}, + inputText: `{"mapStringNestedMessage":{"key1":{"corecursive":{"mapStringNestedMessage":{}}}}}`, + umo: protojson.UnmarshalOptions{RecursionLimit: 2}, + wantErr: "exceeded max recursion depth", + }, { + desc: "just at recursion limit: arrays", + inputMessage: &testpb.TestAllTypes{}, + inputText: `{"repeatedNestedMessage":[{"corecursive":{"repeatedInt32":[1,2,3]}}]}`, + umo: protojson.UnmarshalOptions{RecursionLimit: 3}, + }, { + desc: "exceed recursion limit: arrays", + inputMessage: &testpb.TestAllTypes{}, + inputText: `{"repeatedNestedMessage":[{"corecursive":{"repeatedNestedMessage":[{}]}}]}`, + umo: protojson.UnmarshalOptions{RecursionLimit: 3}, + wantErr: "exceeded max recursion depth", + }, { + desc: "just at recursion limit: value", + inputMessage: &structpb.Value{}, + inputText: `{"a":{"b":{"c":{"d":{}}}}}`, + umo: protojson.UnmarshalOptions{RecursionLimit: 5}, + }, { + desc: "exceed recursion limit: value", + inputMessage: &structpb.Value{}, + inputText: `{"a":{"b":{"c":{"d":{"e":[]}}}}}`, + umo: protojson.UnmarshalOptions{RecursionLimit: 5}, + wantErr: "exceeded max recursion depth", + }, { + desc: "just at recursion limit: list value", + inputMessage: &structpb.ListValue{}, + inputText: `[[[[[1, 2, 3, 4]]]]]`, + // Note: the JSON appears to have recursion of only 5. But it's actually 6 because the + // first leaf value (1) is actually a message (google.protobuf.Value), even though the + // JSON doesn't use an open brace. + umo: protojson.UnmarshalOptions{RecursionLimit: 6}, + }, { + desc: "exceed recursion limit: list value", + inputMessage: &structpb.ListValue{}, + inputText: `[[[[[1, 2, 3, 4, ["a", "b"]]]]]]`, + umo: protojson.UnmarshalOptions{RecursionLimit: 6}, + wantErr: "exceeded max recursion depth", + }, { + desc: "just at recursion limit: struct value", + inputMessage: &structpb.Struct{}, + inputText: `{"a":{"b":{"c":{"d":{}}}}}`, + umo: protojson.UnmarshalOptions{RecursionLimit: 5}, + }, { + desc: "exceed recursion limit: struct value", + inputMessage: &structpb.Struct{}, + inputText: `{"a":{"b":{"c":{"d":{"e":{}]}}}}}`, + umo: protojson.UnmarshalOptions{RecursionLimit: 5}, + wantErr: "exceeded max recursion depth", + }, { + desc: "just at recursion limit: skip unknown", + inputMessage: &testpb.TestAllTypes{}, + inputText: `{"foo":{"bar":[{"baz":{}}]}}`, + umo: protojson.UnmarshalOptions{RecursionLimit: 5, DiscardUnknown: true}, + }, { + desc: "exceed recursion limit: skip unknown", + inputMessage: &testpb.TestAllTypes{}, + inputText: `{"foo":{"bar":[{"baz":[{}]]}}`, + umo: protojson.UnmarshalOptions{RecursionLimit: 5, DiscardUnknown: true}, + wantErr: "exceeded max recursion depth", }} for _, tt := range tests { diff --git a/encoding/protojson/well_known_types.go b/encoding/protojson/well_known_types.go index 6c37d4174..25329b769 100644 --- a/encoding/protojson/well_known_types.go +++ b/encoding/protojson/well_known_types.go @@ -176,7 +176,7 @@ func (d decoder) unmarshalAny(m protoreflect.Message) error { // Use another decoder to parse the unread bytes for @type field. This // avoids advancing a read from current decoder because the current JSON // object may contain the fields of the embedded type. - dec := decoder{d.Clone(), UnmarshalOptions{}} + dec := decoder{d.Clone(), UnmarshalOptions{RecursionLimit: d.opts.RecursionLimit}} tok, err := findTypeURL(dec) switch err { case errEmptyObject: @@ -308,48 +308,25 @@ Loop: // array) in order to advance the read to the next JSON value. It relies on // the decoder returning an error if the types are not in valid sequence. func (d decoder) skipJSONValue() error { - tok, err := d.Read() - if err != nil { - return err - } - // Only need to continue reading for objects and arrays. - switch tok.Kind() { - case json.ObjectOpen: - for { - tok, err := d.Read() - if err != nil { - return err - } - switch tok.Kind() { - case json.ObjectClose: - return nil - case json.Name: - // Skip object field value. - if err := d.skipJSONValue(); err != nil { - return err - } - } + var open int + for { + tok, err := d.Read() + if err != nil { + return err } - - case json.ArrayOpen: - for { - tok, err := d.Peek() - if err != nil { - return err - } - switch tok.Kind() { - case json.ArrayClose: - d.Read() - return nil - default: - // Skip array item. - if err := d.skipJSONValue(); err != nil { - return err - } + switch tok.Kind() { + case json.ObjectClose, json.ArrayClose: + open-- + case json.ObjectOpen, json.ArrayOpen: + open++ + if open > d.opts.RecursionLimit { + return errors.New("exceeded max recursion depth") } } + if open == 0 { + return nil + } } - return nil } // unmarshalAnyValue unmarshals the given custom-type message from the JSON