Skip to content

Commit

Permalink
Merge pull request #946 from AArnott/fix924
Browse files Browse the repository at this point in the history
Fix MessagePackStreamReader reading when string or binary headers are incomplete
  • Loading branch information
AArnott committed Jun 12, 2020
2 parents bac7020 + 7c07da2 commit c2ee6bc
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ public MessagePackReader Clone(in ReadOnlySequence<byte> readOnlySequence) => ne
/// <remarks>
/// The entire primitive is skipped, including content of maps or arrays, or any other type with payloads.
/// To get the raw MessagePack sequence that was skipped, use <see cref="ReadRaw()"/> instead.
/// WARNING: when false is returned, the position of the reader is undefined.
/// </remarks>
internal bool TrySkip()
{
Expand Down Expand Up @@ -187,11 +188,11 @@ internal bool TrySkip()
case MessagePackCode.Str8:
case MessagePackCode.Str16:
case MessagePackCode.Str32:
return this.reader.TryAdvance(this.GetStringLengthInBytes());
return this.TryGetStringLengthInBytes(out int length) && this.reader.TryAdvance(length);
case MessagePackCode.Bin8:
case MessagePackCode.Bin16:
case MessagePackCode.Bin32:
return this.reader.TryAdvance(this.GetBytesLength());
return this.TryGetBytesLength(out length) && this.reader.TryAdvance(length);
case MessagePackCode.FixExt1:
case MessagePackCode.FixExt2:
case MessagePackCode.FixExt4:
Expand Down Expand Up @@ -220,7 +221,7 @@ internal bool TrySkip()

if (code >= MessagePackCode.MinFixStr && code <= MessagePackCode.MaxFixStr)
{
return this.reader.TryAdvance(this.GetStringLengthInBytes());
return this.TryGetStringLengthInBytes(out length) && this.reader.TryAdvance(length);
}

// We don't actually expect to ever hit this point, since every code is supported.
Expand Down Expand Up @@ -956,77 +957,135 @@ private static void ThrowInsufficientBufferUnless(bool condition)

private int GetBytesLength()
{
ThrowInsufficientBufferUnless(this.reader.TryRead(out byte code));
ThrowInsufficientBufferUnless(this.TryGetBytesLength(out int length));
return length;
}

private bool TryGetBytesLength(out int length)
{
if (!this.reader.TryRead(out byte code))
{
length = 0;
return false;
}

// In OldSpec mode, Bin didn't exist, so Str was used. Str8 didn't exist either.
int length;
switch (code)
{
case MessagePackCode.Bin8:
ThrowInsufficientBufferUnless(this.reader.TryRead(out byte byteLength));
length = byteLength;
if (this.reader.TryRead(out byte byteLength))
{
length = byteLength;
return true;
}

break;
case MessagePackCode.Bin16:
case MessagePackCode.Str16: // OldSpec compatibility
ThrowInsufficientBufferUnless(this.reader.TryReadBigEndian(out short shortLength));
length = unchecked((ushort)shortLength);
if (this.reader.TryReadBigEndian(out short shortLength))
{
length = unchecked((ushort)shortLength);
return true;
}

break;
case MessagePackCode.Bin32:
case MessagePackCode.Str32: // OldSpec compatibility
ThrowInsufficientBufferUnless(this.reader.TryReadBigEndian(out length));
if (this.reader.TryReadBigEndian(out length))
{
return true;
}

break;
default:
// OldSpec compatibility
if (code >= MessagePackCode.MinFixStr && code <= MessagePackCode.MaxFixStr)
{
length = code & 0x1F;
break;
return true;
}

throw ThrowInvalidCode(code);
}

return length;
length = 0;
return false;
}

/// <summary>
/// Gets the length of the next string.
/// </summary>
/// <returns>The length of the next string.</returns>
/// <param name="length">Receives the length of the next string, if there were enough bytes to read it.</param>
/// <returns><c>true</c> if there were enough bytes to read the length of the next string; <c>false</c> otherwise.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private int GetStringLengthInBytes()
private bool TryGetStringLengthInBytes(out int length)
{
ThrowInsufficientBufferUnless(this.reader.TryRead(out byte code));
if (!this.reader.TryRead(out byte code))
{
length = 0;
return false;
}

if (code >= MessagePackCode.MinFixStr && code <= MessagePackCode.MaxFixStr)
{
return code & 0x1F;
length = code & 0x1F;
return true;
}

return this.GetStringLengthInBytesSlow(code);
return this.TryGetStringLengthInBytesSlow(code, out length);
}

private int GetStringLengthInBytesSlow(byte code)
/// <summary>
/// Gets the length of the next string.
/// </summary>
/// <returns>The length of the next string.</returns>
private int GetStringLengthInBytes()
{
ThrowInsufficientBufferUnless(this.TryGetStringLengthInBytes(out int length));
return length;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private bool TryGetStringLengthInBytesSlow(byte code, out int length)
{
switch (code)
{
case MessagePackCode.Str8:
ThrowInsufficientBufferUnless(this.reader.TryRead(out byte byteValue));
return byteValue;
if (this.reader.TryRead(out byte byteValue))
{
length = byteValue;
return true;
}

break;
case MessagePackCode.Str16:
ThrowInsufficientBufferUnless(this.reader.TryReadBigEndian(out short shortValue));
return unchecked((ushort)shortValue);
if (this.reader.TryReadBigEndian(out short shortValue))
{
length = unchecked((ushort)shortValue);
return true;
}

break;
case MessagePackCode.Str32:
ThrowInsufficientBufferUnless(this.reader.TryReadBigEndian(out int intValue));
return intValue;
if (this.reader.TryReadBigEndian(out int intValue))
{
length = intValue;
return true;
}

break;
default:
if (code >= MessagePackCode.MinFixStr && code <= MessagePackCode.MaxFixStr)
{
return code & 0x1F;
length = code & 0x1F;
return true;
}

throw ThrowInvalidCode(code);
}

length = 0;
return false;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ public MessagePackStreamReaderTests()
positions.Add(sequence.AsReadOnlySequence.End);

// Second message is more interesting.
writer.WriteArrayHeader(2);
writer.WriteArrayHeader(4);
writer.Write("Hi");
writer.Write("There");
writer.Write("There + " + new string('3', 300)); // a long enough string that a multi-byte header is required.
writer.Write(new byte[300]); // a long enough byte array that a multi-byte header is required.
writer.WriteExtensionFormat(new ExtensionResult(1, new byte[300]));
writer.Flush();
positions.Add(sequence.AsReadOnlySequence.End);

Expand Down

0 comments on commit c2ee6bc

Please sign in to comment.