Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize MapField serialization by removing MessageAdapter #8143

Merged
merged 4 commits into from Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions csharp/src/Google.Protobuf.Test/Collections/MapFieldTest.cs
Expand Up @@ -611,6 +611,32 @@ public void AddEntriesFrom_CodedInputStream()
Assert.IsTrue(input.IsAtEnd);
}

[Test]
public void AddEntriesFrom_CodedInputStream_MissingKey()
{
// map will have string key and string value
var keyTag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited);
var valueTag = WireFormat.MakeTag(2, WireFormat.WireType.LengthDelimited);

var memoryStream = new MemoryStream();
var output = new CodedOutputStream(memoryStream);
output.WriteLength(11); // total of valueTag + value
output.WriteTag(valueTag);
output.WriteString("the_value");
output.Flush();

Console.WriteLine(BitConverter.ToString(memoryStream.ToArray()));

var field = new MapField<string, string>();
var mapCodec = new MapField<string, string>.Codec(FieldCodec.ForString(keyTag, ""), FieldCodec.ForString(valueTag, ""), 10);
var input = new CodedInputStream(memoryStream.ToArray());

field.AddEntriesFrom(input, mapCodec);
CollectionAssert.AreEquivalent(new[] { "" }, field.Keys);
CollectionAssert.AreEquivalent(new[] { "the_value" }, field.Values);
Assert.IsTrue(input.IsAtEnd);
}

#if !NET35
[Test]
public void IDictionaryKeys_Equals_IReadOnlyDictionaryKeys()
Expand Down
123 changes: 22 additions & 101 deletions csharp/src/Google.Protobuf/Collections/MapField.cs
Expand Up @@ -448,12 +448,10 @@ public void AddEntriesFrom(CodedInputStream input, Codec codec)
[SecuritySafeCritical]
public void AddEntriesFrom(ref ParseContext ctx, Codec codec)
{
var adapter = new Codec.MessageAdapter(codec);
do
{
adapter.Reset();
ctx.ReadMessage(adapter);
this[adapter.Key] = adapter.Value;
KeyValuePair<TKey, TValue> entry = ParsingPrimitivesMessages.ReadMapEntry(ref ctx, codec);
this[entry.Key] = entry.Value;
} while (ParsingPrimitives.MaybeConsumeTag(ref ctx.buffer, ref ctx.state, codec.MapTag));
}

Expand Down Expand Up @@ -485,13 +483,13 @@ public void WriteTo(CodedOutputStream output, Codec codec)
[SecuritySafeCritical]
public void WriteTo(ref WriteContext ctx, Codec codec)
{
var message = new Codec.MessageAdapter(codec);
foreach (var entry in list)
{
message.Key = entry.Key;
message.Value = entry.Value;
ctx.WriteTag(codec.MapTag);
ctx.WriteMessage(message);

WritingPrimitives.WriteLength(ref ctx.buffer, ref ctx.state, CalculateEntrySize(codec, entry));
codec.KeyCodec.WriteTagAndValue(ref ctx, entry.Key);
codec.ValueCodec.WriteTagAndValue(ref ctx, entry.Value);
}
}

Expand All @@ -506,18 +504,22 @@ public int CalculateSize(Codec codec)
{
return 0;
}
var message = new Codec.MessageAdapter(codec);
int size = 0;
foreach (var entry in list)
{
message.Key = entry.Key;
message.Value = entry.Value;
int entrySize = CalculateEntrySize(codec, entry);

size += CodedOutputStream.ComputeRawVarint32Size(codec.MapTag);
size += CodedOutputStream.ComputeMessageSize(message);
size += CodedOutputStream.ComputeLengthSize(entrySize) + entrySize;
}
return size;
}

private static int CalculateEntrySize(Codec codec, KeyValuePair<TKey, TValue> entry)
{
return codec.KeyCodec.CalculateSizeWithTag(entry.Key) + codec.ValueCodec.CalculateSizeWithTag(entry.Value);
}

/// <summary>
/// Returns a string representation of this repeated field, in the same
/// way as it would be represented by the default JSON formatter.
Expand Down Expand Up @@ -655,100 +657,19 @@ public Codec(FieldCodec<TKey> keyCodec, FieldCodec<TValue> valueCodec, uint mapT
}

/// <summary>
/// The tag used in the enclosing message to indicate map entries.
/// The key codec.
/// </summary>
internal uint MapTag { get { return mapTag; } }
internal FieldCodec<TKey> KeyCodec => keyCodec;

/// <summary>
/// A mutable message class, used for parsing and serializing. This
/// delegates the work to a codec, but implements the <see cref="IMessage"/> interface
/// for interop with <see cref="CodedInputStream"/> and <see cref="CodedOutputStream"/>.
/// This is nested inside Codec as it's tightly coupled to the associated codec,
/// and it's simpler if it has direct access to all its fields.
/// The value codec.
/// </summary>
internal class MessageAdapter : IMessage, IBufferMessage
{
private static readonly byte[] ZeroLengthMessageStreamData = new byte[] { 0 };

private readonly Codec codec;
internal TKey Key { get; set; }
internal TValue Value { get; set; }

internal MessageAdapter(Codec codec)
{
this.codec = codec;
}

internal void Reset()
{
Key = codec.keyCodec.DefaultValue;
Value = codec.valueCodec.DefaultValue;
}

public void MergeFrom(CodedInputStream input)
{
// Message adapter is an internal class and we know that all the parsing will happen via InternalMergeFrom.
throw new NotImplementedException();
}
internal FieldCodec<TValue> ValueCodec => valueCodec;

[SecuritySafeCritical]
public void InternalMergeFrom(ref ParseContext ctx)
{
uint tag;
while ((tag = ctx.ReadTag()) != 0)
{
if (tag == codec.keyCodec.Tag)
{
Key = codec.keyCodec.Read(ref ctx);
}
else if (tag == codec.valueCodec.Tag)
{
Value = codec.valueCodec.Read(ref ctx);
}
else
{
ParsingPrimitivesMessages.SkipLastField(ref ctx.buffer, ref ctx.state);
}
}

// Corner case: a map entry with a key but no value, where the value type is a message.
// Read it as if we'd seen input with no data (i.e. create a "default" message).
if (Value == null)
{
if (ctx.state.CodedInputStream != null)
{
// the decoded message might not support parsing from ParseContext, so
// we need to allow fallback to the legacy MergeFrom(CodedInputStream) parsing.
Value = codec.valueCodec.Read(new CodedInputStream(ZeroLengthMessageStreamData));
}
else
{
ParseContext.Initialize(new ReadOnlySequence<byte>(ZeroLengthMessageStreamData), out ParseContext zeroLengthCtx);
Value = codec.valueCodec.Read(ref zeroLengthCtx);
}
}
}

public void WriteTo(CodedOutputStream output)
{
// Message adapter is an internal class and we know that all the writing will happen via InternalWriteTo.
throw new NotImplementedException();
}

[SecuritySafeCritical]
public void InternalWriteTo(ref WriteContext ctx)
{
codec.keyCodec.WriteTagAndValue(ref ctx, Key);
codec.valueCodec.WriteTagAndValue(ref ctx, Value);
}

public int CalculateSize()
{
return codec.keyCodec.CalculateSizeWithTag(Key) + codec.valueCodec.CalculateSizeWithTag(Value);
}

MessageDescriptor IMessage.Descriptor { get { return null; } }
}
/// <summary>
/// The tag used in the enclosing message to indicate map entries.
/// </summary>
internal uint MapTag => mapTag;
}

private class MapView<T> : ICollection<T>, ICollection
Expand Down
63 changes: 63 additions & 0 deletions csharp/src/Google.Protobuf/ParsingPrimitivesMessages.cs
Expand Up @@ -32,9 +32,11 @@

using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Security;
using Google.Protobuf.Collections;

namespace Google.Protobuf
{
Expand All @@ -44,6 +46,8 @@ namespace Google.Protobuf
[SecuritySafeCritical]
internal static class ParsingPrimitivesMessages
{
private static readonly byte[] ZeroLengthMessageStreamData = new byte[] { 0 };

public static void SkipLastField(ref ReadOnlySpan<byte> buffer, ref ParserInternalState state)
{
if (state.lastTag == 0)
Expand Down Expand Up @@ -134,6 +138,65 @@ public static void ReadMessage(ref ParseContext ctx, IMessage message)
SegmentedBufferHelper.PopLimit(ref ctx.state, oldLimit);
}

public static KeyValuePair<TKey, TValue> ReadMapEntry<TKey, TValue>(ref ParseContext ctx, MapField<TKey, TValue>.Codec codec)
{
int length = ParsingPrimitives.ParseLength(ref ctx.buffer, ref ctx.state);
if (ctx.state.recursionDepth >= ctx.state.recursionLimit)
{
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
int oldLimit = SegmentedBufferHelper.PushLimit(ref ctx.state, length);
++ctx.state.recursionDepth;

TKey key = codec.KeyCodec.DefaultValue;
TValue value = codec.ValueCodec.DefaultValue;

uint tag;
while ((tag = ctx.ReadTag()) != 0)
{
if (tag == codec.KeyCodec.Tag)
{
key = codec.KeyCodec.Read(ref ctx);
}
else if (tag == codec.ValueCodec.Tag)
{
value = codec.ValueCodec.Read(ref ctx);
}
else
{
SkipLastField(ref ctx.buffer, ref ctx.state);
}
}

// Corner case: a map entry with a key but no value, where the value type is a message.
// Read it as if we'd seen input with no data (i.e. create a "default" message).
if (value == null)
{
if (ctx.state.CodedInputStream != null)
{
// the decoded message might not support parsing from ParseContext, so
// we need to allow fallback to the legacy MergeFrom(CodedInputStream) parsing.
value = codec.ValueCodec.Read(new CodedInputStream(ZeroLengthMessageStreamData));
}
else
{
ParseContext.Initialize(new ReadOnlySequence<byte>(ZeroLengthMessageStreamData), out ParseContext zeroLengthCtx);
value = codec.ValueCodec.Read(ref zeroLengthCtx);
}
}

CheckReadEndOfStreamTag(ref ctx.state);
// Check that we've read exactly as much data as expected.
if (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state))
{
throw InvalidProtocolBufferException.TruncatedMessage();
}
--ctx.state.recursionDepth;
SegmentedBufferHelper.PopLimit(ref ctx.state, oldLimit);

return new KeyValuePair<TKey, TValue>(key, value);
}

public static void ReadGroup(ref ParseContext ctx, IMessage message)
{
if (ctx.state.recursionDepth >= ctx.state.recursionLimit)
Expand Down