Skip to content

Commit

Permalink
xds: reorder processing of tlsContext to prioritize CertProviderInsta…
Browse files Browse the repository at this point in the history
…nce (#7592)
  • Loading branch information
sanjaypujare committed Nov 4, 2020
1 parent d52b359 commit d7764d7
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,19 @@ public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) {
checkNotNull(
upstreamTlsContext.getCommonTlsContext(),
"upstreamTlsContext should have CommonTlsContext");
if (CommonTlsContextUtil.hasAllSecretsUsingFilename(upstreamTlsContext.getCommonTlsContext())) {
if (CommonTlsContextUtil.hasCertProviderInstance(
upstreamTlsContext.getCommonTlsContext())) {
try {
Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap();
return certProviderClientSslContextProviderFactory.getProvider(
upstreamTlsContext,
bootstrapInfo.getNode().toEnvoyProtoNode(),
bootstrapInfo.getCertProviders());
} catch (XdsInitializationException e) {
throw new RuntimeException(e);
}
} else if (CommonTlsContextUtil.hasAllSecretsUsingFilename(
upstreamTlsContext.getCommonTlsContext())) {
return SecretVolumeClientSslContextProvider.getProvider(upstreamTlsContext);
} else if (CommonTlsContextUtil.hasAllSecretsUsingSds(
upstreamTlsContext.getCommonTlsContext())) {
Expand All @@ -67,17 +79,6 @@ public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) {
} catch (XdsInitializationException e) {
throw new RuntimeException(e);
}
} else if (CommonTlsContextUtil.hasCertProviderInstance(
upstreamTlsContext.getCommonTlsContext())) {
try {
Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap();
return certProviderClientSslContextProviderFactory.getProvider(
upstreamTlsContext,
bootstrapInfo.getNode().toEnvoyProtoNode(),
bootstrapInfo.getCertProviders());
} catch (XdsInitializationException e) {
throw new RuntimeException(e);
}
}
throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,18 @@ public SslContextProvider create(
checkNotNull(
downstreamTlsContext.getCommonTlsContext(),
"downstreamTlsContext should have CommonTlsContext");
if (CommonTlsContextUtil.hasAllSecretsUsingFilename(
if (CommonTlsContextUtil.hasCertProviderInstance(
downstreamTlsContext.getCommonTlsContext())) {
try {
Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap();
return certProviderServerSslContextProviderFactory.getProvider(
downstreamTlsContext,
bootstrapInfo.getNode().toEnvoyProtoNode(),
bootstrapInfo.getCertProviders());
} catch (XdsInitializationException e) {
throw new RuntimeException(e);
}
} else if (CommonTlsContextUtil.hasAllSecretsUsingFilename(
downstreamTlsContext.getCommonTlsContext())) {
return SecretVolumeServerSslContextProvider.getProvider(downstreamTlsContext);
} else if (CommonTlsContextUtil.hasAllSecretsUsingSds(
Expand All @@ -69,17 +80,6 @@ public SslContextProvider create(
} catch (XdsInitializationException e) {
throw new RuntimeException(e);
}
} else if (CommonTlsContextUtil.hasCertProviderInstance(
downstreamTlsContext.getCommonTlsContext())) {
try {
Bootstrapper.BootstrapInfo bootstrapInfo = bootstrapper.readBootstrap();
return certProviderServerSslContextProviderFactory.getProvider(
downstreamTlsContext,
bootstrapInfo.getNode().toEnvoyProtoNode(),
bootstrapInfo.getCertProviders());
} catch (XdsInitializationException e) {
throw new RuntimeException(e);
}
}
throw new UnsupportedOperationException("Unsupported configurations in DownstreamTlsContext!");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableSet;
import io.envoyproxy.envoy.config.core.v3.DataSource;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.CommonBootstrapperTestUtils;
Expand Down Expand Up @@ -139,6 +141,33 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}

@Test
public void bothPresent_expectCertProviderClientSslContextProvider()
throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);

CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder();
builder = addFilenames(builder, "foo.pem", "foo.key", "root.pem");
upstreamTlsContext = new UpstreamTlsContext(builder.build());

Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}

@Test
public void createCertProviderClientSslContextProvider_onlyRootCert()
throws XdsInitializationException {
Expand Down Expand Up @@ -301,4 +330,27 @@ static void verifyWatcher(
assertThat(watcherCaptor.getDownstreamWatchers().iterator().next())
.isSameInstanceAs(sslContextProvider);
}

static CommonTlsContext.Builder addFilenames(
CommonTlsContext.Builder builder, String certChain, String privateKey, String trustCa) {
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename(certChain))
.setPrivateKey(DataSource.newBuilder().setFilename(privateKey))
.build();
CertificateValidationContext certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa))
.build();
CommonTlsContext.CertificateProviderInstance certificateProviderInstance =
builder.getValidationContextCertificateProviderInstance();
CommonTlsContext.CombinedCertificateValidationContext.Builder combinedBuilder =
CommonTlsContext.CombinedCertificateValidationContext.newBuilder();
combinedBuilder
.setDefaultValidationContext(certContext)
.setValidationContextCertificateProviderInstance(certificateProviderInstance);
return builder
.addTlsCertificates(tlsCert)
.setCombinedValidationContext(combinedBuilder.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.CommonBootstrapperTestUtils;
import io.grpc.xds.EnvoyServerProtoData;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.XdsInitializationException;
import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider;
Expand Down Expand Up @@ -136,6 +137,37 @@ public void createCertProviderServerSslContextProvider() throws XdsInitializatio
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}

@Test
public void bothPresent_expectCertProviderServerSslContextProvider()
throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ true);

CommonTlsContext.Builder builder = downstreamTlsContext.getCommonTlsContext().toBuilder();
builder =
ClientSslContextProviderFactoryTest.addFilenames(builder, "foo.pem", "foo.key", "root.pem");
downstreamTlsContext =
new EnvoyServerProtoData.DownstreamTlsContext(
builder.build(), downstreamTlsContext.isRequireClientCertificate());

Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
when(bootstrapper.readBootstrap()).thenReturn(bootstrapInfo);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}

@Test
public void createCertProviderServerSslContextProvider_onlyCertInstance()
throws XdsInitializationException {
Expand Down

0 comments on commit d7764d7

Please sign in to comment.