diff --git a/src/MongoDB.Driver.Core/Core/Operations/RetryabilityHelper.cs b/src/MongoDB.Driver.Core/Core/Operations/RetryabilityHelper.cs index bc5a683acc7..8e19e3465d4 100644 --- a/src/MongoDB.Driver.Core/Core/Operations/RetryabilityHelper.cs +++ b/src/MongoDB.Driver.Core/Core/Operations/RetryabilityHelper.cs @@ -16,7 +16,9 @@ using System; using System.Collections.Generic; using MongoDB.Bson; +using MongoDB.Driver.Core.Connections; using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Core.Servers; namespace MongoDB.Driver.Core.Operations { @@ -90,9 +92,9 @@ static RetryabilityHelper() } // public static methods - public static void AddRetryableWriteErrorLabelIfRequired(MongoException exception, int maxWireVersion) + public static void AddRetryableWriteErrorLabelIfRequired(MongoException exception, ConnectionDescription connectionDescription) { - if (ShouldRetryableWriteExceptionLabelBeAdded(exception, maxWireVersion)) + if (ShouldRetryableWriteExceptionLabelBeAdded(exception, connectionDescription)) { exception.AddErrorLabel(RetryableWriteErrorLabel); } @@ -172,18 +174,22 @@ private static bool IsNetworkException(Exception exception) return exception is MongoConnectionException mongoConnectionException && mongoConnectionException.IsNetworkException; } - private static bool ShouldRetryableWriteExceptionLabelBeAdded(Exception exception, int maxWireVersion) + private static bool ShouldRetryableWriteExceptionLabelBeAdded(Exception exception, ConnectionDescription connectionDescription) { if (IsNetworkException(exception)) { return true; } + var maxWireVersion = connectionDescription.MaxWireVersion; if (Feature.ServerReturnsRetryableWriteErrorLabel.IsSupported(maxWireVersion)) { return false; } + // on all servers from 4.4 on we would have returned false in the previous if statement + // so from this point on we know we are connected to a pre 4.4 server + if (__retryableWriteExceptions.Contains(exception.GetType())) { return true; @@ -199,29 +205,33 @@ private static bool ShouldRetryableWriteExceptionLabelBeAdded(Exception exceptio } } - var writeConcernException = exception as MongoWriteConcernException; - if (writeConcernException != null) + var serverType = connectionDescription.HelloResult.ServerType; + if (serverType != ServerType.ShardRouter) { - var writeConcernError = writeConcernException.WriteConcernResult.Response.GetValue("writeConcernError", null)?.AsBsonDocument; - if (writeConcernError != null) + var writeConcernException = exception as MongoWriteConcernException; + if (writeConcernException != null) { - var code = (ServerErrorCode)writeConcernError.GetValue("code", -1).AsInt32; - switch (code) + var writeConcernError = writeConcernException.WriteConcernResult.Response.GetValue("writeConcernError", null)?.AsBsonDocument; + if (writeConcernError != null) { - case ServerErrorCode.InterruptedAtShutdown: - case ServerErrorCode.InterruptedDueToReplStateChange: - case ServerErrorCode.LegacyNotPrimary: - case ServerErrorCode.NotWritablePrimary: - case ServerErrorCode.NotPrimaryNoSecondaryOk: - case ServerErrorCode.NotPrimaryOrSecondary: - case ServerErrorCode.PrimarySteppedDown: - case ServerErrorCode.ShutdownInProgress: - case ServerErrorCode.HostNotFound: - case ServerErrorCode.HostUnreachable: - case ServerErrorCode.NetworkTimeout: - case ServerErrorCode.SocketException: - case ServerErrorCode.ExceededTimeLimit: - return true; + var code = (ServerErrorCode)writeConcernError.GetValue("code", -1).AsInt32; + switch (code) + { + case ServerErrorCode.InterruptedAtShutdown: + case ServerErrorCode.InterruptedDueToReplStateChange: + case ServerErrorCode.LegacyNotPrimary: + case ServerErrorCode.NotWritablePrimary: + case ServerErrorCode.NotPrimaryNoSecondaryOk: + case ServerErrorCode.NotPrimaryOrSecondary: + case ServerErrorCode.PrimarySteppedDown: + case ServerErrorCode.ShutdownInProgress: + case ServerErrorCode.HostNotFound: + case ServerErrorCode.HostUnreachable: + case ServerErrorCode.NetworkTimeout: + case ServerErrorCode.SocketException: + case ServerErrorCode.ExceededTimeLimit: + return true; + } } } } diff --git a/src/MongoDB.Driver.Core/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs b/src/MongoDB.Driver.Core/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs index 330533ad427..eb02dc57126 100644 --- a/src/MongoDB.Driver.Core/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs +++ b/src/MongoDB.Driver.Core/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs @@ -147,7 +147,7 @@ public TCommandResult Execute(IConnection connection, CancellationToken cancella } catch (Exception exception) { - AddErrorLabelIfRequired(exception, connection.Description?.MaxWireVersion); + AddErrorLabelIfRequired(exception, connection.Description); TransactionHelper.UnpinServerIfNeededOnCommandException(_session, exception); throw; @@ -201,7 +201,7 @@ public async Task ExecuteAsync(IConnection connection, Cancellat } catch (Exception exception) { - AddErrorLabelIfRequired(exception, connection.Description?.MaxWireVersion); + AddErrorLabelIfRequired(exception, connection.Description); TransactionHelper.UnpinServerIfNeededOnCommandException(_session, exception); throw; @@ -209,7 +209,7 @@ public async Task ExecuteAsync(IConnection connection, Cancellat } // private methods - private void AddErrorLabelIfRequired(Exception exception, int? maxWireVersion) + private void AddErrorLabelIfRequired(Exception exception, ConnectionDescription connectionDescription) { if (exception is MongoException mongoException) { @@ -218,9 +218,9 @@ private void AddErrorLabelIfRequired(Exception exception, int? maxWireVersion) mongoException.AddErrorLabel("TransientTransactionError"); } - if (RetryabilityHelper.IsCommandRetryable(_command) && maxWireVersion.HasValue) + if (RetryabilityHelper.IsCommandRetryable(_command) && connectionDescription != null) { - RetryabilityHelper.AddRetryableWriteErrorLabelIfRequired(mongoException, maxWireVersion.Value); + RetryabilityHelper.AddRetryableWriteErrorLabelIfRequired(mongoException, connectionDescription); } } } diff --git a/tests/MongoDB.Driver.Core.Tests/Core/Operations/RetryabilityHelperTests.cs b/tests/MongoDB.Driver.Core.Tests/Core/Operations/RetryabilityHelperTests.cs index 779c694e354..f4e05118945 100644 --- a/tests/MongoDB.Driver.Core.Tests/Core/Operations/RetryabilityHelperTests.cs +++ b/tests/MongoDB.Driver.Core.Tests/Core/Operations/RetryabilityHelperTests.cs @@ -21,6 +21,7 @@ using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Core.TestHelpers; using Xunit; +using MongoDB.Driver.Core.Connections; namespace MongoDB.Driver.Core.Operations { @@ -45,8 +46,10 @@ public class RetryabilityHelperTests public void AddRetryableWriteErrorLabelIfRequired_should_add_RetryableWriteError_for_MongoWriteConcernException_when_required(int errorCode, bool shouldAddErrorLabel) { var exception = CoreExceptionHelper.CreateMongoWriteConcernException(BsonDocument.Parse($"{{ writeConcernError : {{ code : {errorCode} }} }}")); + var maxWireVersion = Feature.ServerReturnsRetryableWriteErrorLabel.LastNotSupportedWireVersion; + var connectionDescription = OperationTestHelper.CreateConnectionDescription(maxWireVersion); - RetryabilityHelper.AddRetryableWriteErrorLabelIfRequired(exception, Feature.ServerReturnsRetryableWriteErrorLabel.LastNotSupportedWireVersion); + RetryabilityHelper.AddRetryableWriteErrorLabelIfRequired(exception, connectionDescription); var hasRetryableWriteErrorLabel = exception.HasErrorLabel("RetryableWriteError"); hasRetryableWriteErrorLabel.Should().Be(shouldAddErrorLabel); @@ -59,8 +62,9 @@ public void AddRetryableWriteErrorLabelIfRequired_should_add_RetryableWriteError var exception = (MongoException)CoreExceptionHelper.CreateException(typeof(MongoConnectionException)); var feature = Feature.ServerReturnsRetryableWriteErrorLabel; var wireVersion = serverReturnsRetryableWriteErrorLabel ? feature.FirstSupportedWireVersion : feature.LastNotSupportedWireVersion; + var connectionDescription = OperationTestHelper.CreateConnectionDescription(wireVersion); - RetryabilityHelper.AddRetryableWriteErrorLabelIfRequired(exception, wireVersion); + RetryabilityHelper.AddRetryableWriteErrorLabelIfRequired(exception, connectionDescription); var hasRetryableWriteErrorLabel = exception.HasErrorLabel("RetryableWriteError"); hasRetryableWriteErrorLabel.Should().BeTrue(); @@ -89,8 +93,10 @@ public void AddRetryableWriteErrorLabelIfRequired_should_add_RetryableWriteError { exception = CoreExceptionHelper.CreateMongoCommandException((int)exceptionDescription); } + var maxWireVersion = Feature.ServerReturnsRetryableWriteErrorLabel.LastNotSupportedWireVersion; + var connectionDescription = OperationTestHelper.CreateConnectionDescription(maxWireVersion); - RetryabilityHelper.AddRetryableWriteErrorLabelIfRequired(exception, Feature.ServerReturnsRetryableWriteErrorLabel.LastNotSupportedWireVersion); + RetryabilityHelper.AddRetryableWriteErrorLabelIfRequired(exception, connectionDescription); var hasRetryableWriteErrorLabel = exception.HasErrorLabel("RetryableWriteError"); hasRetryableWriteErrorLabel.Should().Be(shouldAddErrorLabel);