Skip to content

Commit

Permalink
Merge pull request #8143 from JamesNK/jamesnk/messageadapter
Browse files Browse the repository at this point in the history
Optimize MapField serialization by removing MessageAdapter
  • Loading branch information
jtattermusch committed Jan 26, 2021
2 parents 48234f5 + 69223b8 commit f6da785
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 101 deletions.
26 changes: 26 additions & 0 deletions csharp/src/Google.Protobuf.Test/Collections/MapFieldTest.cs
Expand Up @@ -611,6 +611,32 @@ public void AddEntriesFrom_CodedInputStream()
Assert.IsTrue(input.IsAtEnd);
}

[Test]
public void AddEntriesFrom_CodedInputStream_MissingKey()
{
// map will have string key and string value
var keyTag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited);
var valueTag = WireFormat.MakeTag(2, WireFormat.WireType.LengthDelimited);

var memoryStream = new MemoryStream();
var output = new CodedOutputStream(memoryStream);
output.WriteLength(11); // total of valueTag + value
output.WriteTag(valueTag);
output.WriteString("the_value");
output.Flush();

Console.WriteLine(BitConverter.ToString(memoryStream.ToArray()));

var field = new MapField<string, string>();
var mapCodec = new MapField<string, string>.Codec(FieldCodec.ForString(keyTag, ""), FieldCodec.ForString(valueTag, ""), 10);
var input = new CodedInputStream(memoryStream.ToArray());

field.AddEntriesFrom(input, mapCodec);
CollectionAssert.AreEquivalent(new[] { "" }, field.Keys);
CollectionAssert.AreEquivalent(new[] { "the_value" }, field.Values);
Assert.IsTrue(input.IsAtEnd);
}

#if !NET35
[Test]
public void IDictionaryKeys_Equals_IReadOnlyDictionaryKeys()
Expand Down
123 changes: 22 additions & 101 deletions csharp/src/Google.Protobuf/Collections/MapField.cs
Expand Up @@ -448,12 +448,10 @@ public void AddEntriesFrom(CodedInputStream input, Codec codec)
[SecuritySafeCritical]
public void AddEntriesFrom(ref ParseContext ctx, Codec codec)
{
var adapter = new Codec.MessageAdapter(codec);
do
{
adapter.Reset();
ctx.ReadMessage(adapter);
this[adapter.Key] = adapter.Value;
KeyValuePair<TKey, TValue> entry = ParsingPrimitivesMessages.ReadMapEntry(ref ctx, codec);
this[entry.Key] = entry.Value;
} while (ParsingPrimitives.MaybeConsumeTag(ref ctx.buffer, ref ctx.state, codec.MapTag));
}

Expand Down Expand Up @@ -485,13 +483,13 @@ public void WriteTo(CodedOutputStream output, Codec codec)
[SecuritySafeCritical]
public void WriteTo(ref WriteContext ctx, Codec codec)
{
var message = new Codec.MessageAdapter(codec);
foreach (var entry in list)
{
message.Key = entry.Key;
message.Value = entry.Value;
ctx.WriteTag(codec.MapTag);
ctx.WriteMessage(message);

WritingPrimitives.WriteLength(ref ctx.buffer, ref ctx.state, CalculateEntrySize(codec, entry));
codec.KeyCodec.WriteTagAndValue(ref ctx, entry.Key);
codec.ValueCodec.WriteTagAndValue(ref ctx, entry.Value);
}
}

Expand All @@ -506,18 +504,22 @@ public int CalculateSize(Codec codec)
{
return 0;
}
var message = new Codec.MessageAdapter(codec);
int size = 0;
foreach (var entry in list)
{
message.Key = entry.Key;
message.Value = entry.Value;
int entrySize = CalculateEntrySize(codec, entry);

size += CodedOutputStream.ComputeRawVarint32Size(codec.MapTag);
size += CodedOutputStream.ComputeMessageSize(message);
size += CodedOutputStream.ComputeLengthSize(entrySize) + entrySize;
}
return size;
}

private static int CalculateEntrySize(Codec codec, KeyValuePair<TKey, TValue> entry)
{
return codec.KeyCodec.CalculateSizeWithTag(entry.Key) + codec.ValueCodec.CalculateSizeWithTag(entry.Value);
}

/// <summary>
/// Returns a string representation of this repeated field, in the same
/// way as it would be represented by the default JSON formatter.
Expand Down Expand Up @@ -655,100 +657,19 @@ public Codec(FieldCodec<TKey> keyCodec, FieldCodec<TValue> valueCodec, uint mapT
}

/// <summary>
/// The tag used in the enclosing message to indicate map entries.
/// The key codec.
/// </summary>
internal uint MapTag { get { return mapTag; } }
internal FieldCodec<TKey> KeyCodec => keyCodec;

/// <summary>
/// A mutable message class, used for parsing and serializing. This
/// delegates the work to a codec, but implements the <see cref="IMessage"/> interface
/// for interop with <see cref="CodedInputStream"/> and <see cref="CodedOutputStream"/>.
/// This is nested inside Codec as it's tightly coupled to the associated codec,
/// and it's simpler if it has direct access to all its fields.
/// The value codec.
/// </summary>
internal class MessageAdapter : IMessage, IBufferMessage
{
private static readonly byte[] ZeroLengthMessageStreamData = new byte[] { 0 };

private readonly Codec codec;
internal TKey Key { get; set; }
internal TValue Value { get; set; }

internal MessageAdapter(Codec codec)
{
this.codec = codec;
}

internal void Reset()
{
Key = codec.keyCodec.DefaultValue;
Value = codec.valueCodec.DefaultValue;
}

public void MergeFrom(CodedInputStream input)
{
// Message adapter is an internal class and we know that all the parsing will happen via InternalMergeFrom.
throw new NotImplementedException();
}
internal FieldCodec<TValue> ValueCodec => valueCodec;

[SecuritySafeCritical]
public void InternalMergeFrom(ref ParseContext ctx)
{
uint tag;
while ((tag = ctx.ReadTag()) != 0)
{
if (tag == codec.keyCodec.Tag)
{
Key = codec.keyCodec.Read(ref ctx);
}
else if (tag == codec.valueCodec.Tag)
{
Value = codec.valueCodec.Read(ref ctx);
}
else
{
ParsingPrimitivesMessages.SkipLastField(ref ctx.buffer, ref ctx.state);
}
}

// Corner case: a map entry with a key but no value, where the value type is a message.
// Read it as if we'd seen input with no data (i.e. create a "default" message).
if (Value == null)
{
if (ctx.state.CodedInputStream != null)
{
// the decoded message might not support parsing from ParseContext, so
// we need to allow fallback to the legacy MergeFrom(CodedInputStream) parsing.
Value = codec.valueCodec.Read(new CodedInputStream(ZeroLengthMessageStreamData));
}
else
{
ParseContext.Initialize(new ReadOnlySequence<byte>(ZeroLengthMessageStreamData), out ParseContext zeroLengthCtx);
Value = codec.valueCodec.Read(ref zeroLengthCtx);
}
}
}

public void WriteTo(CodedOutputStream output)
{
// Message adapter is an internal class and we know that all the writing will happen via InternalWriteTo.
throw new NotImplementedException();
}

[SecuritySafeCritical]
public void InternalWriteTo(ref WriteContext ctx)
{
codec.keyCodec.WriteTagAndValue(ref ctx, Key);
codec.valueCodec.WriteTagAndValue(ref ctx, Value);
}

public int CalculateSize()
{
return codec.keyCodec.CalculateSizeWithTag(Key) + codec.valueCodec.CalculateSizeWithTag(Value);
}

MessageDescriptor IMessage.Descriptor { get { return null; } }
}
/// <summary>
/// The tag used in the enclosing message to indicate map entries.
/// </summary>
internal uint MapTag => mapTag;
}

private class MapView<T> : ICollection<T>, ICollection
Expand Down
63 changes: 63 additions & 0 deletions csharp/src/Google.Protobuf/ParsingPrimitivesMessages.cs
Expand Up @@ -32,9 +32,11 @@

using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Security;
using Google.Protobuf.Collections;

namespace Google.Protobuf
{
Expand All @@ -44,6 +46,8 @@ namespace Google.Protobuf
[SecuritySafeCritical]
internal static class ParsingPrimitivesMessages
{
private static readonly byte[] ZeroLengthMessageStreamData = new byte[] { 0 };

public static void SkipLastField(ref ReadOnlySpan<byte> buffer, ref ParserInternalState state)
{
if (state.lastTag == 0)
Expand Down Expand Up @@ -134,6 +138,65 @@ public static void ReadMessage(ref ParseContext ctx, IMessage message)
SegmentedBufferHelper.PopLimit(ref ctx.state, oldLimit);
}

public static KeyValuePair<TKey, TValue> ReadMapEntry<TKey, TValue>(ref ParseContext ctx, MapField<TKey, TValue>.Codec codec)
{
int length = ParsingPrimitives.ParseLength(ref ctx.buffer, ref ctx.state);
if (ctx.state.recursionDepth >= ctx.state.recursionLimit)
{
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
int oldLimit = SegmentedBufferHelper.PushLimit(ref ctx.state, length);
++ctx.state.recursionDepth;

TKey key = codec.KeyCodec.DefaultValue;
TValue value = codec.ValueCodec.DefaultValue;

uint tag;
while ((tag = ctx.ReadTag()) != 0)
{
if (tag == codec.KeyCodec.Tag)
{
key = codec.KeyCodec.Read(ref ctx);
}
else if (tag == codec.ValueCodec.Tag)
{
value = codec.ValueCodec.Read(ref ctx);
}
else
{
SkipLastField(ref ctx.buffer, ref ctx.state);
}
}

// Corner case: a map entry with a key but no value, where the value type is a message.
// Read it as if we'd seen input with no data (i.e. create a "default" message).
if (value == null)
{
if (ctx.state.CodedInputStream != null)
{
// the decoded message might not support parsing from ParseContext, so
// we need to allow fallback to the legacy MergeFrom(CodedInputStream) parsing.
value = codec.ValueCodec.Read(new CodedInputStream(ZeroLengthMessageStreamData));
}
else
{
ParseContext.Initialize(new ReadOnlySequence<byte>(ZeroLengthMessageStreamData), out ParseContext zeroLengthCtx);
value = codec.ValueCodec.Read(ref zeroLengthCtx);
}
}

CheckReadEndOfStreamTag(ref ctx.state);
// Check that we've read exactly as much data as expected.
if (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state))
{
throw InvalidProtocolBufferException.TruncatedMessage();
}
--ctx.state.recursionDepth;
SegmentedBufferHelper.PopLimit(ref ctx.state, oldLimit);

return new KeyValuePair<TKey, TValue>(key, value);
}

public static void ReadGroup(ref ParseContext ctx, IMessage message)
{
if (ctx.state.recursionDepth >= ctx.state.recursionLimit)
Expand Down

0 comments on commit f6da785

Please sign in to comment.