diff --git a/src/Grpc.AspNetCore.Server.ClientFactory/ContextPropagationInterceptor.cs b/src/Grpc.AspNetCore.Server.ClientFactory/ContextPropagationInterceptor.cs index fa6a2d3ef..d2baa257b 100644 --- a/src/Grpc.AspNetCore.Server.ClientFactory/ContextPropagationInterceptor.cs +++ b/src/Grpc.AspNetCore.Server.ClientFactory/ContextPropagationInterceptor.cs @@ -22,7 +22,6 @@ using Grpc.Core.Interceptors; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; namespace Grpc.AspNetCore.ClientFactory; @@ -53,14 +52,15 @@ public ContextPropagationInterceptor(GrpcContextPropagationOptions options, IHtt } else { + var state = CreateContextState(call, cts); return new AsyncClientStreamingCall( requestStream: call.RequestStream, - responseAsync: call.ResponseAsync, + responseAsync: OnResponseAsync(call.ResponseAsync, state), responseHeadersAsync: ClientStreamingCallbacks.GetResponseHeadersAsync, getStatusFunc: ClientStreamingCallbacks.GetStatus, getTrailersFunc: ClientStreamingCallbacks.GetTrailers, disposeAction: ClientStreamingCallbacks.Dispose, - CreateContextState(call, cts)); + state); } } @@ -73,14 +73,15 @@ public ContextPropagationInterceptor(GrpcContextPropagationOptions options, IHtt } else { + var state = CreateContextState(call, cts); return new AsyncDuplexStreamingCall( requestStream: call.RequestStream, - responseStream: call.ResponseStream, + responseStream: new ResponseStreamWrapper(call.ResponseStream, state), responseHeadersAsync: DuplexStreamingCallbacks.GetResponseHeadersAsync, getStatusFunc: DuplexStreamingCallbacks.GetStatus, getTrailersFunc: DuplexStreamingCallbacks.GetTrailers, disposeAction: DuplexStreamingCallbacks.Dispose, - CreateContextState(call, cts)); + state); } } @@ -93,13 +94,14 @@ public ContextPropagationInterceptor(GrpcContextPropagationOptions options, IHtt } else { + var state = CreateContextState(call, cts); return new AsyncServerStreamingCall( - responseStream: call.ResponseStream, + responseStream: new ResponseStreamWrapper(call.ResponseStream, state), responseHeadersAsync: ServerStreamingCallbacks.GetResponseHeadersAsync, getStatusFunc: ServerStreamingCallbacks.GetStatus, getTrailersFunc: ServerStreamingCallbacks.GetTrailers, disposeAction: ServerStreamingCallbacks.Dispose, - CreateContextState(call, cts)); + state); } } @@ -112,13 +114,14 @@ public ContextPropagationInterceptor(GrpcContextPropagationOptions options, IHtt } else { + var state = CreateContextState(call, cts); return new AsyncUnaryCall( - responseAsync: call.ResponseAsync, + responseAsync: OnResponseAsync(call.ResponseAsync, state), responseHeadersAsync: UnaryCallbacks.GetResponseHeadersAsync, getStatusFunc: UnaryCallbacks.GetStatus, getTrailersFunc: UnaryCallbacks.GetTrailers, disposeAction: UnaryCallbacks.Dispose, - CreateContextState(call, cts)); + state); } } @@ -129,6 +132,19 @@ public ContextPropagationInterceptor(GrpcContextPropagationOptions options, IHtt return response; } + // Automatically dispose state after awaiting the response. + private static async Task OnResponseAsync(Task task, IDisposable state) + { + try + { + return await task.ConfigureAwait(false); + } + finally + { + state.Dispose(); + } + } + private ClientInterceptorContext ConfigureContext(ClientInterceptorContext context, out CancellationTokenSource? linkedCts) where TRequest : class where TResponse : class @@ -197,7 +213,7 @@ private bool TryGetServerCallContext([NotNullWhen(true)] out ServerCallContext? private ContextState CreateContextState(TCall call, CancellationTokenSource cancellationTokenSource) where TCall : IDisposable => new ContextState(call, cancellationTokenSource); - private class ContextState : IDisposable where TCall : IDisposable + private sealed class ContextState : IDisposable where TCall : IDisposable { public ContextState(TCall call, CancellationTokenSource cancellationTokenSource) { @@ -215,6 +231,33 @@ public void Dispose() } } + // Automatically dispose state after reading to the end of the stream. + private sealed class ResponseStreamWrapper : IAsyncStreamReader + { + private readonly IAsyncStreamReader _inner; + private readonly IDisposable _state; + private bool _disposed; + + public ResponseStreamWrapper(IAsyncStreamReader inner, IDisposable state) + { + _inner = inner; + _state = state; + } + + public TResponse Current => _inner.Current; + + public async Task MoveNext(CancellationToken cancellationToken) + { + var result = await _inner.MoveNext(cancellationToken); + if (!result && !_disposed) + { + _state.Dispose(); + _disposed = true; + } + return result; + } + } + private static class Log { private static readonly Action _propagateServerCallContextFailure = diff --git a/test/Grpc.AspNetCore.Server.ClientFactory.Tests/DefaultGrpcClientFactoryTests.cs b/test/Grpc.AspNetCore.Server.ClientFactory.Tests/DefaultGrpcClientFactoryTests.cs index 8ed16eb7a..e6c22b041 100644 --- a/test/Grpc.AspNetCore.Server.ClientFactory.Tests/DefaultGrpcClientFactoryTests.cs +++ b/test/Grpc.AspNetCore.Server.ClientFactory.Tests/DefaultGrpcClientFactoryTests.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -20,6 +20,7 @@ using Greet; using Grpc.AspNetCore.Server.ClientFactory.Tests.TestObjects; using Grpc.Core; +using Grpc.Core.Interceptors; using Grpc.Net.ClientFactory; using Grpc.Net.ClientFactory.Internal; using Grpc.Tests.Shared; @@ -91,6 +92,155 @@ public async Task CreateClient_ServerCallContextHasValues_PropogatedDeadlineAndC Assert.AreEqual(cancellationToken, options.CancellationToken); } + [Test] + public async Task CreateClient_Unary_ServerCallContextHasValues_StateDisposed() + { + // Arrange + var baseAddress = new Uri("http://localhost"); + var deadline = DateTime.UtcNow.AddDays(1); + var cancellationToken = new CancellationTokenSource().Token; + + var interceptor = new OnDisposedInterceptor(); + + var services = new ServiceCollection(); + services.AddOptions(); + services.AddSingleton(CreateHttpContextAccessorWithServerCallContext(deadline: deadline, cancellationToken: cancellationToken)); + services + .AddGrpcClient(o => + { + o.Address = baseAddress; + }) + .EnableCallContextPropagation() + .AddInterceptor(() => interceptor) + .ConfigurePrimaryHttpMessageHandler(() => ClientTestHelpers.CreateTestMessageHandler(new HelloReply())); + + var serviceProvider = services.BuildServiceProvider(validateScopes: true); + + var clientFactory = CreateGrpcClientFactory(serviceProvider); + var client = clientFactory.CreateClient(nameof(Greeter.GreeterClient)); + + // Checking that token register calls don't build up on CTS and create a memory leak. + var cts = new CancellationTokenSource(); + + // Act + // Send calls in a different method so there is no chance that a stack reference + // to a gRPC call is still alive after calls are complete. + var response = await client.SayHelloAsync(new HelloRequest(), cancellationToken: cts.Token); + + // Assert + Assert.IsTrue(interceptor.ContextDisposed); + } + + [Test] + public async Task CreateClient_ServerStreaming_ServerCallContextHasValues_StateDisposed() + { + // Arrange + var baseAddress = new Uri("http://localhost"); + var deadline = DateTime.UtcNow.AddDays(1); + var cancellationToken = new CancellationTokenSource().Token; + + var interceptor = new OnDisposedInterceptor(); + + var services = new ServiceCollection(); + services.AddOptions(); + services.AddSingleton(CreateHttpContextAccessorWithServerCallContext(deadline: deadline, cancellationToken: cancellationToken)); + services + .AddGrpcClient(o => + { + o.Address = baseAddress; + }) + .EnableCallContextPropagation() + .AddInterceptor(() => interceptor) + .ConfigurePrimaryHttpMessageHandler(() => ClientTestHelpers.CreateTestMessageHandler(new HelloReply())); + + var serviceProvider = services.BuildServiceProvider(validateScopes: true); + + var clientFactory = CreateGrpcClientFactory(serviceProvider); + var client = clientFactory.CreateClient(nameof(Greeter.GreeterClient)); + + // Checking that token register calls don't build up on CTS and create a memory leak. + var cts = new CancellationTokenSource(); + + // Act + // Send calls in a different method so there is no chance that a stack reference + // to a gRPC call is still alive after calls are complete. + var call = client.SayHellos(new HelloRequest(), cancellationToken: cts.Token); + + Assert.IsTrue(await call.ResponseStream.MoveNext()); + Assert.IsFalse(await call.ResponseStream.MoveNext()); + + // Assert + Assert.IsTrue(interceptor.ContextDisposed); + } + + private sealed class OnDisposedInterceptor : Interceptor + { + public bool ContextDisposed { get; private set; } + + public override TResponse BlockingUnaryCall(TRequest request, ClientInterceptorContext context, BlockingUnaryCallContinuation continuation) + { + return continuation(request, context); + } + + public override AsyncUnaryCall AsyncUnaryCall(TRequest request, ClientInterceptorContext context, AsyncUnaryCallContinuation continuation) + { + var call = continuation(request, context); + return new AsyncUnaryCall(call.ResponseAsync, + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + () => + { + call.Dispose(); + ContextDisposed = true; + }); + } + + public override AsyncServerStreamingCall AsyncServerStreamingCall(TRequest request, ClientInterceptorContext context, AsyncServerStreamingCallContinuation continuation) + { + var call = continuation(request, context); + return new AsyncServerStreamingCall(call.ResponseStream, + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + () => + { + call.Dispose(); + ContextDisposed = true; + }); + } + + public override AsyncClientStreamingCall AsyncClientStreamingCall(ClientInterceptorContext context, AsyncClientStreamingCallContinuation continuation) + { + var call = continuation(context); + return new AsyncClientStreamingCall(call.RequestStream, + call.ResponseAsync, + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + () => + { + call.Dispose(); + ContextDisposed = true; + }); + } + + public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall(ClientInterceptorContext context, AsyncDuplexStreamingCallContinuation continuation) + { + var call = continuation(context); + return new AsyncDuplexStreamingCall(call.RequestStream, + call.ResponseStream, + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + () => + { + call.Dispose(); + ContextDisposed = true; + }); + } + } + [TestCase(Canceller.Context)] [TestCase(Canceller.User)] public async Task CreateClient_ServerCallContextAndUserCancellationToken_PropogatedDeadlineAndCancellation(Canceller canceller)