Skip to content

Commit

Permalink
xds: support for updating upstreamTlsContext and SslContextProvider, …
Browse files Browse the repository at this point in the history
…also release object in SdsProtocolNegotiators (grpc#6599)
  • Loading branch information
sanjaypujare committed Jan 14, 2020
1 parent 066e72d commit 04cf90a
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 119 deletions.
48 changes: 38 additions & 10 deletions xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java
Expand Up @@ -40,6 +40,9 @@
import io.grpc.xds.XdsClient.ClusterWatcher;
import io.grpc.xds.XdsLoadBalancerProvider.XdsConfig;
import io.grpc.xds.XdsSubchannelPickers.ErrorPicker;
import io.grpc.xds.sds.SslContextProvider;
import io.grpc.xds.sds.TlsContextManager;
import io.grpc.xds.sds.TlsContextManagerImpl;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -55,6 +58,7 @@ public final class CdsLoadBalancer extends LoadBalancer {
private final LoadBalancerRegistry lbRegistry;
private final GracefulSwitchLoadBalancer switchingLoadBalancer;
private final Helper helper;
private final TlsContextManager tlsContextManager;

// The following fields become non-null once handleResolvedAddresses() successfully.

Expand All @@ -70,15 +74,17 @@ public final class CdsLoadBalancer extends LoadBalancer {
private XdsClient xdsClient;

CdsLoadBalancer(Helper helper) {
this(helper, LoadBalancerRegistry.getDefaultRegistry());
this(helper, LoadBalancerRegistry.getDefaultRegistry(), TlsContextManagerImpl.getInstance());
}

@VisibleForTesting
CdsLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry) {
CdsLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry,
TlsContextManager tlsContextManager) {
this.helper = helper;
this.channelLogger = helper.getChannelLogger();
this.lbRegistry = lbRegistry;
this.switchingLoadBalancer = new GracefulSwitchLoadBalancer(helper);
this.tlsContextManager = tlsContextManager;
}

@Override
Expand Down Expand Up @@ -236,22 +242,23 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {

private static final class EdsLoadBalancingHelper extends ForwardingLoadBalancerHelper {
private final Helper delegate;
private final AtomicReference<UpstreamTlsContext> upstreamTlsContext;
private final AtomicReference<SslContextProvider<UpstreamTlsContext>> sslContextProvider;

EdsLoadBalancingHelper(Helper helper, AtomicReference<UpstreamTlsContext> upstreamTlsContext) {
EdsLoadBalancingHelper(Helper helper,
AtomicReference<SslContextProvider<UpstreamTlsContext>> sslContextProvider) {
this.delegate = helper;
this.upstreamTlsContext = upstreamTlsContext;
this.sslContextProvider = sslContextProvider;
}

@Override
public Subchannel createSubchannel(CreateSubchannelArgs createSubchannelArgs) {
if (upstreamTlsContext.get() != null) {
if (sslContextProvider.get() != null) {
createSubchannelArgs =
createSubchannelArgs
.toBuilder()
.setAddresses(
addUpstreamTlsContext(createSubchannelArgs.getAddresses(),
upstreamTlsContext.get()))
sslContextProvider.get().getSource()))
.build();
}
return delegate.createSubchannel(createSubchannelArgs);
Expand Down Expand Up @@ -295,7 +302,8 @@ private final class ClusterWatcherImpl implements ClusterWatcher {
LoadBalancer edsBalancer;

ClusterWatcherImpl(Helper helper, ResolvedAddresses resolvedAddresses) {
this.helper = new EdsLoadBalancingHelper(helper, new AtomicReference<UpstreamTlsContext>());
this.helper = new EdsLoadBalancingHelper(helper,
new AtomicReference<SslContextProvider<UpstreamTlsContext>>());
this.resolvedAddresses = resolvedAddresses;
}

Expand All @@ -312,8 +320,7 @@ public void onClusterChanged(ClusterUpdate newUpdate) {
/* fallbackPolicy = */ null,
/* edsServiceName = */ newUpdate.getEdsServiceName(),
/* lrsServerName = */ newUpdate.getLrsServerName());
UpstreamTlsContext upstreamTlsContext = newUpdate.getUpstreamTlsContext();
helper.upstreamTlsContext.set(upstreamTlsContext);
updateSslContextProvider(newUpdate.getUpstreamTlsContext());
if (edsBalancer == null) {
edsBalancer = lbRegistry.getProvider(XDS_POLICY_NAME).newLoadBalancer(helper);
}
Expand All @@ -327,6 +334,27 @@ public void onClusterChanged(ClusterUpdate newUpdate) {
.build());
}

/** For new UpstreamTlsContext value, release old SslContextProvider. */
private void updateSslContextProvider(UpstreamTlsContext newUpstreamTlsContext) {
SslContextProvider<UpstreamTlsContext> oldSslContextProvider =
helper.sslContextProvider.get();
if (oldSslContextProvider != null) {
UpstreamTlsContext oldUpstreamTlsContext = oldSslContextProvider.getSource();

if (oldUpstreamTlsContext.equals(newUpstreamTlsContext)) {
return;
}
tlsContextManager.releaseClientSslContextProvider(oldSslContextProvider);
}
if (newUpstreamTlsContext != null) {
SslContextProvider<UpstreamTlsContext> newSslContextProvider =
tlsContextManager.findOrCreateClientSslContextProvider(newUpstreamTlsContext);
helper.sslContextProvider.set(newSslContextProvider);
} else {
helper.sslContextProvider.set(null);
}
}

@Override
public void onError(Status error) {
channelLogger.log(ChannelLogLevel.ERROR, "CDS load balancer received an error: {0}", error);
Expand Down
2 changes: 1 addition & 1 deletion xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java
Expand Up @@ -55,7 +55,7 @@ protected SslContextProvider(K source, boolean server) {
this.server = server;
}

K getSource() {
public K getSource() {
return source;
}

Expand Down
69 changes: 10 additions & 59 deletions xds/src/main/java/io/grpc/xds/sds/TlsContextManager.java
@@ -1,5 +1,5 @@
/*
* Copyright 2019 The gRPC Authors
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,63 +16,20 @@

package io.grpc.xds.sds;

import static com.google.common.base.Preconditions.checkNotNull;

import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.Internal;
import io.grpc.xds.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory;

/**
* Class to manage {@link SslContextProvider} objects created from inputs we get from xDS. Used by
* gRPC-xds to access the SslContext's and is not public API. This manager manages the life-cycle of
* {@link SslContextProvider} objects as shared resources via ref-counting as described in {@link
* ReferenceCountingSslContextProviderMap}.
*/
@Internal
public final class TlsContextManager {

private static TlsContextManager instance;

private final ReferenceCountingSslContextProviderMap<UpstreamTlsContext> mapForClients;
private final ReferenceCountingSslContextProviderMap<DownstreamTlsContext> mapForServers;

private TlsContextManager() {
this(new ClientSslContextProviderFactory(), new ServerSslContextProviderFactory());
}

@VisibleForTesting
TlsContextManager(
SslContextProviderFactory<UpstreamTlsContext> clientFactory,
SslContextProviderFactory<DownstreamTlsContext> serverFactory) {
checkNotNull(clientFactory, "clientFactory");
checkNotNull(serverFactory, "serverFactory");
mapForClients = new ReferenceCountingSslContextProviderMap<>(clientFactory);
mapForServers = new ReferenceCountingSslContextProviderMap<>(serverFactory);
}

/** Gets the TlsContextManager singleton. */
public static synchronized TlsContextManager getInstance() {
if (instance == null) {
instance = new TlsContextManager();
}
return instance;
}
public interface TlsContextManager {

/** Creates a SslContextProvider. Used for retrieving a server-side SslContext. */
public SslContextProvider<DownstreamTlsContext> findOrCreateServerSslContextProvider(
DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
return mapForServers.get(downstreamTlsContext);
}
SslContextProvider<DownstreamTlsContext> findOrCreateServerSslContextProvider(
DownstreamTlsContext downstreamTlsContext);

/** Creates a SslContextProvider. Used for retrieving a client-side SslContext. */
public SslContextProvider<UpstreamTlsContext> findOrCreateClientSslContextProvider(
UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
return mapForClients.get(upstreamTlsContext);
}
SslContextProvider<UpstreamTlsContext> findOrCreateClientSslContextProvider(
UpstreamTlsContext upstreamTlsContext);

/**
* Releases an instance of the given client-side {@link SslContextProvider}.
Expand All @@ -83,11 +40,8 @@ public SslContextProvider<UpstreamTlsContext> findOrCreateClientSslContextProvid
* <p>Caller must not release a reference more than once. It's advised that you clear the
* reference to the instance with the null returned by this method.
*/
public SslContextProvider<UpstreamTlsContext> releaseClientSslContextProvider(
SslContextProvider<UpstreamTlsContext> sslContextProvider) {
checkNotNull(sslContextProvider, "sslContextProvider");
return mapForClients.release(sslContextProvider);
}
SslContextProvider<UpstreamTlsContext> releaseClientSslContextProvider(
SslContextProvider<UpstreamTlsContext> sslContextProvider);

/**
* Releases an instance of the given server-side {@link SslContextProvider}.
Expand All @@ -98,9 +52,6 @@ public SslContextProvider<UpstreamTlsContext> releaseClientSslContextProvider(
* <p>Caller must not release a reference more than once. It's advised that you clear the
* reference to the instance with the null returned by this method.
*/
public SslContextProvider<DownstreamTlsContext> releaseServerSslContextProvider(
SslContextProvider<DownstreamTlsContext> sslContextProvider) {
checkNotNull(sslContextProvider, "sslContextProvider");
return mapForServers.release(sslContextProvider);
}
SslContextProvider<DownstreamTlsContext> releaseServerSslContextProvider(
SslContextProvider<DownstreamTlsContext> sslContextProvider);
}
90 changes: 90 additions & 0 deletions xds/src/main/java/io/grpc/xds/sds/TlsContextManagerImpl.java
@@ -0,0 +1,90 @@
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc.xds.sds;

import static com.google.common.base.Preconditions.checkNotNull;

import com.google.common.annotations.VisibleForTesting;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.Internal;
import io.grpc.xds.sds.ReferenceCountingSslContextProviderMap.SslContextProviderFactory;

/**
* Class to manage {@link SslContextProvider} objects created from inputs we get from xDS. Used by
* gRPC-xds to access the SslContext's and is not public API. This manager manages the life-cycle of
* {@link SslContextProvider} objects as shared resources via ref-counting as described in {@link
* ReferenceCountingSslContextProviderMap}.
*/
@Internal
public final class TlsContextManagerImpl implements TlsContextManager {

private static TlsContextManagerImpl instance;

private final ReferenceCountingSslContextProviderMap<UpstreamTlsContext> mapForClients;
private final ReferenceCountingSslContextProviderMap<DownstreamTlsContext> mapForServers;

private TlsContextManagerImpl() {
this(new ClientSslContextProviderFactory(), new ServerSslContextProviderFactory());
}

@VisibleForTesting
TlsContextManagerImpl(
SslContextProviderFactory<UpstreamTlsContext> clientFactory,
SslContextProviderFactory<DownstreamTlsContext> serverFactory) {
checkNotNull(clientFactory, "clientFactory");
checkNotNull(serverFactory, "serverFactory");
mapForClients = new ReferenceCountingSslContextProviderMap<>(clientFactory);
mapForServers = new ReferenceCountingSslContextProviderMap<>(serverFactory);
}

/** Gets the TlsContextManagerImpl singleton. */
public static synchronized TlsContextManagerImpl getInstance() {
if (instance == null) {
instance = new TlsContextManagerImpl();
}
return instance;
}

@Override
public SslContextProvider<DownstreamTlsContext> findOrCreateServerSslContextProvider(
DownstreamTlsContext downstreamTlsContext) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
return mapForServers.get(downstreamTlsContext);
}

@Override
public SslContextProvider<UpstreamTlsContext> findOrCreateClientSslContextProvider(
UpstreamTlsContext upstreamTlsContext) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
return mapForClients.get(upstreamTlsContext);
}

@Override
public SslContextProvider<UpstreamTlsContext> releaseClientSslContextProvider(
SslContextProvider<UpstreamTlsContext> sslContextProvider) {
checkNotNull(sslContextProvider, "sslContextProvider");
return mapForClients.release(sslContextProvider);
}

@Override
public SslContextProvider<DownstreamTlsContext> releaseServerSslContextProvider(
SslContextProvider<DownstreamTlsContext> sslContextProvider) {
checkNotNull(sslContextProvider, "sslContextProvider");
return mapForServers.release(sslContextProvider);
}
}
Expand Up @@ -31,7 +31,7 @@
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.xds.XdsAttributes;
import io.grpc.xds.sds.SslContextProvider;
import io.grpc.xds.sds.TlsContextManager;
import io.grpc.xds.sds.TlsContextManagerImpl;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
Expand Down Expand Up @@ -207,8 +207,9 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) {
final BufferReadsHandler bufferReads = new BufferReadsHandler();
ctx.pipeline().addBefore(ctx.name(), null, bufferReads);

SslContextProvider<UpstreamTlsContext> sslContextProvider =
TlsContextManager.getInstance().findOrCreateClientSslContextProvider(upstreamTlsContext);
final SslContextProvider<UpstreamTlsContext> sslContextProvider =
TlsContextManagerImpl.getInstance()
.findOrCreateClientSslContextProvider(upstreamTlsContext);

sslContextProvider.addCallback(
new SslContextProvider.Callback() {
Expand All @@ -226,6 +227,8 @@ public void updateSecret(SslContext sslContext) {
ctx.pipeline().addAfter(ctx.name(), null, handler);
fireProtocolNegotiationEvent(ctx);
ctx.pipeline().remove(bufferReads);
TlsContextManagerImpl.getInstance()
.releaseClientSslContextProvider(sslContextProvider);
}

@Override
Expand Down Expand Up @@ -303,8 +306,8 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) {
final BufferReadsHandler bufferReads = new BufferReadsHandler();
ctx.pipeline().addBefore(ctx.name(), null, bufferReads);

SslContextProvider<DownstreamTlsContext> sslContextProvider =
TlsContextManager.getInstance()
final SslContextProvider<DownstreamTlsContext> sslContextProvider =
TlsContextManagerImpl.getInstance()
.findOrCreateServerSslContextProvider(downstreamTlsContext);

sslContextProvider.addCallback(
Expand All @@ -319,6 +322,8 @@ public void updateSecret(SslContext sslContext) {
ctx.pipeline().addAfter(ctx.name(), null, handler);
fireProtocolNegotiationEvent(ctx);
ctx.pipeline().remove(bufferReads);
TlsContextManagerImpl.getInstance()
.releaseServerSslContextProvider(sslContextProvider);
}

@Override
Expand Down
Expand Up @@ -41,8 +41,8 @@
import javax.net.ssl.X509ExtendedTrustManager;

/**
* Factory class used by providers of {@link io.grpc.xds.sds.TlsContextManager} to provide a {@link
* SdsX509TrustManager} for trust and SAN checks.
* Factory class used by providers of {@link io.grpc.xds.sds.TlsContextManagerImpl} to provide a
* {@link SdsX509TrustManager} for trust and SAN checks.
*/
@Internal
public final class SdsTrustManagerFactory extends SimpleTrustManagerFactory {
Expand Down

0 comments on commit 04cf90a

Please sign in to comment.