Skip to content

Commit

Permalink
all: implement depth limit for unmarshaling
Browse files Browse the repository at this point in the history
+ This change introduce a default and configurable depth limit for
  proto.Unmarshal. If a message is nested deeper than the limit,
  unmarshaling will fail. There are two ways to nest messages. Either by
  having fields which are message types itself or by using groups.
+ The default limit is 10,000 for now. This might change in the future
  to align it with other language implementation (C++ and Java use 100
  as limit).
+ If pure groups (groups that don't contain message fields) are nested
  deeper than the default limit the unmarshaling fails with:
  proto: cannot parse invalid wire-format data
+ Note: the configured limit does not apply to pure groups.
+ This change is introduced to improve security and robustness. Because
  unmarshaling is implemented using recursion it can lead to stack overflows
  for certain inputs. The introduced limit protects against this.
+ A secondary motivation for this limit is the alignment with other
  languages. Protocol buffers are a language interoperability mechanism
  and thus either all implementations should accept the input or all
  implementation should reject the input.

Change-Id: I14bdb44d06e4bd1aa90d6336c2cf6446003b2037
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/385854
Trust: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Trust: Damien Neil <dneil@google.com>
Reviewed-by: Nicolas Hillegeer <aktau@google.com>
Reviewed-by: Chressie Himpel <chressie@google.com>
  • Loading branch information
lfolger authored and neild committed Feb 17, 2022
1 parent e5db296 commit 3992ea8
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 7 deletions.
19 changes: 14 additions & 5 deletions encoding/protowire/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ import (
type Number int32

const (
MinValidNumber Number = 1
FirstReservedNumber Number = 19000
LastReservedNumber Number = 19999
MaxValidNumber Number = 1<<29 - 1
MinValidNumber Number = 1
FirstReservedNumber Number = 19000
LastReservedNumber Number = 19999
MaxValidNumber Number = 1<<29 - 1
DefaultRecursionLimit = 10000
)

// IsValid reports whether the field number is semantically valid.
Expand Down Expand Up @@ -55,6 +56,7 @@ const (
errCodeOverflow
errCodeReserved
errCodeEndGroup
errCodeRecursionDepth
)

var (
Expand Down Expand Up @@ -112,6 +114,10 @@ func ConsumeField(b []byte) (Number, Type, int) {
// When parsing a group, the length includes the end group marker and
// the end group is verified to match the starting field number.
func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
return consumeFieldValueD(num, typ, b, DefaultRecursionLimit)
}

func consumeFieldValueD(num Number, typ Type, b []byte, depth int) (n int) {
switch typ {
case VarintType:
_, n = ConsumeVarint(b)
Expand All @@ -126,6 +132,9 @@ func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
_, n = ConsumeBytes(b)
return n
case StartGroupType:
if depth < 0 {
return errCodeRecursionDepth
}
n0 := len(b)
for {
num2, typ2, n := ConsumeTag(b)
Expand All @@ -140,7 +149,7 @@ func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
return n0 - len(b)
}

n = ConsumeFieldValue(num2, typ2, b)
n = consumeFieldValueD(num2, typ2, b, depth-1)
if n < 0 {
return n // forward error code
}
Expand Down
2 changes: 1 addition & 1 deletion internal/fuzz/wirefuzz/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func Fuzz(data []byte) (score int) {
// Unmarshal, Validate, and CheckInitialized should agree about initialization.
checkInit := proto.CheckInitialized(m1) == nil
methods := m1.ProtoReflect().ProtoMethods()
in := piface.UnmarshalInput{Message: mt.New(), Resolver: protoregistry.GlobalTypes}
in := piface.UnmarshalInput{Message: mt.New(), Resolver: protoregistry.GlobalTypes, Depth: 10000}
if checkInit {
// If the message initialized, the both Unmarshal and Validate should
// report it as such. False negatives are tolerated, but have a
Expand Down
8 changes: 8 additions & 0 deletions internal/impl/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ import (
)

var errDecode = errors.New("cannot parse invalid wire-format data")
var errRecursionDepth = errors.New("exceeded maximum recursion depth")

type unmarshalOptions struct {
flags protoiface.UnmarshalInputFlags
resolver interface {
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
}
depth int
}

func (o unmarshalOptions) Options() proto.UnmarshalOptions {
Expand All @@ -44,6 +46,7 @@ func (o unmarshalOptions) IsDefault() bool {

var lazyUnmarshalOptions = unmarshalOptions{
resolver: preg.GlobalTypes,
depth: protowire.DefaultRecursionLimit,
}

type unmarshalOutput struct {
Expand All @@ -62,6 +65,7 @@ func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutp
out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
flags: in.Flags,
resolver: in.Resolver,
depth: in.Depth,
})
var flags piface.UnmarshalOutputFlags
if out.initialized {
Expand All @@ -82,6 +86,10 @@ var errUnknown = errors.New("unknown")

func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
mi.init()
opts.depth--
if opts.depth < 0 {
return out, errRecursionDepth
}
if flags.ProtoLegacy && mi.isMessageSet {
return unmarshalMessageSet(mi, b, p, opts)
}
Expand Down
17 changes: 16 additions & 1 deletion proto/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,25 @@ type UnmarshalOptions struct {
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
}

// RecursionLimit limits how deeply messages may be nested.
// If zero, a default limit is applied.
RecursionLimit int
}

// Unmarshal parses the wire-format message in b and places the result in m.
// The provided message must be mutable (e.g., a non-nil pointer to a message).
func Unmarshal(b []byte, m Message) error {
_, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
_, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
return err
}

// Unmarshal parses the wire-format message in b and places the result in m.
// The provided message must be mutable (e.g., a non-nil pointer to a message).
func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
if o.RecursionLimit == 0 {
o.RecursionLimit = protowire.DefaultRecursionLimit
}
_, err := o.unmarshal(b, m.ProtoReflect())
return err
}
Expand All @@ -63,6 +70,9 @@ func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
// This method permits fine-grained control over the unmarshaler.
// Most users should use Unmarshal instead.
func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
if o.RecursionLimit == 0 {
o.RecursionLimit = protowire.DefaultRecursionLimit
}
return o.unmarshal(in.Buf, in.Message)
}

Expand All @@ -86,12 +96,17 @@ func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out proto
Message: m,
Buf: b,
Resolver: o.Resolver,
Depth: o.RecursionLimit,
}
if o.DiscardUnknown {
in.Flags |= protoiface.UnmarshalDiscardUnknown
}
out, err = methods.Unmarshal(in)
} else {
o.RecursionLimit--
if o.RecursionLimit < 0 {
return out, errors.New("exceeded max recursion depth")
}
err = o.unmarshalMessageSlow(b, m)
}
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions reflect/protoreflect/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type (
FindExtensionByName(field FullName) (ExtensionType, error)
FindExtensionByNumber(message FullName, field FieldNumber) (ExtensionType, error)
}
Depth int
}
unmarshalOutput = struct {
pragma.NoUnkeyedLiterals
Expand Down
1 change: 1 addition & 0 deletions runtime/protoiface/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ type UnmarshalInput = struct {
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
}
Depth int
}

// UnmarshalOutput is output from the Unmarshal method.
Expand Down

0 comments on commit 3992ea8

Please sign in to comment.