diff --git a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs index ba65b328e806..5f360ff46eaa 100644 --- a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs +++ b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs @@ -33,6 +33,7 @@ using System; using System.IO; using Google.Protobuf.TestProtos; +using Proto2 = Google.Protobuf.TestProtos.Proto2; using NUnit.Framework; namespace Google.Protobuf @@ -337,6 +338,66 @@ public void MaliciousRecursion() CodedInputStream input = CodedInputStream.CreateWithLimits(new MemoryStream(atRecursiveLimit.ToByteArray()), 1000000, CodedInputStream.DefaultRecursionLimit - 1); Assert.Throws(() => TestRecursiveMessage.Parser.ParseFrom(input)); } + + private static byte[] MakeMaliciousRecursionUnknownFieldsPayload(int recursionDepth) + { + // generate recursively nested groups that will be parsed as unknown fields + int unknownFieldNumber = 14; // an unused field number + MemoryStream ms = new MemoryStream(); + CodedOutputStream output = new CodedOutputStream(ms); + for (int i = 0; i < recursionDepth; i++) + { + output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.StartGroup)); + } + for (int i = 0; i < recursionDepth; i++) + { + output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.EndGroup)); + } + output.Flush(); + return ms.ToArray(); + } + + [Test] + public void MaliciousRecursion_UnknownFields() + { + byte[] payloadAtRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit); + byte[] payloadBeyondRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit + 1); + + Assert.DoesNotThrow(() => TestRecursiveMessage.Parser.ParseFrom(payloadAtRecursiveLimit)); + Assert.Throws(() => TestRecursiveMessage.Parser.ParseFrom(payloadBeyondRecursiveLimit)); + } + + [Test] + public void ReadGroup_WrongEndGroupTag() + { + int groupFieldNumber = Proto2.TestAllTypes.OptionalGroupFieldNumber; + + // write Proto2.TestAllTypes with "optional_group" set, but use wrong EndGroup closing tag + MemoryStream ms = new MemoryStream(); + CodedOutputStream output = new CodedOutputStream(ms); + output.WriteTag(WireFormat.MakeTag(groupFieldNumber, WireFormat.WireType.StartGroup)); + output.WriteGroup(new Proto2.TestAllTypes.Types.OptionalGroup { A = 12345 }); + // end group with different field number + output.WriteTag(WireFormat.MakeTag(groupFieldNumber + 1, WireFormat.WireType.EndGroup)); + output.Flush(); + var payload = ms.ToArray(); + + Assert.Throws(() => Proto2.TestAllTypes.Parser.ParseFrom(payload)); + } + + [Test] + public void ReadGroup_UnknownFields_WrongEndGroupTag() + { + MemoryStream ms = new MemoryStream(); + CodedOutputStream output = new CodedOutputStream(ms); + output.WriteTag(WireFormat.MakeTag(14, WireFormat.WireType.StartGroup)); + // end group with different field number + output.WriteTag(WireFormat.MakeTag(15, WireFormat.WireType.EndGroup)); + output.Flush(); + var payload = ms.ToArray(); + + Assert.Throws(() => TestRecursiveMessage.Parser.ParseFrom(payload)); + } [Test] public void SizeLimit() @@ -735,4 +796,4 @@ public override int Read(byte[] buffer, int offset, int count) } } } -} \ No newline at end of file +} diff --git a/csharp/src/Google.Protobuf/CodedInputStream.cs b/csharp/src/Google.Protobuf/CodedInputStream.cs index bea6bff34f2b..b9feda53cbc5 100644 --- a/csharp/src/Google.Protobuf/CodedInputStream.cs +++ b/csharp/src/Google.Protobuf/CodedInputStream.cs @@ -307,10 +307,17 @@ internal void CheckReadEndOfStreamTag() throw InvalidProtocolBufferException.MoreDataAvailable(); } } - #endregion + internal void CheckLastTagWas(uint expectedTag) + { + if (lastTag != expectedTag) { + throw InvalidProtocolBufferException.InvalidEndTag(); + } + } + #endregion + #region Reading of tags etc - + /// /// Peeks at the next field tag. This is like calling , but the /// tag is not consumed. (So a subsequent call to will return the @@ -636,7 +643,27 @@ public void ReadGroup(IMessage builder) throw InvalidProtocolBufferException.RecursionLimitExceeded(); } ++recursionDepth; + + uint tag = lastTag; + int fieldNumber = WireFormat.GetTagFieldNumber(tag); + builder.MergeFrom(this); + CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup)); + --recursionDepth; + } + + /// + /// Reads an embedded group unknown field from the stream. + /// + internal void ReadGroup(int fieldNumber, UnknownFieldSet set) + { + if (recursionDepth >= recursionLimit) + { + throw InvalidProtocolBufferException.RecursionLimitExceeded(); + } + ++recursionDepth; + set.MergeGroupFrom(this); + CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup)); --recursionDepth; } diff --git a/csharp/src/Google.Protobuf/UnknownFieldSet.cs b/csharp/src/Google.Protobuf/UnknownFieldSet.cs index d136cf1e6572..7a2b6a00d24a 100644 --- a/csharp/src/Google.Protobuf/UnknownFieldSet.cs +++ b/csharp/src/Google.Protobuf/UnknownFieldSet.cs @@ -215,12 +215,8 @@ private bool MergeFieldFrom(CodedInputStream input) } case WireFormat.WireType.StartGroup: { - uint endTag = WireFormat.MakeTag(number, WireFormat.WireType.EndGroup); UnknownFieldSet set = new UnknownFieldSet(); - while (input.ReadTag() != endTag) - { - set.MergeFieldFrom(input); - } + input.ReadGroup(number, set); GetOrAddField(number).AddGroup(set); return true; } @@ -233,6 +229,22 @@ private bool MergeFieldFrom(CodedInputStream input) } } + internal void MergeGroupFrom(CodedInputStream input) + { + while (true) + { + uint tag = input.ReadTag(); + if (tag == 0) + { + break; + } + if (!MergeFieldFrom(input)) + { + break; + } + } + } + /// /// Create a new UnknownFieldSet if unknownFields is null. /// Parse a single field from and merge it