Skip to content

Commit

Permalink
enforce recursion depth checking for unknown fields (#7210)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtattermusch committed Feb 14, 2020
1 parent d314101 commit 0e8f69e
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 8 deletions.
63 changes: 62 additions & 1 deletion csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs
Expand Up @@ -33,6 +33,7 @@
using System;
using System.IO;
using Google.Protobuf.TestProtos;
using Proto2 = Google.Protobuf.TestProtos.Proto2;
using NUnit.Framework;

namespace Google.Protobuf
Expand Down Expand Up @@ -337,6 +338,66 @@ public void MaliciousRecursion()
CodedInputStream input = CodedInputStream.CreateWithLimits(new MemoryStream(atRecursiveLimit.ToByteArray()), 1000000, CodedInputStream.DefaultRecursionLimit - 1);
Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(input));
}

private static byte[] MakeMaliciousRecursionUnknownFieldsPayload(int recursionDepth)
{
// generate recursively nested groups that will be parsed as unknown fields
int unknownFieldNumber = 14; // an unused field number
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
for (int i = 0; i < recursionDepth; i++)
{
output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.StartGroup));
}
for (int i = 0; i < recursionDepth; i++)
{
output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.EndGroup));
}
output.Flush();
return ms.ToArray();
}

[Test]
public void MaliciousRecursion_UnknownFields()
{
byte[] payloadAtRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit);
byte[] payloadBeyondRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit + 1);

Assert.DoesNotThrow(() => TestRecursiveMessage.Parser.ParseFrom(payloadAtRecursiveLimit));
Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payloadBeyondRecursiveLimit));
}

[Test]
public void ReadGroup_WrongEndGroupTag()
{
int groupFieldNumber = Proto2.TestAllTypes.OptionalGroupFieldNumber;

// write Proto2.TestAllTypes with "optional_group" set, but use wrong EndGroup closing tag
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
output.WriteTag(WireFormat.MakeTag(groupFieldNumber, WireFormat.WireType.StartGroup));
output.WriteGroup(new Proto2.TestAllTypes.Types.OptionalGroup { A = 12345 });
// end group with different field number
output.WriteTag(WireFormat.MakeTag(groupFieldNumber + 1, WireFormat.WireType.EndGroup));
output.Flush();
var payload = ms.ToArray();

Assert.Throws<InvalidProtocolBufferException>(() => Proto2.TestAllTypes.Parser.ParseFrom(payload));
}

[Test]
public void ReadGroup_UnknownFields_WrongEndGroupTag()
{
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
output.WriteTag(WireFormat.MakeTag(14, WireFormat.WireType.StartGroup));
// end group with different field number
output.WriteTag(WireFormat.MakeTag(15, WireFormat.WireType.EndGroup));
output.Flush();
var payload = ms.ToArray();

Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payload));
}

[Test]
public void SizeLimit()
Expand Down Expand Up @@ -735,4 +796,4 @@ public override int Read(byte[] buffer, int offset, int count)
}
}
}
}
}
31 changes: 29 additions & 2 deletions csharp/src/Google.Protobuf/CodedInputStream.cs
Expand Up @@ -307,10 +307,17 @@ internal void CheckReadEndOfStreamTag()
throw InvalidProtocolBufferException.MoreDataAvailable();
}
}
#endregion

internal void CheckLastTagWas(uint expectedTag)
{
if (lastTag != expectedTag) {
throw InvalidProtocolBufferException.InvalidEndTag();
}
}
#endregion

#region Reading of tags etc


/// <summary>
/// Peeks at the next field tag. This is like calling <see cref="ReadTag"/>, but the
/// tag is not consumed. (So a subsequent call to <see cref="ReadTag"/> will return the
Expand Down Expand Up @@ -636,7 +643,27 @@ public void ReadGroup(IMessage builder)
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
++recursionDepth;

uint tag = lastTag;
int fieldNumber = WireFormat.GetTagFieldNumber(tag);

builder.MergeFrom(this);
CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
--recursionDepth;
}

/// <summary>
/// Reads an embedded group unknown field from the stream.
/// </summary>
internal void ReadGroup(int fieldNumber, UnknownFieldSet set)
{
if (recursionDepth >= recursionLimit)
{
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
++recursionDepth;
set.MergeGroupFrom(this);
CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
--recursionDepth;
}

Expand Down
22 changes: 17 additions & 5 deletions csharp/src/Google.Protobuf/UnknownFieldSet.cs
Expand Up @@ -215,12 +215,8 @@ private bool MergeFieldFrom(CodedInputStream input)
}
case WireFormat.WireType.StartGroup:
{
uint endTag = WireFormat.MakeTag(number, WireFormat.WireType.EndGroup);
UnknownFieldSet set = new UnknownFieldSet();
while (input.ReadTag() != endTag)
{
set.MergeFieldFrom(input);
}
input.ReadGroup(number, set);
GetOrAddField(number).AddGroup(set);
return true;
}
Expand All @@ -233,6 +229,22 @@ private bool MergeFieldFrom(CodedInputStream input)
}
}

internal void MergeGroupFrom(CodedInputStream input)
{
while (true)
{
uint tag = input.ReadTag();
if (tag == 0)
{
break;
}
if (!MergeFieldFrom(input))
{
break;
}
}
}

/// <summary>
/// Create a new UnknownFieldSet if unknownFields is null.
/// Parse a single field from <paramref name="input"/> and merge it
Expand Down

0 comments on commit 0e8f69e

Please sign in to comment.