Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add semaphore to limit subchannel connect to prevent race conditions #2422

Merged
merged 6 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
109 changes: 83 additions & 26 deletions src/Grpc.Net.Client/Balancer/Subchannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public sealed class Subchannel : IDisposable

internal readonly ConnectionManager _manager;
private readonly ILogger _logger;
private readonly SemaphoreSlim _connectSemaphore;

private ISubchannelTransport _transport = default!;
private ConnectContext? _connectContext;
Expand Down Expand Up @@ -89,6 +90,7 @@ internal Subchannel(ConnectionManager manager, IReadOnlyList<BalancerAddress> ad
{
Lock = new object();
_logger = manager.LoggerFactory.CreateLogger(GetType());
_connectSemaphore = new SemaphoreSlim(1);

Id = manager.GetNextId();
_addresses = addresses.ToList();
Expand Down Expand Up @@ -213,7 +215,10 @@ public void UpdateAddresses(IReadOnlyList<BalancerAddress> addresses)

if (requireReconnect)
{
CancelInProgressConnect();
lock (Lock)
{
CancelInProgressConnectUnsynchronized();
}
_transport.Disconnect();
RequestConnection();
}
Expand Down Expand Up @@ -268,43 +273,76 @@ public void RequestConnection()
}
}

private void CancelInProgressConnect()
private void CancelInProgressConnectUnsynchronized()
{
lock (Lock)
{
if (_connectContext != null && !_connectContext.Disposed)
{
SubchannelLog.CancelingConnect(_logger, Id);
Debug.Assert(Monitor.IsEntered(Lock));

// Cancel connect cancellation token.
_connectContext.CancelConnect();
_connectContext.Dispose();
}
if (_connectContext != null && !_connectContext.Disposed)
{
SubchannelLog.CancelingConnect(_logger, Id);

_delayInterruptTcs?.TrySetResult(null);
// Cancel connect cancellation token.
_connectContext.CancelConnect();
_connectContext.Dispose();
}

_delayInterruptTcs?.TrySetResult(null);
}

private ConnectContext GetConnectContext()
private ConnectContext GetConnectContextUnsynchronized()
{
lock (Lock)
{
// There shouldn't be a previous connect in progress, but cancel the CTS to ensure they're no longer running.
CancelInProgressConnect();
Debug.Assert(Monitor.IsEntered(Lock));

var connectContext = _connectContext = new ConnectContext(_transport.ConnectTimeout ?? Timeout.InfiniteTimeSpan);
return connectContext;
}
// There shouldn't be a previous connect in progress, but cancel the CTS to ensure they're no longer running.
CancelInProgressConnectUnsynchronized();

var connectContext = _connectContext = new ConnectContext(_transport.ConnectTimeout ?? Timeout.InfiniteTimeSpan);
return connectContext;
}

private async Task ConnectTransportAsync()
{
var connectContext = GetConnectContext();
ConnectContext connectContext;
Task? waitSemaporeTask = null;
lock (Lock)
{
// Don't start connecting if the subchannel has been shutdown. Transport/semaphore will be disposed if shutdown.
if (_state == ConnectivityState.Shutdown)
{
return;
}

connectContext = GetConnectContextUnsynchronized();

// Use a semaphore to limit one connection attempt at a time. This is done to prevent a race conditional where a canceled connect
// overwrites the status of a successful connect.
//
// Try to get semaphore without waiting. If semaphore is already taken then start a task to wait for it to be released.
// Start this inside a lock to make sure subchannel isn't shutdown before waiting for semaphore.
if (!_connectSemaphore.Wait(0))
{
SubchannelLog.QueuingConnect(_logger, Id);
waitSemaporeTask = _connectSemaphore.WaitAsync(connectContext.CancellationToken);
}
}

var backoffPolicy = _manager.BackoffPolicyFactory.Create();
if (waitSemaporeTask != null)
{
try
{
await waitSemaporeTask.ConfigureAwait(false);
}
catch (OperationCanceledException)
{
// Canceled while waiting for semaphore.
return;
}
}

try
{
var backoffPolicy = _manager.BackoffPolicyFactory.Create();

SubchannelLog.ConnectingTransport(_logger, Id);

for (var attempt = 0; ; attempt++)
Expand Down Expand Up @@ -384,6 +422,13 @@ private async Task ConnectTransportAsync()
// Dispose context because it might have been created with a connect timeout.
// Want to clean up the connect timeout timer.
connectContext.Dispose();

// Subchannel could have been disposed while connect is running.
// If subchannel is shutting down then don't release semaphore to avoid ObjectDisposedException.
if (_state != ConnectivityState.Shutdown)
{
_connectSemaphore.Release();
}
}
}
}
Expand Down Expand Up @@ -482,8 +527,12 @@ public void Dispose()
}
_stateChangedRegistrations.Clear();

CancelInProgressConnect();
_transport.Dispose();
lock (Lock)
{
CancelInProgressConnectUnsynchronized();
_transport.Dispose();
_connectSemaphore.Dispose();
}
}
}

Expand All @@ -505,7 +554,7 @@ internal static class SubchannelLog
LoggerMessage.Define<string, ConnectivityState>(LogLevel.Debug, new EventId(5, "ConnectionRequestedInNonIdleState"), "Subchannel id '{SubchannelId}' connection requested in non-idle state of {State}.");

private static readonly Action<ILogger, string, Exception?> _connectingTransport =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(6, "ConnectingTransport"), "Subchannel id '{SubchannelId}' connecting to transport.");
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(6, "ConnectingTransport"), "Subchannel id '{SubchannelId}' connecting to transport.");

private static readonly Action<ILogger, string, TimeSpan, Exception?> _startingConnectBackoff =
LoggerMessage.Define<string, TimeSpan>(LogLevel.Trace, new EventId(7, "StartingConnectBackoff"), "Subchannel id '{SubchannelId}' starting connect backoff of {BackoffDuration}.");
Expand All @@ -514,7 +563,7 @@ internal static class SubchannelLog
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(8, "ConnectBackoffInterrupted"), "Subchannel id '{SubchannelId}' connect backoff interrupted.");

private static readonly Action<ILogger, string, Exception?> _connectCanceled =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(9, "ConnectCanceled"), "Subchannel id '{SubchannelId}' connect canceled.");
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(9, "ConnectCanceled"), "Subchannel id '{SubchannelId}' connect canceled.");

private static readonly Action<ILogger, string, Exception?> _connectError =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(10, "ConnectError"), "Subchannel id '{SubchannelId}' unexpected error while connecting to transport.");
Expand Down Expand Up @@ -546,6 +595,9 @@ internal static class SubchannelLog
private static readonly Action<ILogger, string, string, Exception?> _addressesUpdated =
LoggerMessage.Define<string, string>(LogLevel.Trace, new EventId(19, "AddressesUpdated"), "Subchannel id '{SubchannelId}' updated with addresses: {Addresses}");

private static readonly Action<ILogger, string, Exception?> _queuingConnect =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(20, "QueuingConnect"), "Subchannel id '{SubchannelId}' queuing connect because a connect is already in progress.");

public static void SubchannelCreated(ILogger logger, string subchannelId, IReadOnlyList<BalancerAddress> addresses)
{
if (logger.IsEnabled(LogLevel.Debug))
Expand Down Expand Up @@ -648,5 +700,10 @@ public static void AddressesUpdated(ILogger logger, string subchannelId, IReadOn
_addressesUpdated(logger, subchannelId, addressesText, null);
}
}

public static void QueuingConnect(ILogger logger, string subchannelId)
{
_queuingConnect(logger, subchannelId, null);
}
}
#endif
5 changes: 4 additions & 1 deletion test/FunctionalTests/Balancer/ConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte

var client = TestClientFactory.Create(channel, endpoint.Method);

// Act
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => client.UnaryCall(new HelloRequest()).ResponseAsync).DefaultTimeout();
Assert.AreEqual("A connection could not be established within the configured ConnectTimeout.", ex.Status.DebugException!.Message);

await ExceptionAssert.ThrowsAsync<OperationCanceledException>(() => connectTcs.Task).DefaultTimeout();
}

[Test]
Expand Down Expand Up @@ -167,7 +170,7 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
connectionIdleTimeout: connectionIdleTimeout).DefaultTimeout();

Logger.LogInformation("Connecting channel.");
await channel.ConnectAsync();
await channel.ConnectAsync().DefaultTimeout();

// Wait for timeout plus a little extra to avoid issues from imprecise timers.
await Task.Delay(connectionIdleTimeout + TimeSpan.FromMilliseconds(50));
Expand Down
39 changes: 39 additions & 0 deletions test/FunctionalTests/Balancer/PickFirstBalancerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,45 @@ private GrpcChannel CreateGrpcWebChannel(TestServerEndpointName endpointName, Ve
return channel;
}

[Test]
public async Task UnaryCall_CallAfterConnectionTimeout_Success()
{
// Ignore errors
SetExpectedErrorsFilter(writeContext =>
{
return true;
});

string? host = null;
Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
{
host = context.Host;
return Task.FromResult(new HelloReply { Message = request.Name });
}

// Arrange
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod));

var connectCount = 0;
var channel = await BalancerHelpers.CreateChannel(LoggerFactory, new PickFirstConfig(), new[] { endpoint.Address }, connectTimeout: TimeSpan.FromMilliseconds(200), socketConnect:
async (socket, endpoint, cancellationToken) =>
{
if (Interlocked.Increment(ref connectCount) == 1)
{
await Task.Delay(1000, cancellationToken);
}
await socket.ConnectAsync(endpoint, cancellationToken);
}).DefaultTimeout();
var client = TestClientFactory.Create(channel, endpoint.Method);

// Assert
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => client.UnaryCall(new HelloRequest { Name = "Balancer" }).ResponseAsync).DefaultTimeout();
Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode);
Assert.IsInstanceOf(typeof(TimeoutException), ex.InnerException);

await client.UnaryCall(new HelloRequest { Name = "Balancer" }).ResponseAsync.DefaultTimeout();
}

[Test]
public async Task UnaryCall_CallAfterCancellation_Success()
{
Expand Down