Skip to content

Commit

Permalink
xds: add support for static and combined validation context and enhan…
Browse files Browse the repository at this point in the history
…ced loggging (#6586)
  • Loading branch information
sanjaypujare committed Jan 9, 2020
1 parent d03a746 commit 6517ac8
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 126 deletions.
62 changes: 30 additions & 32 deletions xds/src/main/java/io/grpc/xds/sds/SdsClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import com.google.rpc.Code;
import io.envoyproxy.envoy.api.v2.DiscoveryRequest;
import io.envoyproxy.envoy.api.v2.DiscoveryResponse;
import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig;
Expand Down Expand Up @@ -240,6 +239,7 @@ void start() {
}
responseObserver = new ResponseObserver();
requestObserver = secretDiscoveryServiceStub.streamSecrets(responseObserver);
logger.log(Level.FINEST, "Stream created for {0}", sdsSecretConfig);
}
}

Expand All @@ -261,6 +261,7 @@ private final class ResponseObserver implements StreamObserver<DiscoveryResponse

@Override
public void onNext(DiscoveryResponse discoveryResponse) {
logger.log(Level.FINEST, "response={0}", discoveryResponse);
processDiscoveryResponse(discoveryResponse);
}

Expand All @@ -272,6 +273,7 @@ public void onError(Throwable t) {
@Override
public void onCompleted() {
// TODO(sanjaypujare): add retry logic once client implementation is final
logger.warning("Stream unexpectedly completed.");
}
}

Expand All @@ -280,8 +282,10 @@ private void processDiscoveryResponse(final DiscoveryResponse response) {
new Runnable() {
@Override
public void run() {
if (!processSecretsFromDiscoveryResponse(response)) {
sendNack(Code.INTERNAL_VALUE, "Secret not updated");
try {
processSecretsFromDiscoveryResponse(response);
} catch (Throwable exceptionSeen) {
sendNack(exceptionSeen);
return;
}
lastResponse = response;
Expand All @@ -291,14 +295,15 @@ public void run() {
});
}

private void sendNack(int errorCode, String errorMessage) {
private void sendNack(Throwable exceptionSeen) {
String nonce = "";
String versionInfo = "";

if (lastResponse != null) {
nonce = lastResponse.getNonce();
versionInfo = lastResponse.getVersionInfo();
}
Status grpcStatus = Status.fromThrowable(exceptionSeen);
DiscoveryRequest.Builder builder =
DiscoveryRequest.newBuilder()
.setTypeUrl(SECRET_TYPE_URL)
Expand All @@ -307,12 +312,14 @@ private void sendNack(int errorCode, String errorMessage) {
.addResourceNames(sdsSecretConfig.getName())
.setErrorDetail(
com.google.rpc.Status.newBuilder()
.setCode(errorCode)
.setMessage(errorMessage)
.setCode(grpcStatus.getCode().value())
.setMessage(grpcStatus.getDescription() != null ? grpcStatus.getDescription()
: "Secret not updated")
.build())
.setNode(clientNode);

DiscoveryRequest req = builder.build();
logger.log(Level.FINEST, "Sending NACK req={0}", req);
requestObserver.onNext(req);
}

Expand All @@ -333,42 +340,26 @@ public void run() {
}
}

private boolean processSecretsFromDiscoveryResponse(DiscoveryResponse response) {
private void processSecretsFromDiscoveryResponse(DiscoveryResponse response)
throws InvalidProtocolBufferException {
List<Any> resources = response.getResourcesList();
checkState(resources.size() == 1, "exactly one resource expected");
boolean noException = true;
for (Any any : resources) {
final String typeUrl = any.getTypeUrl();
checkState(SECRET_TYPE_URL.equals(typeUrl), "wrong value for typeUrl %s", typeUrl);
Secret secret = null;
try {
secret = Secret.parseFrom(any.getValue());
if (!processSecret(secret)) {
noException = false;
}
} catch (InvalidProtocolBufferException e) {
logger.log(Level.SEVERE, "exception from parseFrom", e);
}
}
return noException;
Any any = resources.get(0);
final String typeUrl = any.getTypeUrl();
checkState(SECRET_TYPE_URL.equals(typeUrl), "wrong value for typeUrl %s", typeUrl);
Secret secret = Secret.parseFrom(any.getValue());
processSecret(secret);
}

private boolean processSecret(Secret secret) {
private void processSecret(Secret secret) {
checkState(
sdsSecretConfig.getName().equals(secret.getName()),
"expected secret name %s",
sdsSecretConfig.getName());
boolean noException = true;
final SecretWatcher localCopy = watcher;
if (localCopy != null) {
try {
localCopy.onSecretChanged(secret);
} catch (Throwable throwable) {
noException = false;
logger.log(Level.SEVERE, "exception from onSecretChanged", throwable);
}
localCopy.onSecretChanged(secret);
}
return noException;
}

/** Registers a secret watcher for this client's SdsSecretConfig. */
Expand All @@ -383,7 +374,11 @@ void watchSecret(SecretWatcher secretWatcher) {
new Runnable() {
@Override
public void run() {
processSecretsFromDiscoveryResponse(lastResponse);
try {
processSecretsFromDiscoveryResponse(lastResponse);
} catch (Throwable throwable) {
logger.log(Level.SEVERE, "from watcherExecutor.execute", throwable);
}
}
});
}
Expand Down Expand Up @@ -432,10 +427,12 @@ public void close(EventLoopGroup instance) {
private void sendDiscoveryRequestOnStream() {
String nonce = "";
String versionInfo = "";
String requestType = "Sending initial req={0}";

if (lastResponse != null) {
nonce = lastResponse.getNonce();
versionInfo = lastResponse.getVersionInfo();
requestType = "Sending ACK req={0}";
}
DiscoveryRequest.Builder builder =
DiscoveryRequest.newBuilder()
Expand All @@ -446,6 +443,7 @@ private void sendDiscoveryRequestOnStream() {
.setNode(clientNode);

DiscoveryRequest req = builder.build();
logger.log(Level.FINEST, requestType, req);
requestObserver.onNext(req);
}
}
53 changes: 47 additions & 6 deletions xds/src/main/java/io/grpc/xds/sds/SdsSslContextProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext.CombinedCertificateValidationContext;
import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.envoyproxy.envoy.api.v2.auth.SdsSecretConfig;
import io.envoyproxy.envoy.api.v2.auth.Secret;
Expand Down Expand Up @@ -55,6 +56,7 @@ final class SdsSslContextProvider<K> extends SslContextProvider<K>
@Nullable private final SdsClient validationContextSdsClient;
@Nullable private final SdsSecretConfig certSdsConfig;
@Nullable private final SdsSecretConfig validationContextSdsConfig;
@Nullable private final CertificateValidationContext staticCertificateValidationContext;
private final List<CallbackPair> pendingCallbacks = new ArrayList<>();
@Nullable private TlsCertificate tlsCertificate;
@Nullable private CertificateValidationContext certificateValidationContext;
Expand All @@ -64,13 +66,15 @@ private SdsSslContextProvider(
Node node,
SdsSecretConfig certSdsConfig,
SdsSecretConfig validationContextSdsConfig,
CertificateValidationContext staticCertValidationContext,
Executor watcherExecutor,
Executor channelExecutor,
boolean server,
K source) {
super(source, server);
this.certSdsConfig = certSdsConfig;
this.validationContextSdsConfig = validationContextSdsConfig;
this.staticCertificateValidationContext = staticCertValidationContext;
if (certSdsConfig != null && certSdsConfig.isInitialized()) {
certSdsClient =
SdsClient.Factory.createSdsClient(certSdsConfig, node, watcherExecutor, channelExecutor);
Expand All @@ -97,9 +101,23 @@ static SdsSslContextProvider<UpstreamTlsContext> getProviderForClient(
Executor channelExecutor) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
SdsSecretConfig validationContextSdsConfig =
commonTlsContext.getValidationContextSdsSecretConfig();

SdsSecretConfig validationContextSdsConfig = null;
CertificateValidationContext staticCertValidationContext = null;
if (commonTlsContext.hasCombinedValidationContext()) {
CombinedCertificateValidationContext combinedValidationContext =
commonTlsContext.getCombinedValidationContext();
if (combinedValidationContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig =
combinedValidationContext.getValidationContextSdsSecretConfig();
}
if (combinedValidationContext.hasDefaultValidationContext()) {
staticCertValidationContext = combinedValidationContext.getDefaultValidationContext();
}
} else if (commonTlsContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig = commonTlsContext.getValidationContextSdsSecretConfig();
} else if (commonTlsContext.hasValidationContext()) {
staticCertValidationContext = commonTlsContext.getValidationContext();
}
SdsSecretConfig certSdsConfig = null;
if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) {
certSdsConfig = commonTlsContext.getTlsCertificateSdsSecretConfigs(0);
Expand All @@ -108,6 +126,7 @@ static SdsSslContextProvider<UpstreamTlsContext> getProviderForClient(
node,
certSdsConfig,
validationContextSdsConfig,
staticCertValidationContext,
watcherExecutor,
channelExecutor,
false,
Expand Down Expand Up @@ -135,6 +154,7 @@ static SdsSslContextProvider<DownstreamTlsContext> getProviderForServer(
node,
certSdsConfig,
validationContextSdsConfig,
null,
watcherExecutor,
channelExecutor,
true,
Expand Down Expand Up @@ -175,6 +195,7 @@ public synchronized void onSecretChanged(Secret secretUpdate) {
if (secretUpdate.hasTlsCertificate()) {
checkState(
secretUpdate.getName().equals(certSdsConfig.getName()), "tlsCert names don't match");
logger.log(Level.FINEST, "onSecretChanged certSdsConfig.name={0}", certSdsConfig.getName());
tlsCertificate = secretUpdate.getTlsCertificate();
if (certificateValidationContext != null || validationContextSdsConfig == null) {
updateSslContext();
Expand All @@ -183,6 +204,10 @@ public synchronized void onSecretChanged(Secret secretUpdate) {
checkState(
secretUpdate.getName().equals(validationContextSdsConfig.getName()),
"validationContext names don't match");
logger.log(
Level.FINEST,
"onSecretChanged validationContextSdsConfig.name={0}",
validationContextSdsConfig.getName());
certificateValidationContext = secretUpdate.getValidationContext();
if (tlsCertificate != null || certSdsConfig == null) {
updateSslContext();
Expand All @@ -197,21 +222,25 @@ public synchronized void onSecretChanged(Secret secretUpdate) {
private void updateSslContext() {
try {
SslContextBuilder sslContextBuilder;
CertificateValidationContext localCertValidationContext =
mergeStaticAndDynamicCertContexts();
if (server) {
logger.log(Level.FINEST, "for server");
sslContextBuilder =
GrpcSslContexts.forServer(
tlsCertificate.getCertificateChain().getInlineBytes().newInput(),
tlsCertificate.getPrivateKey().getInlineBytes().newInput(),
tlsCertificate.hasPassword()
? tlsCertificate.getPassword().getInlineString()
: null);
if (certificateValidationContext != null) {
sslContextBuilder.trustManager(new SdsTrustManagerFactory(certificateValidationContext));
if (localCertValidationContext != null) {
sslContextBuilder.trustManager(new SdsTrustManagerFactory(localCertValidationContext));
}
} else {
logger.log(Level.FINEST, "for client");
sslContextBuilder =
GrpcSslContexts.forClient()
.trustManager(new SdsTrustManagerFactory(certificateValidationContext));
.trustManager(new SdsTrustManagerFactory(localCertValidationContext));
if (tlsCertificate != null) {
sslContextBuilder.keyManager(
tlsCertificate.getCertificateChain().getInlineBytes().newInput(),
Expand All @@ -227,6 +256,18 @@ private void updateSslContext() {
}
}

private CertificateValidationContext mergeStaticAndDynamicCertContexts() {
if (staticCertificateValidationContext == null) {
return certificateValidationContext;
}
if (certificateValidationContext == null) {
return staticCertificateValidationContext;
}
CertificateValidationContext.Builder localCertContextBuilder =
certificateValidationContext.toBuilder();
return localCertContextBuilder.mergeFrom(staticCertificateValidationContext).build();
}

private void makePendingCallbacks(SslContext sslContextCopy) {
synchronized (pendingCallbacks) {
for (CallbackPair pair : pendingCallbacks) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import io.netty.util.AsciiString;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;

/**
Expand All @@ -49,6 +51,8 @@
@Internal
public final class SdsProtocolNegotiators {

private static final Logger logger = Logger.getLogger(SdsProtocolNegotiators.class.getName());

private static final AsciiString SCHEME = AsciiString.of("https");

/**
Expand Down Expand Up @@ -167,6 +171,12 @@ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
super.channelReadComplete(ctx);
}
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
logger.log(Level.SEVERE, "exceptionCaught", cause);
ctx.fireExceptionCaught(cause);
}
}

@VisibleForTesting
Expand Down Expand Up @@ -205,6 +215,10 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) {

@Override
public void updateSecret(SslContext sslContext) {
logger.log(
Level.FINEST,
"ClientSdsHandler.updateSecret authority={0}, ctx.name={1}",
new Object[]{grpcHandler.getAuthority(), ctx.name()});
ChannelHandler handler =
InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler);

Expand All @@ -221,6 +235,13 @@ public void onException(Throwable throwable) {
},
ctx.executor());
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
logger.log(Level.SEVERE, "exceptionCaught", cause);
ctx.fireExceptionCaught(cause);
}
}

private static final class ServerSdsProtocolNegotiator implements ProtocolNegotiator {
Expand Down

0 comments on commit 6517ac8

Please sign in to comment.