Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xds: implement XdsChannelCredentials #7497

Merged
merged 3 commits into from Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -18,6 +18,8 @@

import io.grpc.ChannelCredentials;
import io.grpc.Internal;
import io.netty.channel.ChannelHandler;
import io.netty.util.AsciiString;

/**
* Internal {@link NettyChannelCredentials} accessor. This is intended for usage internal to the
Expand All @@ -31,4 +33,42 @@ private InternalNettyChannelCredentials() {}
public static ChannelCredentials create(InternalProtocolNegotiator.ClientFactory negotiator) {
return NettyChannelCredentials.create(negotiator);
}

public static InternalProtocolNegotiator.ClientFactory toNegotiator(ChannelCredentials channelCredentials) {
final ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(channelCredentials);
if (result.error != null) {
throw new IllegalArgumentException(result.error);
}
final class ClientFactory implements InternalProtocolNegotiator.ClientFactory {

@Override
public InternalProtocolNegotiator.ProtocolNegotiator newNegotiator() {
final ProtocolNegotiator pn = result.negotiator.newNegotiator();
final class LocalProtocolNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {

@Override
public AsciiString scheme() {
return pn.scheme();
}

@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
return pn.newHandler(grpcHandler);
}

@Override
public void close() {
pn.close();
}
}
return new LocalProtocolNegotiator();
}

@Override
public int getDefaultPort() {
return result.negotiator.getDefaultPort();
}
}
return new ClientFactory();
}
}
43 changes: 43 additions & 0 deletions xds/src/main/java/io/grpc/xds/XdsChannelCredentials.java
@@ -0,0 +1,43 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc.xds;

import static com.google.common.base.Preconditions.checkNotNull;

import io.grpc.ChannelCredentials;
import io.grpc.netty.InternalNettyChannelCredentials;
import io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.xds.internal.sds.SdsProtocolNegotiators;

public class XdsChannelCredentials {
sanjaypujare marked this conversation as resolved.
Show resolved Hide resolved
private XdsChannelCredentials() {} // prevent instantiation

/**
* Creates credentials to be configured by xDS, falling back to other credentials if no
* TLS configuration is provided by xDS.
*
* @param fallback Credentials to fall back to.
*
* @throws IllegalArgumentException if fallback is unable to be used
*/
public static ChannelCredentials create(ChannelCredentials fallback) {
InternalProtocolNegotiator.ClientFactory fallbackNegotiator =
InternalNettyChannelCredentials.toNegotiator(checkNotNull(fallback, "fallback"));
return InternalNettyChannelCredentials.create(
SdsProtocolNegotiators.clientProtocolNegotiatorFactory(fallbackNegotiator));
}
}
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.common.annotations.VisibleForTesting;
import io.grpc.internal.GrpcUtil;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory;
Expand Down Expand Up @@ -69,6 +70,17 @@ public static ProtocolNegotiatorFactory clientProtocolNegotiatorFactory(
return new ClientSdsProtocolNegotiatorFactory(fallbackNegotiator);
}

/**
* Returns a {@link InternalProtocolNegotiator.ClientFactory} to be used on {@link
* NettyChannelBuilder}.
*
* @param fallbackNegotiator protocol negotiator to use as fallback.
*/
public static InternalProtocolNegotiator.ClientFactory clientProtocolNegotiatorFactory(
@Nullable InternalProtocolNegotiator.ClientFactory fallbackNegotiator) {
return new ClientFactory(fallbackNegotiator);
}

/**
* Creates an SDS based {@link ProtocolNegotiator} for a {@link io.grpc.netty.NettyServerBuilder}.
* If xDS returns no DownstreamTlsContext, it will fall back to plaintext.
Expand All @@ -82,6 +94,25 @@ public static ServerSdsProtocolNegotiator serverProtocolNegotiator(int port,
fallbackProtocolNegotiator);
}

private static final class ClientFactory implements InternalProtocolNegotiator.ClientFactory {

private final InternalProtocolNegotiator.ClientFactory fallbackProtocolNegotiator;

private ClientFactory(InternalProtocolNegotiator.ClientFactory fallbackNegotiator) {
this.fallbackProtocolNegotiator = fallbackNegotiator;
}

@Override
public ProtocolNegotiator newNegotiator() {
return new ClientSdsProtocolNegotiator(fallbackProtocolNegotiator.newNegotiator());
}

@Override
public int getDefaultPort() {
return GrpcUtil.DEFAULT_PORT_SSL;
}
}

private static final class ClientSdsProtocolNegotiatorFactory
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {

Expand Down
64 changes: 59 additions & 5 deletions xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java
Expand Up @@ -33,6 +33,9 @@
import com.google.common.collect.ImmutableList;
import io.grpc.Attributes;
import io.grpc.EquivalentAddressGroup;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannelBuilder;
import io.grpc.NameResolver;
import io.grpc.NameResolverProvider;
import io.grpc.NameResolverRegistry;
Expand Down Expand Up @@ -104,6 +107,15 @@ public void plaintextClientServer() throws IOException, URISyntaxException {
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
}

@Test
public void plaintextClientServer_withXdsChannelCreds() throws IOException, URISyntaxException {
buildServerWithTlsContext(/* downstreamTlsContext= */ null);

SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
getBlockingStubNewApi(/* upstreamTlsContext= */ null, /* overrideAuthority= */ null);
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
}

@Test
public void plaintextClientServer_withDefaultTlsContext() throws IOException, URISyntaxException {
DownstreamTlsContext defaultTlsContext =
Expand Down Expand Up @@ -197,7 +209,8 @@ public void mtls_badClientCert_expectException() throws IOException, URISyntaxEx
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE);
try {
XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext);
XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext,
false);
fail("exception expected");
} catch (StatusRuntimeException sre) {
assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class);
Expand All @@ -211,7 +224,17 @@ public void mtlsClientServer_withClientAuthentication() throws IOException, URIS
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext);
performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false);
}

/** mTLS - client auth enabled - using {@link XdsChannelCredentials} API. */
@Test
public void mtlsClientServer_withClientAuthentication_withXdsChannelCreds()
throws IOException, URISyntaxException {
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
performMtlsTestAndGetListenerWatcher(upstreamTlsContext, true);
}

@Test
Expand Down Expand Up @@ -260,7 +283,7 @@ public void mtlsClientServer_changeServerContext_expectException()
CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(
CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
XdsClient.ListenerWatcher listenerWatcher =
performMtlsTestAndGetListenerWatcher(upstreamTlsContext);
performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(
BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE);
Expand All @@ -278,7 +301,8 @@ public void mtlsClientServer_changeServerContext_expectException()
}

private XdsClient.ListenerWatcher performMtlsTestAndGetListenerWatcher(
UpstreamTlsContext upstreamTlsContext) throws IOException, URISyntaxException {
UpstreamTlsContext upstreamTlsContext, boolean newApi)
throws IOException, URISyntaxException {
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenamesWithClientCertRequired(
SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
Expand All @@ -291,7 +315,8 @@ private XdsClient.ListenerWatcher performMtlsTestAndGetListenerWatcher(

XdsClient.ListenerWatcher listenerWatcher = xdsClientWrapperForServerSds.getListenerWatcher();

SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = newApi
? getBlockingStubNewApi(upstreamTlsContext, "foo.test.google.fr") :
getBlockingStub(upstreamTlsContext, "foo.test.google.fr");
assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy");
return listenerWatcher;
Expand Down Expand Up @@ -378,6 +403,35 @@ upstreamTlsContext, new TlsContextManagerImpl(mockBootstrapper)))
return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build()));
}

private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStubNewApi(
final UpstreamTlsContext upstreamTlsContext, String overrideAuthority)
throws URISyntaxException {
URI expectedUri = new URI("sdstest://localhost:" + port);
fakeNameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri).build();
NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverFactory);
ManagedChannelBuilder<?> channelBuilder =
Grpc.newChannelBuilder(
"sdstest://localhost:" + port,
XdsChannelCredentials.create(InsecureChannelCredentials.create()));

if (overrideAuthority != null) {
channelBuilder = channelBuilder.overrideAuthority(overrideAuthority);
}
InetSocketAddress socketAddress =
new InetSocketAddress(Inet4Address.getLoopbackAddress(), port);
Attributes attrs =
(upstreamTlsContext != null)
? Attributes.newBuilder()
.set(XdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER,
new SslContextProviderSupplier(
upstreamTlsContext, new TlsContextManagerImpl(mockBootstrapper)))
.build()
: Attributes.EMPTY;
fakeNameResolverFactory.setServers(
ImmutableList.of(new EquivalentAddressGroup(socketAddress, attrs)));
return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build()));
}

/** Say hello to server. */
private static String unaryRpc(
String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) {
Expand Down