Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid C# GC when deserializing large messages #13440

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/compiler/csharp_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ void GenerateMarshallerFields(Printer* out, const ServiceDescriptor* service) {
out->Print(
"static readonly grpc::Marshaller<$type$> $fieldname$ = "
"grpc::Marshallers.Create((arg) => "
"global::Google.Protobuf.MessageExtensions.ToByteArray(arg), "
"$type$.Parser.ParseFrom);\n",
"global::Google.Protobuf.MessageExtensions.ToByteArray(arg), (arg) => "
"$type$.Parser.ParseFrom(arg.Array, arg.Offset, arg.Count));\n",
"fieldname", GetMarshallerFieldName(message), "type",
GetClassName(message));
}
Expand Down
2 changes: 1 addition & 1 deletion src/csharp/Grpc.Core.Tests/Internal/AsyncCallServerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public void Init()

fakeCall = new FakeNativeCall();
asyncCallServer = new AsyncCallServer<string, string>(
Marshallers.StringMarshaller.Serializer, Marshallers.StringMarshaller.Deserializer,
Marshallers.StringMarshaller.Serializer, Marshallers.StringMarshaller.ArraySegmentDeserializer,
server);
asyncCallServer.InitializeForTesting(fakeCall);
}
Expand Down
4 changes: 2 additions & 2 deletions src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,9 @@ ClientSideStatus CreateClientSideStatus(StatusCode statusCode)
return new ClientSideStatus(new Status(statusCode, ""), new Metadata());
}

byte[] CreateResponsePayload()
INativePayloadReader CreateResponsePayload()
{
return Marshallers.StringMarshaller.Serializer("response1");
return new FakeNativePayloadReader(Marshallers.StringMarshaller.Serializer("response1"));
}

static void AssertUnaryResponseSuccess(AsyncCall<string, string> asyncCall, FakeNativeCall fakeCall, Task<string> resultTask)
Expand Down
56 changes: 56 additions & 0 deletions src/csharp/Grpc.Core.Tests/Internal/FakeNativePayloadReader.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#region Copyright notice and license

// Copyright 2015 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#endregion

using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Threading.Tasks;

using Grpc.Core.Internal;
using NUnit.Framework;

namespace Grpc.Core.Internal.Tests
{
/// <summary>
/// For testing purposes.
/// </summary>
internal class FakeNativePayloadReader : INativePayloadReader
{
readonly byte[] payload;

public FakeNativePayloadReader(byte[] payload)
{
this.payload = payload;
}

public byte[] ReadPayload()
{
return payload;
}

public int? GetPayloadLength()
{
return payload?.Length;
}

public void ReadPayloadToBuffer(byte[] buffer, int offset, int length)
{
Buffer.BlockCopy(payload, 0, buffer, offset, length);
}
}
}
8 changes: 4 additions & 4 deletions src/csharp/Grpc.Core/Internal/AsyncCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ internal class AsyncCall<TRequest, TResponse> : AsyncCallBase<TRequest, TRespons
ClientSideStatus? finishedStatus;

public AsyncCall(CallInvocationDetails<TRequest, TResponse> callDetails)
: base(callDetails.RequestMarshaller.Serializer, callDetails.ResponseMarshaller.Deserializer)
: base(callDetails.RequestMarshaller.Serializer, callDetails.ResponseMarshaller.ArraySegmentDeserializer)
{
this.details = callDetails.WithOptions(callDetails.Options.Normalize());
this.initialMetadataSent = true; // we always send metadata at the very beginning of the call.
Expand Down Expand Up @@ -103,7 +103,7 @@ public TResponse UnaryCall(TRequest msg)
{
using (profiler.NewScope("AsyncCall.UnaryCall.HandleBatch"))
{
HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage(), ctx.GetReceivedInitialMetadata());
HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.AsNativePayloadReader(), ctx.GetReceivedInitialMetadata());
}
}
catch (Exception e)
Expand Down Expand Up @@ -439,7 +439,7 @@ private void HandleReceivedResponseHeaders(bool success, Metadata responseHeader
/// <summary>
/// Handler for unary response completion.
/// </summary>
private void HandleUnaryResponse(bool success, ClientSideStatus receivedStatus, byte[] receivedMessage, Metadata responseHeaders)
private void HandleUnaryResponse(bool success, ClientSideStatus receivedStatus, INativePayloadReader receivedMessage, Metadata responseHeaders)
{
// NOTE: because this event is a result of batch containing GRPC_OP_RECV_STATUS_ON_CLIENT,
// success will be always set to true.
Expand Down Expand Up @@ -524,7 +524,7 @@ private void HandleFinished(bool success, ClientSideStatus receivedStatus)

IUnaryResponseClientCallback UnaryResponseClientCallback => this;

void IUnaryResponseClientCallback.OnUnaryResponseClient(bool success, ClientSideStatus receivedStatus, byte[] receivedMessage, Metadata responseHeaders)
void IUnaryResponseClientCallback.OnUnaryResponseClient(bool success, ClientSideStatus receivedStatus, INativePayloadReader receivedMessage, Metadata responseHeaders)
{
HandleUnaryResponse(success, receivedStatus, receivedMessage, responseHeaders);
}
Expand Down
22 changes: 16 additions & 6 deletions src/csharp/Grpc.Core/Internal/AsyncCallBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ namespace Grpc.Core.Internal
internal abstract class AsyncCallBase<TWrite, TRead> : IReceivedMessageCallback, ISendCompletionCallback
{
static readonly ILogger Logger = GrpcEnvironment.Logger.ForType<AsyncCallBase<TWrite, TRead>>();
static readonly ThreadLocalBufferCache DeserializerBufferCache = new ThreadLocalBufferCache();

protected static readonly Status DeserializeResponseFailureStatus = new Status(StatusCode.Internal, "Failed to deserialize response message.");

readonly Func<TWrite, byte[]> serializer;
readonly Func<byte[], TRead> deserializer;
readonly Func<ArraySegment<byte>, TRead> deserializer;

protected readonly object myLock = new object();

Expand All @@ -63,7 +65,7 @@ internal abstract class AsyncCallBase<TWrite, TRead> : IReceivedMessageCallback,
protected bool initialMetadataSent;
protected long streamingWritesCounter; // Number of streaming send operations started so far.

public AsyncCallBase(Func<TWrite, byte[]> serializer, Func<byte[], TRead> deserializer)
public AsyncCallBase(Func<TWrite, byte[]> serializer, Func<ArraySegment<byte>, TRead> deserializer)
{
this.serializer = GrpcPreconditions.CheckNotNull(serializer);
this.deserializer = GrpcPreconditions.CheckNotNull(deserializer);
Expand Down Expand Up @@ -214,11 +216,19 @@ protected byte[] UnsafeSerialize(TWrite msg)
return serializer(msg);
}

protected Exception TryDeserialize(byte[] payload, out TRead msg)
protected Exception TryDeserialize(INativePayloadReader payloadReader, out TRead msg)
{
try
{
msg = deserializer(payload);
int? len = payloadReader.GetPayloadLength();
if (!len.HasValue)
{
msg = default(TRead);
return null;
}
var buffer = DeserializerBufferCache.RentForCurrentScope(len.Value);
payloadReader.ReadPayloadToBuffer(buffer, 0, len.Value);
msg = deserializer(new ArraySegment<byte>(buffer, 0, len.Value));
return null;
}
catch (Exception e)
Expand Down Expand Up @@ -300,7 +310,7 @@ protected void HandleSendStatusFromServerFinished(bool success)
/// <summary>
/// Handles streaming read completion.
/// </summary>
protected void HandleReadFinished(bool success, byte[] receivedMessage)
protected void HandleReadFinished(bool success, INativePayloadReader receivedMessage)
{
// if success == false, received message will be null. It that case we will
// treat this completion as the last read an rely on C core to handle the failed
Expand Down Expand Up @@ -352,7 +362,7 @@ void ISendCompletionCallback.OnSendCompletion(bool success)

IReceivedMessageCallback ReceivedMessageCallback => this;

void IReceivedMessageCallback.OnReceivedMessage(bool success, byte[] receivedMessage)
void IReceivedMessageCallback.OnReceivedMessage(bool success, INativePayloadReader receivedMessage)
{
HandleReadFinished(success, receivedMessage);
}
Expand Down
2 changes: 1 addition & 1 deletion src/csharp/Grpc.Core/Internal/AsyncCallServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ internal class AsyncCallServer<TRequest, TResponse> : AsyncCallBase<TResponse, T
readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource();
readonly Server server;

public AsyncCallServer(Func<TResponse, byte[]> serializer, Func<byte[], TRequest> deserializer, Server server) : base(serializer, deserializer)
public AsyncCallServer(Func<TResponse, byte[]> serializer, Func<ArraySegment<byte>, TRequest> deserializer, Server server) : base(serializer, deserializer)
{
this.server = GrpcPreconditions.CheckNotNull(server);
}
Expand Down
46 changes: 44 additions & 2 deletions src/csharp/Grpc.Core/Internal/BatchContextSafeHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,22 @@ internal interface IOpCompletionCallback
void OnComplete(bool success);
}

/// <summary>
/// Provides a way to read the received message (a.k.a payload) from grpcsharp_batch_context
/// </summary>
internal interface INativePayloadReader
{
byte[] ReadPayload();

int? GetPayloadLength();

void ReadPayloadToBuffer(byte[] buffer, int offset, int length);
}

/// <summary>
/// grpcsharp_batch_context
/// </summary>
internal class BatchContextSafeHandle : SafeHandleZeroIsInvalid, IOpCompletionCallback
internal class BatchContextSafeHandle : SafeHandleZeroIsInvalid, IOpCompletionCallback, INativePayloadReader
{
static readonly NativeMethods Native = NativeMethods.Get();
static readonly ILogger Logger = GrpcEnvironment.Logger.ForType<BatchContextSafeHandle>();
Expand Down Expand Up @@ -97,10 +109,17 @@ public byte[] GetReceivedMessage()
return null;
}
byte[] data = new byte[(int)len];
Native.grpcsharp_batch_context_recv_message_to_buffer(this, data, new UIntPtr((ulong)data.Length));
Native.grpcsharp_batch_context_recv_message_to_buffer(this, data, UIntPtr.Zero, new UIntPtr((ulong)data.Length));
return data;
}

/// Returns payload reader for recv_message completion.
/// The method exists a convenience to improve code readability.
public INativePayloadReader AsNativePayloadReader()
{
return this;
}

// Gets data of receive_close_on_server completion.
public bool GetReceivedCloseOnServerCancelled()
{
Expand All @@ -126,6 +145,29 @@ protected override bool ReleaseHandle()
return true;
}

byte[] INativePayloadReader.ReadPayload()
{
return GetReceivedMessage();
}

int? INativePayloadReader.GetPayloadLength()
{
IntPtr len = Native.grpcsharp_batch_context_recv_message_length(this);
if (len == new IntPtr(-1))
{
return null;
}
return (int)len;
}

void INativePayloadReader.ReadPayloadToBuffer(byte[] buffer, int offset, int length)
{
GrpcPreconditions.CheckArgument(offset >= 0);
GrpcPreconditions.CheckArgument(length >= 0);
GrpcPreconditions.CheckArgument(buffer.Length >= offset + length);
Native.grpcsharp_batch_context_recv_message_to_buffer(this, buffer, new UIntPtr((ulong)offset), new UIntPtr((ulong)length));
}

void IOpCompletionCallback.OnComplete(bool success)
{
try
Expand Down
4 changes: 2 additions & 2 deletions src/csharp/Grpc.Core/Internal/CallSafeHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ internal class CallSafeHandle : SafeHandleZeroIsInvalid, INativeCall
// Completion handlers are pre-allocated to avoid unneccessary delegate allocations.
// The "state" field is used to store the actual callback to invoke.
static readonly BatchCompletionDelegate CompletionHandler_IUnaryResponseClientCallback =
(success, context, state) => ((IUnaryResponseClientCallback)state).OnUnaryResponseClient(success, context.GetReceivedStatusOnClient(), context.GetReceivedMessage(), context.GetReceivedInitialMetadata());
(success, context, state) => ((IUnaryResponseClientCallback)state).OnUnaryResponseClient(success, context.GetReceivedStatusOnClient(), context.AsNativePayloadReader(), context.GetReceivedInitialMetadata());
static readonly BatchCompletionDelegate CompletionHandler_IReceivedStatusOnClientCallback =
(success, context, state) => ((IReceivedStatusOnClientCallback)state).OnReceivedStatusOnClient(success, context.GetReceivedStatusOnClient());
static readonly BatchCompletionDelegate CompletionHandler_IReceivedMessageCallback =
(success, context, state) => ((IReceivedMessageCallback)state).OnReceivedMessage(success, context.GetReceivedMessage());
(success, context, state) => ((IReceivedMessageCallback)state).OnReceivedMessage(success, context.AsNativePayloadReader());
static readonly BatchCompletionDelegate CompletionHandler_IReceivedResponseHeadersCallback =
(success, context, state) => ((IReceivedResponseHeadersCallback)state).OnReceivedResponseHeaders(success, context.GetReceivedInitialMetadata());
static readonly BatchCompletionDelegate CompletionHandler_ISendCompletionCallback =
Expand Down
4 changes: 2 additions & 2 deletions src/csharp/Grpc.Core/Internal/INativeCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace Grpc.Core.Internal
{
internal interface IUnaryResponseClientCallback
{
void OnUnaryResponseClient(bool success, ClientSideStatus receivedStatus, byte[] receivedMessage, Metadata responseHeaders);
void OnUnaryResponseClient(bool success, ClientSideStatus receivedStatus, INativePayloadReader receivedMessage, Metadata responseHeaders);
}

// Received status for streaming response calls.
Expand All @@ -33,7 +33,7 @@ internal interface IReceivedStatusOnClientCallback

internal interface IReceivedMessageCallback
{
void OnReceivedMessage(bool success, byte[] receivedMessage);
void OnReceivedMessage(bool success, INativePayloadReader receivedMessage);
}

internal interface IReceivedResponseHeadersCallback
Expand Down
2 changes: 1 addition & 1 deletion src/csharp/Grpc.Core/Internal/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ public class Delegates
public delegate BatchContextSafeHandle grpcsharp_batch_context_create_delegate();
public delegate IntPtr grpcsharp_batch_context_recv_initial_metadata_delegate(BatchContextSafeHandle ctx);
public delegate IntPtr grpcsharp_batch_context_recv_message_length_delegate(BatchContextSafeHandle ctx);
public delegate void grpcsharp_batch_context_recv_message_to_buffer_delegate(BatchContextSafeHandle ctx, byte[] buffer, UIntPtr bufferLen);
public delegate void grpcsharp_batch_context_recv_message_to_buffer_delegate(BatchContextSafeHandle ctx, byte[] buffer, UIntPtr offset, UIntPtr maxLen);
public delegate StatusCode grpcsharp_batch_context_recv_status_on_client_status_delegate(BatchContextSafeHandle ctx);
public delegate IntPtr grpcsharp_batch_context_recv_status_on_client_details_delegate(BatchContextSafeHandle ctx, out UIntPtr detailsLength);
public delegate IntPtr grpcsharp_batch_context_recv_status_on_client_trailing_metadata_delegate(BatchContextSafeHandle ctx);
Expand Down
8 changes: 4 additions & 4 deletions src/csharp/Grpc.Core/Internal/ServerCallHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public async Task HandleCall(ServerRpcNew newRpc, CompletionQueueSafeHandle cq)
{
var asyncCall = new AsyncCallServer<TRequest, TResponse>(
method.ResponseMarshaller.Serializer,
method.RequestMarshaller.Deserializer,
method.RequestMarshaller.ArraySegmentDeserializer,
newRpc.Server);

asyncCall.Initialize(newRpc.Call, cq);
Expand Down Expand Up @@ -110,7 +110,7 @@ public async Task HandleCall(ServerRpcNew newRpc, CompletionQueueSafeHandle cq)
{
var asyncCall = new AsyncCallServer<TRequest, TResponse>(
method.ResponseMarshaller.Serializer,
method.RequestMarshaller.Deserializer,
method.RequestMarshaller.ArraySegmentDeserializer,
newRpc.Server);

asyncCall.Initialize(newRpc.Call, cq);
Expand Down Expand Up @@ -168,7 +168,7 @@ public async Task HandleCall(ServerRpcNew newRpc, CompletionQueueSafeHandle cq)
{
var asyncCall = new AsyncCallServer<TRequest, TResponse>(
method.ResponseMarshaller.Serializer,
method.RequestMarshaller.Deserializer,
method.RequestMarshaller.ArraySegmentDeserializer,
newRpc.Server);

asyncCall.Initialize(newRpc.Call, cq);
Expand Down Expand Up @@ -226,7 +226,7 @@ public async Task HandleCall(ServerRpcNew newRpc, CompletionQueueSafeHandle cq)
{
var asyncCall = new AsyncCallServer<TRequest, TResponse>(
method.ResponseMarshaller.Serializer,
method.RequestMarshaller.Deserializer,
method.RequestMarshaller.ArraySegmentDeserializer,
newRpc.Server);

asyncCall.Initialize(newRpc.Call, cq);
Expand Down