Skip to content

Commit

Permalink
xds: implement XdsChannelCredentials (grpc#7497)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjaypujare authored and dfawley committed Jan 15, 2021
1 parent e3ab6f5 commit d780573
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 5 deletions.
Expand Up @@ -18,6 +18,8 @@

import io.grpc.ChannelCredentials;
import io.grpc.Internal;
import io.netty.channel.ChannelHandler;
import io.netty.util.AsciiString;

/**
* Internal {@link NettyChannelCredentials} accessor. This is intended for usage internal to the
Expand All @@ -31,4 +33,52 @@ private InternalNettyChannelCredentials() {}
public static ChannelCredentials create(InternalProtocolNegotiator.ClientFactory negotiator) {
return NettyChannelCredentials.create(negotiator);
}

/**
* Converts a {@link ChannelCredentials} to a negotiator, in similar fashion as for a new channel.
*
* @throws IllegalArgumentException if unable to convert
*/
public static InternalProtocolNegotiator.ClientFactory toNegotiator(
ChannelCredentials channelCredentials) {
final ProtocolNegotiators.FromChannelCredentialsResult result =
ProtocolNegotiators.from(channelCredentials);
if (result.error != null) {
throw new IllegalArgumentException(result.error);
}
final class ClientFactory implements InternalProtocolNegotiator.ClientFactory {

@Override
public InternalProtocolNegotiator.ProtocolNegotiator newNegotiator() {
final ProtocolNegotiator pn = result.negotiator.newNegotiator();
final class LocalProtocolNegotiator
implements InternalProtocolNegotiator.ProtocolNegotiator {

@Override
public AsciiString scheme() {
return pn.scheme();
}

@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
return pn.newHandler(grpcHandler);
}

@Override
public void close() {
pn.close();
}
}

return new LocalProtocolNegotiator();
}

@Override
public int getDefaultPort() {
return result.negotiator.getDefaultPort();
}
}

return new ClientFactory();
}
}
45 changes: 45 additions & 0 deletions xds/src/main/java/io/grpc/xds/XdsChannelCredentials.java
@@ -0,0 +1,45 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc.xds;

import static com.google.common.base.Preconditions.checkNotNull;

import io.grpc.ChannelCredentials;
import io.grpc.ExperimentalApi;
import io.grpc.netty.InternalNettyChannelCredentials;
import io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.xds.internal.sds.SdsProtocolNegotiators;

@ExperimentalApi("https://github.com/grpc/grpc-java/issues/7479")
public class XdsChannelCredentials {
private XdsChannelCredentials() {} // prevent instantiation

/**
* Creates credentials to be configured by xDS, falling back to other credentials if no
* TLS configuration is provided by xDS.
*
* @param fallback Credentials to fall back to.
*
* @throws IllegalArgumentException if fallback is unable to be used
*/
public static ChannelCredentials create(ChannelCredentials fallback) {
InternalProtocolNegotiator.ClientFactory fallbackNegotiator =
InternalNettyChannelCredentials.toNegotiator(checkNotNull(fallback, "fallback"));
return InternalNettyChannelCredentials.create(
SdsProtocolNegotiators.clientProtocolNegotiatorFactory(fallbackNegotiator));
}
}
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.common.annotations.VisibleForTesting;
import io.grpc.internal.GrpcUtil;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory;
Expand Down Expand Up @@ -69,6 +70,17 @@ public static ProtocolNegotiatorFactory clientProtocolNegotiatorFactory(
return new ClientSdsProtocolNegotiatorFactory(fallbackNegotiator);
}

/**
* Returns a {@link InternalProtocolNegotiator.ClientFactory} to be used on {@link
* NettyChannelBuilder}.
*
* @param fallbackNegotiator protocol negotiator to use as fallback.
*/
public static InternalProtocolNegotiator.ClientFactory clientProtocolNegotiatorFactory(
@Nullable InternalProtocolNegotiator.ClientFactory fallbackNegotiator) {
return new ClientFactory(fallbackNegotiator);
}

/**
* Creates an SDS based {@link ProtocolNegotiator} for a {@link io.grpc.netty.NettyServerBuilder}.
* If xDS returns no DownstreamTlsContext, it will fall back to plaintext.
Expand All @@ -82,6 +94,25 @@ public static ServerSdsProtocolNegotiator serverProtocolNegotiator(int port,
fallbackProtocolNegotiator);
}

private static final class ClientFactory implements InternalProtocolNegotiator.ClientFactory {

private final InternalProtocolNegotiator.ClientFactory fallbackProtocolNegotiator;

private ClientFactory(InternalProtocolNegotiator.ClientFactory fallbackNegotiator) {
this.fallbackProtocolNegotiator = fallbackNegotiator;
}

@Override
public ProtocolNegotiator newNegotiator() {
return new ClientSdsProtocolNegotiator(fallbackProtocolNegotiator.newNegotiator());
}

@Override
public int getDefaultPort() {
return GrpcUtil.DEFAULT_PORT_SSL;
}
}

private static final class ClientSdsProtocolNegotiatorFactory
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {

Expand Down
64 changes: 59 additions & 5 deletions xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java
Expand Up @@ -33,6 +33,9 @@
import com.google.common.collect.ImmutableList;
import io.grpc.Attributes;
import io.grpc.EquivalentAddressGroup;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannelBuilder;
import io.grpc.NameResolver;
import io.grpc.NameResolverProvider;
import io.grpc.NameResolverRegistry;
Expand Down Expand Up @@ -104,6 +107,15 @@ public void plaintextClientServer() throws IOException, URISyntaxException {
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
}

@Test
public void plaintextClientServer_withXdsChannelCreds() throws IOException, URISyntaxException {
buildServerWithTlsContext(/* downstreamTlsContext= */ null);

SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStubNewApi(/* upstreamTlsContext= */ null, /* overrideAuthority= */ null);
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
}

@Test
public void plaintextClientServer_withDefaultTlsContext() throws IOException, URISyntaxException {
DownstreamTlsContext defaultTlsContext =
Expand Down Expand Up @@ -197,7 +209,8 @@ public void mtls_badClientCert_expectException() throws IOException, URISyntaxEx
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE);
try {
XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext);
XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext,
false);
fail("exception expected");
} catch (StatusRuntimeException sre) {
assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class);
Expand All @@ -211,7 +224,17 @@ public void mtlsClientServer_withClientAuthentication() throws IOException, URIS
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext);
performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false);
}

/** mTLS - client auth enabled - using {@link XdsChannelCredentials} API. */
@Test
public void mtlsClientServer_withClientAuthentication_withXdsChannelCreds()
throws IOException, URISyntaxException {
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
performMtlsTestAndGetListenerWatcher(upstreamTlsContext, true);
}

@Test
Expand Down Expand Up @@ -260,7 +283,7 @@ public void mtlsClientServer_changeServerContext_expectException()
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
XdsClient.ListenerWatcher listenerWatcher =
performMtlsTestAndGetListenerWatcher(upstreamTlsContext);
performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE);
Expand All @@ -278,7 +301,8 @@ public void mtlsClientServer_changeServerContext_expectException()
}

private XdsClient.ListenerWatcher performMtlsTestAndGetListenerWatcher(
UpstreamTlsContext upstreamTlsContext) throws IOException, URISyntaxException {
UpstreamTlsContext upstreamTlsContext, boolean newApi)
throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenamesWithClientCertRequired(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
Expand All @@ -291,7 +315,8 @@ private XdsClient.ListenerWatcher performMtlsTestAndGetListenerWatcher(

XdsClient.ListenerWatcher listenerWatcher = xdsClientWrapperForServerSds.getListenerWatcher();

SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = newApi
? getBlockingStubNewApi(upstreamTlsContext, "foo.test.google.fr") :
getBlockingStub(upstreamTlsContext, "foo.test.google.fr");
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
return listenerWatcher;
Expand Down Expand Up @@ -378,6 +403,35 @@ upstreamTlsContext, new TlsContextManagerImpl(mockBootstrapper)))
return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build()));
}

private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStubNewApi(
final UpstreamTlsContext upstreamTlsContext, String overrideAuthority)
throws URISyntaxException {
URI expectedUri = new URI("sdstest://localhost:" + port);
fakeNameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri).build();
NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverFactory);
ManagedChannelBuilder<?> channelBuilder =
Grpc.newChannelBuilder(
"sdstest://localhost:" + port,
XdsChannelCredentials.create(InsecureChannelCredentials.create()));

if (overrideAuthority != null) {
channelBuilder = channelBuilder.overrideAuthority(overrideAuthority);
}
InetSocketAddress socketAddress =
new InetSocketAddress(Inet4Address.getLoopbackAddress(), port);
Attributes attrs =
(upstreamTlsContext != null)
? Attributes.newBuilder()
.set(XdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER,
new SslContextProviderSupplier(
upstreamTlsContext, new TlsContextManagerImpl(mockBootstrapper)))
.build()
: Attributes.EMPTY;
fakeNameResolverFactory.setServers(
ImmutableList.of(new EquivalentAddressGroup(socketAddress, attrs)));
return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build()));
}

/** Say hello to server. */
private static String unaryRpc(
String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) {
Expand Down

0 comments on commit d780573

Please sign in to comment.