diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java index 5c8cb78e20a..c9dc443495f 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java @@ -40,6 +40,9 @@ import io.grpc.xds.XdsClient.ClusterWatcher; import io.grpc.xds.XdsLoadBalancerProvider.XdsConfig; import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; +import io.grpc.xds.sds.SslContextProvider; +import io.grpc.xds.sds.TlsContextManager; +import io.grpc.xds.sds.TlsContextManagerImpl; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -55,6 +58,7 @@ public final class CdsLoadBalancer extends LoadBalancer { private final LoadBalancerRegistry lbRegistry; private final GracefulSwitchLoadBalancer switchingLoadBalancer; private final Helper helper; + private final TlsContextManager tlsContextManager; // The following fields become non-null once handleResolvedAddresses() successfully. @@ -70,15 +74,17 @@ public final class CdsLoadBalancer extends LoadBalancer { private XdsClient xdsClient; CdsLoadBalancer(Helper helper) { - this(helper, LoadBalancerRegistry.getDefaultRegistry()); + this(helper, LoadBalancerRegistry.getDefaultRegistry(), TlsContextManagerImpl.getInstance()); } @VisibleForTesting - CdsLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry) { + CdsLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry, + TlsContextManager tlsContextManager) { this.helper = helper; this.channelLogger = helper.getChannelLogger(); this.lbRegistry = lbRegistry; this.switchingLoadBalancer = new GracefulSwitchLoadBalancer(helper); + this.tlsContextManager = tlsContextManager; } @Override @@ -236,22 +242,23 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { private static final class EdsLoadBalancingHelper extends ForwardingLoadBalancerHelper { private final Helper delegate; - private final AtomicReference upstreamTlsContext; + private final AtomicReference> sslContextProvider; - EdsLoadBalancingHelper(Helper helper, AtomicReference upstreamTlsContext) { + EdsLoadBalancingHelper(Helper helper, + AtomicReference> sslContextProvider) { this.delegate = helper; - this.upstreamTlsContext = upstreamTlsContext; + this.sslContextProvider = sslContextProvider; } @Override public Subchannel createSubchannel(CreateSubchannelArgs createSubchannelArgs) { - if (upstreamTlsContext.get() != null) { + if (sslContextProvider.get() != null) { createSubchannelArgs = createSubchannelArgs .toBuilder() .setAddresses( addUpstreamTlsContext(createSubchannelArgs.getAddresses(), - upstreamTlsContext.get())) + sslContextProvider.get().getSource())) .build(); } return delegate.createSubchannel(createSubchannelArgs); @@ -295,7 +302,8 @@ private final class ClusterWatcherImpl implements ClusterWatcher { LoadBalancer edsBalancer; ClusterWatcherImpl(Helper helper, ResolvedAddresses resolvedAddresses) { - this.helper = new EdsLoadBalancingHelper(helper, new AtomicReference()); + this.helper = new EdsLoadBalancingHelper(helper, + new AtomicReference>()); this.resolvedAddresses = resolvedAddresses; } @@ -312,8 +320,7 @@ public void onClusterChanged(ClusterUpdate newUpdate) { /* fallbackPolicy = */ null, /* edsServiceName = */ newUpdate.getEdsServiceName(), /* lrsServerName = */ newUpdate.getLrsServerName()); - UpstreamTlsContext upstreamTlsContext = newUpdate.getUpstreamTlsContext(); - helper.upstreamTlsContext.set(upstreamTlsContext); + updateSslContextProvider(newUpdate.getUpstreamTlsContext()); if (edsBalancer == null) { edsBalancer = lbRegistry.getProvider(XDS_POLICY_NAME).newLoadBalancer(helper); } @@ -327,6 +334,27 @@ public void onClusterChanged(ClusterUpdate newUpdate) { .build()); } + /** For new UpstreamTlsContext value, release old SslContextProvider. */ + private void updateSslContextProvider(UpstreamTlsContext newUpstreamTlsContext) { + SslContextProvider oldSslContextProvider = + helper.sslContextProvider.get(); + if (oldSslContextProvider != null) { + UpstreamTlsContext oldUpstreamTlsContext = oldSslContextProvider.getSource(); + + if (oldUpstreamTlsContext.equals(newUpstreamTlsContext)) { + return; + } + tlsContextManager.releaseClientSslContextProvider(oldSslContextProvider); + } + if (newUpstreamTlsContext != null) { + SslContextProvider newSslContextProvider = + tlsContextManager.findOrCreateClientSslContextProvider(newUpstreamTlsContext); + helper.sslContextProvider.set(newSslContextProvider); + } else { + helper.sslContextProvider.set(null); + } + } + @Override public void onError(Status error) { channelLogger.log(ChannelLogLevel.ERROR, "CDS load balancer received an error: {0}", error); diff --git a/xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java index d3e63686d28..a00afbd9dfb 100644 --- a/xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java @@ -55,7 +55,7 @@ protected SslContextProvider(K source, boolean server) { this.server = server; } - K getSource() { + public K getSource() { return source; } diff --git a/xds/src/main/java/io/grpc/xds/sds/TlsContextManager.java b/xds/src/main/java/io/grpc/xds/sds/TlsContextManager.java index ee3a6c6f108..eeabc91a689 100644 --- a/xds/src/main/java/io/grpc/xds/sds/TlsContextManager.java +++ b/xds/src/main/java/io/grpc/xds/sds/TlsContextManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 The gRPC Authors + * Copyright 2020 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,63 +16,20 @@ package io.grpc.xds.sds; -import static com.google.common.base.Preconditions.checkNotNull; - -import com.google.common.annotations.VisibleForTesting; import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; import io.grpc.Internal; -import io.grpc.xds.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory; -/** - * Class to manage {@link SslContextProvider} objects created from inputs we get from xDS. Used by - * gRPC-xds to access the SslContext's and is not public API. This manager manages the life-cycle of - * {@link SslContextProvider} objects as shared resources via ref-counting as described in {@link - * ReferenceCountingSslContextProviderMap}. - */ @Internal -public final class TlsContextManager { - - private static TlsContextManager instance; - - private final ReferenceCountingSslContextProviderMap mapForClients; - private final ReferenceCountingSslContextProviderMap mapForServers; - - private TlsContextManager() { - this(new ClientSslContextProviderFactory(), new ServerSslContextProviderFactory()); - } - - @VisibleForTesting - TlsContextManager( - SslContextProviderFactory clientFactory, - SslContextProviderFactory serverFactory) { - checkNotNull(clientFactory, "clientFactory"); - checkNotNull(serverFactory, "serverFactory"); - mapForClients = new ReferenceCountingSslContextProviderMap<>(clientFactory); - mapForServers = new ReferenceCountingSslContextProviderMap<>(serverFactory); - } - - /** Gets the TlsContextManager singleton. */ - public static synchronized TlsContextManager getInstance() { - if (instance == null) { - instance = new TlsContextManager(); - } - return instance; - } +public interface TlsContextManager { /** Creates a SslContextProvider. Used for retrieving a server-side SslContext. */ - public SslContextProvider findOrCreateServerSslContextProvider( - DownstreamTlsContext downstreamTlsContext) { - checkNotNull(downstreamTlsContext, "downstreamTlsContext"); - return mapForServers.get(downstreamTlsContext); - } + SslContextProvider findOrCreateServerSslContextProvider( + DownstreamTlsContext downstreamTlsContext); /** Creates a SslContextProvider. Used for retrieving a client-side SslContext. */ - public SslContextProvider findOrCreateClientSslContextProvider( - UpstreamTlsContext upstreamTlsContext) { - checkNotNull(upstreamTlsContext, "upstreamTlsContext"); - return mapForClients.get(upstreamTlsContext); - } + SslContextProvider findOrCreateClientSslContextProvider( + UpstreamTlsContext upstreamTlsContext); /** * Releases an instance of the given client-side {@link SslContextProvider}. @@ -83,11 +40,8 @@ public SslContextProvider findOrCreateClientSslContextProvid *

Caller must not release a reference more than once. It's advised that you clear the * reference to the instance with the null returned by this method. */ - public SslContextProvider releaseClientSslContextProvider( - SslContextProvider sslContextProvider) { - checkNotNull(sslContextProvider, "sslContextProvider"); - return mapForClients.release(sslContextProvider); - } + SslContextProvider releaseClientSslContextProvider( + SslContextProvider sslContextProvider); /** * Releases an instance of the given server-side {@link SslContextProvider}. @@ -98,9 +52,6 @@ public SslContextProvider releaseClientSslContextProvider( *

Caller must not release a reference more than once. It's advised that you clear the * reference to the instance with the null returned by this method. */ - public SslContextProvider releaseServerSslContextProvider( - SslContextProvider sslContextProvider) { - checkNotNull(sslContextProvider, "sslContextProvider"); - return mapForServers.release(sslContextProvider); - } + SslContextProvider releaseServerSslContextProvider( + SslContextProvider sslContextProvider); } diff --git a/xds/src/main/java/io/grpc/xds/sds/TlsContextManagerImpl.java b/xds/src/main/java/io/grpc/xds/sds/TlsContextManagerImpl.java new file mode 100644 index 00000000000..934172d0529 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/sds/TlsContextManagerImpl.java @@ -0,0 +1,90 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.sds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext; +import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; +import io.grpc.Internal; +import io.grpc.xds.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory; + +/** + * Class to manage {@link SslContextProvider} objects created from inputs we get from xDS. Used by + * gRPC-xds to access the SslContext's and is not public API. This manager manages the life-cycle of + * {@link SslContextProvider} objects as shared resources via ref-counting as described in {@link + * ReferenceCountingSslContextProviderMap}. + */ +@Internal +public final class TlsContextManagerImpl implements TlsContextManager { + + private static TlsContextManagerImpl instance; + + private final ReferenceCountingSslContextProviderMap mapForClients; + private final ReferenceCountingSslContextProviderMap mapForServers; + + private TlsContextManagerImpl() { + this(new ClientSslContextProviderFactory(), new ServerSslContextProviderFactory()); + } + + @VisibleForTesting + TlsContextManagerImpl( + SslContextProviderFactory clientFactory, + SslContextProviderFactory serverFactory) { + checkNotNull(clientFactory, "clientFactory"); + checkNotNull(serverFactory, "serverFactory"); + mapForClients = new ReferenceCountingSslContextProviderMap<>(clientFactory); + mapForServers = new ReferenceCountingSslContextProviderMap<>(serverFactory); + } + + /** Gets the TlsContextManagerImpl singleton. */ + public static synchronized TlsContextManagerImpl getInstance() { + if (instance == null) { + instance = new TlsContextManagerImpl(); + } + return instance; + } + + @Override + public SslContextProvider findOrCreateServerSslContextProvider( + DownstreamTlsContext downstreamTlsContext) { + checkNotNull(downstreamTlsContext, "downstreamTlsContext"); + return mapForServers.get(downstreamTlsContext); + } + + @Override + public SslContextProvider findOrCreateClientSslContextProvider( + UpstreamTlsContext upstreamTlsContext) { + checkNotNull(upstreamTlsContext, "upstreamTlsContext"); + return mapForClients.get(upstreamTlsContext); + } + + @Override + public SslContextProvider releaseClientSslContextProvider( + SslContextProvider sslContextProvider) { + checkNotNull(sslContextProvider, "sslContextProvider"); + return mapForClients.release(sslContextProvider); + } + + @Override + public SslContextProvider releaseServerSslContextProvider( + SslContextProvider sslContextProvider) { + checkNotNull(sslContextProvider, "sslContextProvider"); + return mapForServers.release(sslContextProvider); + } +} diff --git a/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java index 7cf4e937efc..9419591c80a 100644 --- a/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java @@ -31,7 +31,7 @@ import io.grpc.netty.NettyChannelBuilder; import io.grpc.xds.XdsAttributes; import io.grpc.xds.sds.SslContextProvider; -import io.grpc.xds.sds.TlsContextManager; +import io.grpc.xds.sds.TlsContextManagerImpl; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; @@ -207,8 +207,9 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { final BufferReadsHandler bufferReads = new BufferReadsHandler(); ctx.pipeline().addBefore(ctx.name(), null, bufferReads); - SslContextProvider sslContextProvider = - TlsContextManager.getInstance().findOrCreateClientSslContextProvider(upstreamTlsContext); + final SslContextProvider sslContextProvider = + TlsContextManagerImpl.getInstance() + .findOrCreateClientSslContextProvider(upstreamTlsContext); sslContextProvider.addCallback( new SslContextProvider.Callback() { @@ -226,6 +227,8 @@ public void updateSecret(SslContext sslContext) { ctx.pipeline().addAfter(ctx.name(), null, handler); fireProtocolNegotiationEvent(ctx); ctx.pipeline().remove(bufferReads); + TlsContextManagerImpl.getInstance() + .releaseClientSslContextProvider(sslContextProvider); } @Override @@ -303,8 +306,8 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { final BufferReadsHandler bufferReads = new BufferReadsHandler(); ctx.pipeline().addBefore(ctx.name(), null, bufferReads); - SslContextProvider sslContextProvider = - TlsContextManager.getInstance() + final SslContextProvider sslContextProvider = + TlsContextManagerImpl.getInstance() .findOrCreateServerSslContextProvider(downstreamTlsContext); sslContextProvider.addCallback( @@ -319,6 +322,8 @@ public void updateSecret(SslContext sslContext) { ctx.pipeline().addAfter(ctx.name(), null, handler); fireProtocolNegotiationEvent(ctx); ctx.pipeline().remove(bufferReads); + TlsContextManagerImpl.getInstance() + .releaseServerSslContextProvider(sslContextProvider); } @Override diff --git a/xds/src/main/java/io/grpc/xds/sds/trust/SdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/sds/trust/SdsTrustManagerFactory.java index e3a2274b269..3f2b9941122 100644 --- a/xds/src/main/java/io/grpc/xds/sds/trust/SdsTrustManagerFactory.java +++ b/xds/src/main/java/io/grpc/xds/sds/trust/SdsTrustManagerFactory.java @@ -41,8 +41,8 @@ import javax.net.ssl.X509ExtendedTrustManager; /** - * Factory class used by providers of {@link io.grpc.xds.sds.TlsContextManager} to provide a {@link - * SdsX509TrustManager} for trust and SAN checks. + * Factory class used by providers of {@link io.grpc.xds.sds.TlsContextManagerImpl} to provide a + * {@link SdsX509TrustManager} for trust and SAN checks. */ @Internal public final class SdsTrustManagerFactory extends SimpleTrustManagerFactory { diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java index 52931129262..52cfe0a9720 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java @@ -23,23 +23,24 @@ import static io.grpc.xds.XdsLoadBalancerProvider.XDS_POLICY_NAME; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext; -import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig; import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext; import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.SubchannelPicker; @@ -57,6 +58,9 @@ import io.grpc.xds.XdsClient.RefCountedXdsClientObjectPool; import io.grpc.xds.XdsClient.XdsClientFactory; import io.grpc.xds.XdsLoadBalancerProvider.XdsConfig; +import io.grpc.xds.sds.SecretVolumeSslContextProviderTest; +import io.grpc.xds.sds.SslContextProvider; +import io.grpc.xds.sds.TlsContextManager; import java.net.InetSocketAddress; import java.util.ArrayDeque; import java.util.ArrayList; @@ -78,6 +82,11 @@ // TODO(creamsoup) use parsed service config @SuppressWarnings("deprecation") public class CdsLoadBalancerTest { + private static final String CLIENT_PEM_FILE = "client.pem"; + private static final String CLIENT_KEY_FILE = "client.key"; + private static final String BADCLIENT_PEM_FILE = "badclient.pem"; + private static final String BADCLIENT_KEY_FILE = "badclient.key"; + private static final String CA_PEM_FILE = "ca.pem"; private final RefCountedXdsClientObjectPool xdsClientPool = new RefCountedXdsClientObjectPool( new XdsClientFactory() { @@ -135,6 +144,8 @@ public void uncaughtException(Thread t, Throwable e) { private LoadBalancer cdsLoadBalancer; private XdsClient xdsClient; + @Mock + private TlsContextManager mockTlsContextManager; @Before public void setUp() { @@ -144,7 +155,7 @@ public void setUp() { doReturn(syncContext).when(helper).getSynchronizationContext(); doReturn(fakeClock.getScheduledExecutorService()).when(helper).getScheduledExecutorService(); lbRegistry.register(fakeXdsLoadBlancerProvider); - cdsLoadBalancer = new CdsLoadBalancer(helper, lbRegistry); + cdsLoadBalancer = new CdsLoadBalancer(helper, lbRegistry, mockTlsContextManager); } @Test @@ -331,6 +342,7 @@ public void handleCdsConfigs() throws Exception { } @Test + @SuppressWarnings({"unchecked"}) public void handleCdsConfigs_withUpstreamTlsContext() throws Exception { assertThat(xdsClient).isNull(); @@ -352,14 +364,14 @@ public void handleCdsConfigs_withUpstreamTlsContext() throws Exception { verify(xdsClient).watchClusterData(eq("foo.googleapis.com"), clusterWatcherCaptor1.capture()); UpstreamTlsContext upstreamTlsContext = - UpstreamTlsContext.newBuilder() - .setCommonTlsContext( - CommonTlsContext.newBuilder() - .addTlsCertificateSdsSecretConfigs( - SdsSecretConfig.newBuilder().setName("cert-sds-name")) - .setValidationContextSdsSecretConfig( - SdsSecretConfig.newBuilder().setName("valid-sds-name"))) - .build(); + SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames( + CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); + + SslContextProvider mockSslContextProvider = + (SslContextProvider) mock(SslContextProvider.class); + doReturn(upstreamTlsContext).when(mockSslContextProvider).getSource(); + doReturn(mockSslContextProvider).when(mockTlsContextManager) + .findOrCreateClientSslContextProvider(same(upstreamTlsContext)); ClusterWatcher clusterWatcher1 = clusterWatcherCaptor1.getValue(); clusterWatcher1.onClusterChanged( @@ -373,6 +385,8 @@ public void handleCdsConfigs_withUpstreamTlsContext() throws Exception { assertThat(edsLbHelpers).hasSize(1); assertThat(edsLoadBalancers).hasSize(1); + verify(mockTlsContextManager, never()).releaseClientSslContextProvider( + (SslContextProvider) any(SslContextProvider.class)); Helper edsLbHelper1 = edsLbHelpers.poll(); ArrayList eagList = new ArrayList<>(); @@ -388,8 +402,86 @@ public void handleCdsConfigs_withUpstreamTlsContext() throws Exception { verify(helper, never()) .createSubchannel(any(LoadBalancer.CreateSubchannelArgs.class)); edsLbHelper1.createSubchannel(createSubchannelArgs); + verifyUpstreamTlsContextAttribute(upstreamTlsContext, + createSubchannelArgsCaptor1); + + // update with same upstreamTlsContext + reset(mockTlsContextManager); + clusterWatcher1.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("bar.googleapis.com") + .setEdsServiceName("eds1ServiceFoo.googleapis.com") + .setLbPolicy("round_robin") + .setEnableLrs(false) + .setUpstreamTlsContext(upstreamTlsContext) + .build()); + + verify(mockTlsContextManager, never()).releaseClientSslContextProvider( + (SslContextProvider) any(SslContextProvider.class)); + verify(mockTlsContextManager, never()).findOrCreateClientSslContextProvider( + any(UpstreamTlsContext.class)); + + // update with different upstreamTlsContext + reset(mockTlsContextManager); + reset(helper); + UpstreamTlsContext upstreamTlsContext1 = + SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames( + BADCLIENT_KEY_FILE, BADCLIENT_PEM_FILE, CA_PEM_FILE); + SslContextProvider mockSslContextProvider1 = + (SslContextProvider) mock(SslContextProvider.class); + doReturn(upstreamTlsContext1).when(mockSslContextProvider1).getSource(); + doReturn(mockSslContextProvider1).when(mockTlsContextManager) + .findOrCreateClientSslContextProvider(same(upstreamTlsContext1)); + clusterWatcher1.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("bar.googleapis.com") + .setEdsServiceName("eds1ServiceFoo.googleapis.com") + .setLbPolicy("round_robin") + .setEnableLrs(false) + .setUpstreamTlsContext(upstreamTlsContext1) + .build()); + + verify(mockTlsContextManager).releaseClientSslContextProvider(same(mockSslContextProvider)); + verify(mockTlsContextManager).findOrCreateClientSslContextProvider(same(upstreamTlsContext1)); + ArgumentCaptor createSubchannelArgsCaptor2 = + ArgumentCaptor.forClass(null); + edsLbHelper1.createSubchannel(createSubchannelArgs); + verifyUpstreamTlsContextAttribute(upstreamTlsContext1, + createSubchannelArgsCaptor2); + + // update with null + reset(mockTlsContextManager); + reset(helper); + clusterWatcher1.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("bar.googleapis.com") + .setEdsServiceName("eds1ServiceFoo.googleapis.com") + .setLbPolicy("round_robin") + .setEnableLrs(false) + .setUpstreamTlsContext(null) + .build()); + verify(mockTlsContextManager).releaseClientSslContextProvider(same(mockSslContextProvider1)); + verify(mockTlsContextManager, never()).findOrCreateClientSslContextProvider( + any(UpstreamTlsContext.class)); + ArgumentCaptor createSubchannelArgsCaptor3 = + ArgumentCaptor.forClass(null); + edsLbHelper1.createSubchannel(createSubchannelArgs); + verifyUpstreamTlsContextAttribute(null, + createSubchannelArgsCaptor3); + + LoadBalancer edsLoadBalancer1 = edsLoadBalancers.poll(); + + cdsLoadBalancer.shutdown(); + verify(edsLoadBalancer1).shutdown(); + verify(xdsClient).cancelClusterDataWatch("foo.googleapis.com", clusterWatcher1); + assertThat(xdsClientPool.xdsClient).isNull(); + } + + private void verifyUpstreamTlsContextAttribute( + UpstreamTlsContext upstreamTlsContext, + ArgumentCaptor createSubchannelArgsCaptor1) { verify(helper, times(1)).createSubchannel(createSubchannelArgsCaptor1.capture()); - LoadBalancer.CreateSubchannelArgs capturedValue = createSubchannelArgsCaptor1.getValue(); + CreateSubchannelArgs capturedValue = createSubchannelArgsCaptor1.getValue(); List capturedEagList = capturedValue.getAddresses(); assertThat(capturedEagList.size()).isEqualTo(2); EquivalentAddressGroup capturedEag = capturedEagList.get(0); @@ -402,13 +494,6 @@ public void handleCdsConfigs_withUpstreamTlsContext() throws Exception { assertThat(capturedUpstreamTlsContext).isSameInstanceAs(upstreamTlsContext); assertThat(capturedEag.getAttributes().get(XdsAttributes.XDS_CLIENT_POOL)) .isSameInstanceAs(xdsClientPool); - - LoadBalancer edsLoadBalancer1 = edsLoadBalancers.poll(); - - cdsLoadBalancer.shutdown(); - verify(edsLoadBalancer1).shutdown(); - verify(xdsClient).cancelClusterDataWatch("foo.googleapis.com", clusterWatcher1); - assertThat(xdsClientPool.xdsClient).isNull(); } @Test diff --git a/xds/src/test/java/io/grpc/xds/sds/SecretVolumeSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/sds/SecretVolumeSslContextProviderTest.java index 826f05a954e..c2df8e62d4a 100644 --- a/xds/src/test/java/io/grpc/xds/sds/SecretVolumeSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/sds/SecretVolumeSslContextProviderTest.java @@ -426,7 +426,7 @@ static DownstreamTlsContext buildDownstreamTlsContextFromFilenames( /** * Helper method to build UpstreamTlsContext for above tests. Called from other classes as well. */ - static UpstreamTlsContext buildUpstreamTlsContextFromFilenames( + public static UpstreamTlsContext buildUpstreamTlsContextFromFilenames( String privateKey, String certChain, String trustCa) { return buildUpstreamTlsContext( buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa)); diff --git a/xds/src/test/java/io/grpc/xds/sds/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/sds/TlsContextManagerTest.java index c6f190cdceb..d736f77c956 100644 --- a/xds/src/test/java/io/grpc/xds/sds/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/sds/TlsContextManagerTest.java @@ -36,7 +36,7 @@ import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; -/** Unit tests for {@link TlsContextManager}. */ +/** Unit tests for {@link TlsContextManagerImpl}. */ @RunWith(JUnit4.class) public class TlsContextManagerTest { @@ -58,7 +58,7 @@ public class TlsContextManagerTest { @Before public void clearInstance() throws NoSuchFieldException, IllegalAccessException { - Field field = TlsContextManager.class.getDeclaredField("instance"); + Field field = TlsContextManagerImpl.class.getDeclaredField("instance"); field.setAccessible(true); field.set(null, null); } @@ -69,13 +69,13 @@ public void createServerSslContextProvider() { SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames( SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); - TlsContextManager tlsContextManager = TlsContextManager.getInstance(); + TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); SslContextProvider serverSecretProvider = - tlsContextManager.findOrCreateServerSslContextProvider(downstreamTlsContext); + tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); assertThat(serverSecretProvider).isNotNull(); SslContextProvider serverSecretProvider1 = - tlsContextManager.findOrCreateServerSslContextProvider(downstreamTlsContext); + tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); assertThat(serverSecretProvider1).isSameInstanceAs(serverSecretProvider); } @@ -85,13 +85,13 @@ public void createClientSslContextProvider() { SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames( /* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); - TlsContextManager tlsContextManager = TlsContextManager.getInstance(); + TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); SslContextProvider clientSecretProvider = - tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); assertThat(clientSecretProvider).isNotNull(); SslContextProvider clientSecretProvider1 = - tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); assertThat(clientSecretProvider1).isSameInstanceAs(clientSecretProvider); } @@ -101,16 +101,16 @@ public void createServerSslContextProvider_differentInstance() { SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames( SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); - TlsContextManager tlsContextManager = TlsContextManager.getInstance(); + TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); SslContextProvider serverSecretProvider = - tlsContextManager.findOrCreateServerSslContextProvider(downstreamTlsContext); + tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); assertThat(serverSecretProvider).isNotNull(); DownstreamTlsContext downstreamTlsContext1 = SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames( SERVER_0_KEY_FILE, SERVER_0_PEM_FILE, CA_PEM_FILE); SslContextProvider serverSecretProvider1 = - tlsContextManager.findOrCreateServerSslContextProvider(downstreamTlsContext1); + tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext1); assertThat(serverSecretProvider1).isNotNull(); assertThat(serverSecretProvider1).isNotSameInstanceAs(serverSecretProvider); } @@ -121,9 +121,9 @@ public void createClientSslContextProvider_differentInstance() { SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames( /* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); - TlsContextManager tlsContextManager = TlsContextManager.getInstance(); + TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); SslContextProvider clientSecretProvider = - tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); assertThat(clientSecretProvider).isNotNull(); UpstreamTlsContext upstreamTlsContext1 = @@ -131,7 +131,7 @@ public void createClientSslContextProvider_differentInstance() { CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); SslContextProvider clientSecretProvider1 = - tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext1); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext1); assertThat(clientSecretProvider1).isNotSameInstanceAs(clientSecretProvider); } @@ -141,17 +141,17 @@ public void createServerSslContextProvider_releaseInstance() { SecretVolumeSslContextProviderTest.buildDownstreamTlsContextFromFilenames( SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); - TlsContextManager tlsContextManager = - new TlsContextManager(mockClientFactory, mockServerFactory); + TlsContextManagerImpl tlsContextManagerImpl = + new TlsContextManagerImpl(mockClientFactory, mockServerFactory); @SuppressWarnings("unchecked") SslContextProvider mockProvider = mock(SslContextProvider.class); when(mockServerFactory.createSslContextProvider(downstreamTlsContext)).thenReturn(mockProvider); SslContextProvider serverSecretProvider = - tlsContextManager.findOrCreateServerSslContextProvider(downstreamTlsContext); + tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); assertThat(serverSecretProvider).isSameInstanceAs(mockProvider); verify(mockProvider, never()).close(); when(mockProvider.getSource()).thenReturn(downstreamTlsContext); - tlsContextManager.releaseServerSslContextProvider(mockProvider); + tlsContextManagerImpl.releaseServerSslContextProvider(mockProvider); verify(mockProvider, times(1)).close(); } @@ -161,17 +161,17 @@ public void createClientSslContextProvider_releaseInstance() { SecretVolumeSslContextProviderTest.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); - TlsContextManager tlsContextManager = - new TlsContextManager(mockClientFactory, mockServerFactory); + TlsContextManagerImpl tlsContextManagerImpl = + new TlsContextManagerImpl(mockClientFactory, mockServerFactory); @SuppressWarnings("unchecked") SslContextProvider mockProvider = mock(SslContextProvider.class); when(mockClientFactory.createSslContextProvider(upstreamTlsContext)).thenReturn(mockProvider); SslContextProvider clientSecretProvider = - tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); assertThat(clientSecretProvider).isSameInstanceAs(mockProvider); verify(mockProvider, never()).close(); when(mockProvider.getSource()).thenReturn(upstreamTlsContext); - tlsContextManager.releaseClientSslContextProvider(mockProvider); + tlsContextManagerImpl.releaseClientSslContextProvider(mockProvider); verify(mockProvider, times(1)).close(); } }