From 6ba6bbba555a526383b1c8585e0252390189c438 Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Fri, 9 Jul 2021 10:48:18 -0700 Subject: [PATCH] xds: fix the race condition in SslContextProviderSupplier's updateSslContext and close (#8294) --- .../sds/SslContextProviderSupplier.java | 11 +++-- .../sds/SslContextProviderSupplierTest.java | 47 +++++-------------- 2 files changed, 19 insertions(+), 39 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java index 3902569d873..3300c22b2bf 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java @@ -17,7 +17,6 @@ package io.grpc.xds.internal.sds; import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; @@ -56,13 +55,14 @@ public BaseTlsContext getTlsContext() { public synchronized void updateSslContext(final SslContextProvider.Callback callback) { checkNotNull(callback, "callback"); try { - checkState(!shutdown, "Supplier is shutdown!"); - if (sslContextProvider == null) { - sslContextProvider = getSslContextProvider(); + if (!shutdown) { + if (sslContextProvider == null) { + sslContextProvider = getSslContextProvider(); + } } // we want to increment the ref-count so call findOrCreate again... final SslContextProvider toRelease = getSslContextProvider(); - sslContextProvider.addCallback( + toRelease.addCallback( new SslContextProvider.Callback(callback.getExecutor()) { @Override @@ -115,6 +115,7 @@ public synchronized void close() { tlsContextManager.releaseServerSslContextProvider(sslContextProvider); } } + sslContextProvider = null; shutdown = true; } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java index ec2c85e5b8c..19fd0e189c1 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java @@ -23,16 +23,13 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import com.google.common.util.concurrent.MoreExecutors; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; import java.util.concurrent.Executor; -import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -91,11 +88,11 @@ public void get_updateSecret() { capturedCallback.updateSecret(mockSslContext); verify(mockCallback, times(1)).updateSecret(eq(mockSslContext)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); supplier.updateSslContext(mockCallback); verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); } @Test @@ -106,9 +103,11 @@ public void get_onException() { verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); - capturedCallback.onException(new Exception("test")); + Exception exception = new Exception("test"); + capturedCallback.onException(exception); + verify(mockCallback, times(1)).onException(eq(exception)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider)); } @Test @@ -118,20 +117,11 @@ public void testClose() { supplier.close(); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider)); - SslContextProvider.Callback mockCallback = spy( - new SslContextProvider.Callback(MoreExecutors.directExecutor()) { - @Override - public void updateSecret(SslContext sslContext) { - Assert.fail("unexpected call"); - } - - @Override - protected void onException(Throwable argument) { - assertThat(argument).isInstanceOf(IllegalStateException.class); - assertThat(argument).hasMessageThat().contains("Supplier is shutdown!"); - } - }); supplier.updateSslContext(mockCallback); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(any(SslContextProvider.class)); } @Test @@ -142,19 +132,8 @@ public void testClose_nullSslContextProvider() { supplier.close(); verify(mockTlsContextManager, never()) .releaseClientSslContextProvider(eq(mockSslContextProvider)); - SslContextProvider.Callback mockCallback = spy( - new SslContextProvider.Callback(MoreExecutors.directExecutor()) { - @Override - public void updateSecret(SslContext sslContext) { - Assert.fail("unexpected call"); - } - - @Override - protected void onException(Throwable argument) { - assertThat(argument).isInstanceOf(IllegalStateException.class); - assertThat(argument).hasMessageThat().contains("Supplier is shutdown!"); - } - }); - supplier.updateSslContext(mockCallback); + callUpdateSslContext(); + verify(mockTlsContextManager, times(1)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); } }