Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport: enforce recursion depth checking for unknown fields (v3.11.x) #7210

Merged
merged 1 commit into from Feb 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 62 additions & 1 deletion csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs
Expand Up @@ -33,6 +33,7 @@
using System;
using System.IO;
using Google.Protobuf.TestProtos;
using Proto2 = Google.Protobuf.TestProtos.Proto2;
using NUnit.Framework;

namespace Google.Protobuf
Expand Down Expand Up @@ -337,6 +338,66 @@ public void MaliciousRecursion()
CodedInputStream input = CodedInputStream.CreateWithLimits(new MemoryStream(atRecursiveLimit.ToByteArray()), 1000000, CodedInputStream.DefaultRecursionLimit - 1);
Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(input));
}

private static byte[] MakeMaliciousRecursionUnknownFieldsPayload(int recursionDepth)
{
// generate recursively nested groups that will be parsed as unknown fields
int unknownFieldNumber = 14; // an unused field number
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
for (int i = 0; i < recursionDepth; i++)
{
output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.StartGroup));
}
for (int i = 0; i < recursionDepth; i++)
{
output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.EndGroup));
}
output.Flush();
return ms.ToArray();
}

[Test]
public void MaliciousRecursion_UnknownFields()
{
byte[] payloadAtRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit);
byte[] payloadBeyondRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit + 1);

Assert.DoesNotThrow(() => TestRecursiveMessage.Parser.ParseFrom(payloadAtRecursiveLimit));
Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payloadBeyondRecursiveLimit));
}

[Test]
public void ReadGroup_WrongEndGroupTag()
{
int groupFieldNumber = Proto2.TestAllTypes.OptionalGroupFieldNumber;

// write Proto2.TestAllTypes with "optional_group" set, but use wrong EndGroup closing tag
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
output.WriteTag(WireFormat.MakeTag(groupFieldNumber, WireFormat.WireType.StartGroup));
output.WriteGroup(new Proto2.TestAllTypes.Types.OptionalGroup { A = 12345 });
// end group with different field number
output.WriteTag(WireFormat.MakeTag(groupFieldNumber + 1, WireFormat.WireType.EndGroup));
output.Flush();
var payload = ms.ToArray();

Assert.Throws<InvalidProtocolBufferException>(() => Proto2.TestAllTypes.Parser.ParseFrom(payload));
}

[Test]
public void ReadGroup_UnknownFields_WrongEndGroupTag()
{
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
output.WriteTag(WireFormat.MakeTag(14, WireFormat.WireType.StartGroup));
// end group with different field number
output.WriteTag(WireFormat.MakeTag(15, WireFormat.WireType.EndGroup));
output.Flush();
var payload = ms.ToArray();

Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payload));
}

[Test]
public void SizeLimit()
Expand Down Expand Up @@ -735,4 +796,4 @@ public override int Read(byte[] buffer, int offset, int count)
}
}
}
}
}
31 changes: 29 additions & 2 deletions csharp/src/Google.Protobuf/CodedInputStream.cs
Expand Up @@ -307,10 +307,17 @@ internal void CheckReadEndOfStreamTag()
throw InvalidProtocolBufferException.MoreDataAvailable();
}
}
#endregion

internal void CheckLastTagWas(uint expectedTag)
{
if (lastTag != expectedTag) {
throw InvalidProtocolBufferException.InvalidEndTag();
}
}
#endregion

#region Reading of tags etc

/// <summary>
/// Peeks at the next field tag. This is like calling <see cref="ReadTag"/>, but the
/// tag is not consumed. (So a subsequent call to <see cref="ReadTag"/> will return the
Expand Down Expand Up @@ -636,7 +643,27 @@ public void ReadGroup(IMessage builder)
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
++recursionDepth;

uint tag = lastTag;
int fieldNumber = WireFormat.GetTagFieldNumber(tag);

builder.MergeFrom(this);
CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
--recursionDepth;
}

/// <summary>
/// Reads an embedded group unknown field from the stream.
/// </summary>
internal void ReadGroup(int fieldNumber, UnknownFieldSet set)
{
if (recursionDepth >= recursionLimit)
{
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
++recursionDepth;
set.MergeGroupFrom(this);
CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
--recursionDepth;
}

Expand Down
22 changes: 17 additions & 5 deletions csharp/src/Google.Protobuf/UnknownFieldSet.cs
Expand Up @@ -215,12 +215,8 @@ private bool MergeFieldFrom(CodedInputStream input)
}
case WireFormat.WireType.StartGroup:
{
uint endTag = WireFormat.MakeTag(number, WireFormat.WireType.EndGroup);
UnknownFieldSet set = new UnknownFieldSet();
while (input.ReadTag() != endTag)
{
set.MergeFieldFrom(input);
}
input.ReadGroup(number, set);
GetOrAddField(number).AddGroup(set);
return true;
}
Expand All @@ -233,6 +229,22 @@ private bool MergeFieldFrom(CodedInputStream input)
}
}

internal void MergeGroupFrom(CodedInputStream input)
{
while (true)
{
uint tag = input.ReadTag();
if (tag == 0)
{
break;
}
if (!MergeFieldFrom(input))
{
break;
}
}
}

/// <summary>
/// Create a new UnknownFieldSet if unknownFields is null.
/// Parse a single field from <paramref name="input"/> and merge it
Expand Down