From d78057364036e8ebd3d94e1863aa6501eab822d0 Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Fri, 9 Oct 2020 09:21:39 -0700 Subject: [PATCH] xds: implement XdsChannelCredentials (#7497) --- .../InternalNettyChannelCredentials.java | 50 +++++++++++++++ .../io/grpc/xds/XdsChannelCredentials.java | 45 +++++++++++++ .../internal/sds/SdsProtocolNegotiators.java | 31 +++++++++ .../io/grpc/xds/XdsSdsClientServerTest.java | 64 +++++++++++++++++-- 4 files changed, 185 insertions(+), 5 deletions(-) create mode 100644 xds/src/main/java/io/grpc/xds/XdsChannelCredentials.java diff --git a/netty/src/main/java/io/grpc/netty/InternalNettyChannelCredentials.java b/netty/src/main/java/io/grpc/netty/InternalNettyChannelCredentials.java index 81051a1833d0..d121c563009a 100644 --- a/netty/src/main/java/io/grpc/netty/InternalNettyChannelCredentials.java +++ b/netty/src/main/java/io/grpc/netty/InternalNettyChannelCredentials.java @@ -18,6 +18,8 @@ import io.grpc.ChannelCredentials; import io.grpc.Internal; +import io.netty.channel.ChannelHandler; +import io.netty.util.AsciiString; /** * Internal {@link NettyChannelCredentials} accessor. This is intended for usage internal to the @@ -31,4 +33,52 @@ private InternalNettyChannelCredentials() {} public static ChannelCredentials create(InternalProtocolNegotiator.ClientFactory negotiator) { return NettyChannelCredentials.create(negotiator); } + + /** + * Converts a {@link ChannelCredentials} to a negotiator, in similar fashion as for a new channel. + * + * @throws IllegalArgumentException if unable to convert + */ + public static InternalProtocolNegotiator.ClientFactory toNegotiator( + ChannelCredentials channelCredentials) { + final ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(channelCredentials); + if (result.error != null) { + throw new IllegalArgumentException(result.error); + } + final class ClientFactory implements InternalProtocolNegotiator.ClientFactory { + + @Override + public InternalProtocolNegotiator.ProtocolNegotiator newNegotiator() { + final ProtocolNegotiator pn = result.negotiator.newNegotiator(); + final class LocalProtocolNegotiator + implements InternalProtocolNegotiator.ProtocolNegotiator { + + @Override + public AsciiString scheme() { + return pn.scheme(); + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + return pn.newHandler(grpcHandler); + } + + @Override + public void close() { + pn.close(); + } + } + + return new LocalProtocolNegotiator(); + } + + @Override + public int getDefaultPort() { + return result.negotiator.getDefaultPort(); + } + } + + return new ClientFactory(); + } } diff --git a/xds/src/main/java/io/grpc/xds/XdsChannelCredentials.java b/xds/src/main/java/io/grpc/xds/XdsChannelCredentials.java new file mode 100644 index 000000000000..6ac464ebe5c2 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsChannelCredentials.java @@ -0,0 +1,45 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.grpc.ChannelCredentials; +import io.grpc.ExperimentalApi; +import io.grpc.netty.InternalNettyChannelCredentials; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.xds.internal.sds.SdsProtocolNegotiators; + +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/7479") +public class XdsChannelCredentials { + private XdsChannelCredentials() {} // prevent instantiation + + /** + * Creates credentials to be configured by xDS, falling back to other credentials if no + * TLS configuration is provided by xDS. + * + * @param fallback Credentials to fall back to. + * + * @throws IllegalArgumentException if fallback is unable to be used + */ + public static ChannelCredentials create(ChannelCredentials fallback) { + InternalProtocolNegotiator.ClientFactory fallbackNegotiator = + InternalNettyChannelCredentials.toNegotiator(checkNotNull(fallback, "fallback")); + return InternalNettyChannelCredentials.create( + SdsProtocolNegotiators.clientProtocolNegotiatorFactory(fallbackNegotiator)); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java index ea5a1435d865..556508fb8c51 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import io.grpc.internal.GrpcUtil; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalNettyChannelBuilder; import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory; @@ -69,6 +70,17 @@ public static ProtocolNegotiatorFactory clientProtocolNegotiatorFactory( return new ClientSdsProtocolNegotiatorFactory(fallbackNegotiator); } + /** + * Returns a {@link InternalProtocolNegotiator.ClientFactory} to be used on {@link + * NettyChannelBuilder}. + * + * @param fallbackNegotiator protocol negotiator to use as fallback. + */ + public static InternalProtocolNegotiator.ClientFactory clientProtocolNegotiatorFactory( + @Nullable InternalProtocolNegotiator.ClientFactory fallbackNegotiator) { + return new ClientFactory(fallbackNegotiator); + } + /** * Creates an SDS based {@link ProtocolNegotiator} for a {@link io.grpc.netty.NettyServerBuilder}. * If xDS returns no DownstreamTlsContext, it will fall back to plaintext. @@ -82,6 +94,25 @@ public static ServerSdsProtocolNegotiator serverProtocolNegotiator(int port, fallbackProtocolNegotiator); } + private static final class ClientFactory implements InternalProtocolNegotiator.ClientFactory { + + private final InternalProtocolNegotiator.ClientFactory fallbackProtocolNegotiator; + + private ClientFactory(InternalProtocolNegotiator.ClientFactory fallbackNegotiator) { + this.fallbackProtocolNegotiator = fallbackNegotiator; + } + + @Override + public ProtocolNegotiator newNegotiator() { + return new ClientSdsProtocolNegotiator(fallbackProtocolNegotiator.newNegotiator()); + } + + @Override + public int getDefaultPort() { + return GrpcUtil.DEFAULT_PORT_SSL; + } + } + private static final class ClientSdsProtocolNegotiatorFactory implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory { diff --git a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java index 98932b978222..948b35847845 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java @@ -33,6 +33,9 @@ import com.google.common.collect.ImmutableList; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannelBuilder; import io.grpc.NameResolver; import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; @@ -104,6 +107,15 @@ public void plaintextClientServer() throws IOException, URISyntaxException { assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); } + @Test + public void plaintextClientServer_withXdsChannelCreds() throws IOException, URISyntaxException { + buildServerWithTlsContext(/* downstreamTlsContext= */ null); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStubNewApi(/* upstreamTlsContext= */ null, /* overrideAuthority= */ null); + assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); + } + @Test public void plaintextClientServer_withDefaultTlsContext() throws IOException, URISyntaxException { DownstreamTlsContext defaultTlsContext = @@ -197,7 +209,8 @@ public void mtls_badClientCert_expectException() throws IOException, URISyntaxEx CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE); try { - XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext); + XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext, + false); fail("exception expected"); } catch (StatusRuntimeException sre) { assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class); @@ -211,7 +224,17 @@ public void mtlsClientServer_withClientAuthentication() throws IOException, URIS UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); - XdsClient.ListenerWatcher unused = performMtlsTestAndGetListenerWatcher(upstreamTlsContext); + performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false); + } + + /** mTLS - client auth enabled - using {@link XdsChannelCredentials} API. */ + @Test + public void mtlsClientServer_withClientAuthentication_withXdsChannelCreds() + throws IOException, URISyntaxException { + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); + performMtlsTestAndGetListenerWatcher(upstreamTlsContext, true); } @Test @@ -260,7 +283,7 @@ public void mtlsClientServer_changeServerContext_expectException() CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); XdsClient.ListenerWatcher listenerWatcher = - performMtlsTestAndGetListenerWatcher(upstreamTlsContext); + performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE); @@ -278,7 +301,8 @@ public void mtlsClientServer_changeServerContext_expectException() } private XdsClient.ListenerWatcher performMtlsTestAndGetListenerWatcher( - UpstreamTlsContext upstreamTlsContext) throws IOException, URISyntaxException { + UpstreamTlsContext upstreamTlsContext, boolean newApi) + throws IOException, URISyntaxException { DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenamesWithClientCertRequired( SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); @@ -291,7 +315,8 @@ private XdsClient.ListenerWatcher performMtlsTestAndGetListenerWatcher( XdsClient.ListenerWatcher listenerWatcher = xdsClientWrapperForServerSds.getListenerWatcher(); - SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = newApi + ? getBlockingStubNewApi(upstreamTlsContext, "foo.test.google.fr") : getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); return listenerWatcher; @@ -378,6 +403,35 @@ upstreamTlsContext, new TlsContextManagerImpl(mockBootstrapper))) return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build())); } + private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStubNewApi( + final UpstreamTlsContext upstreamTlsContext, String overrideAuthority) + throws URISyntaxException { + URI expectedUri = new URI("sdstest://localhost:" + port); + fakeNameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri).build(); + NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverFactory); + ManagedChannelBuilder channelBuilder = + Grpc.newChannelBuilder( + "sdstest://localhost:" + port, + XdsChannelCredentials.create(InsecureChannelCredentials.create())); + + if (overrideAuthority != null) { + channelBuilder = channelBuilder.overrideAuthority(overrideAuthority); + } + InetSocketAddress socketAddress = + new InetSocketAddress(Inet4Address.getLoopbackAddress(), port); + Attributes attrs = + (upstreamTlsContext != null) + ? Attributes.newBuilder() + .set(XdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + new SslContextProviderSupplier( + upstreamTlsContext, new TlsContextManagerImpl(mockBootstrapper))) + .build() + : Attributes.EMPTY; + fakeNameResolverFactory.setServers( + ImmutableList.of(new EquivalentAddressGroup(socketAddress, attrs))); + return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build())); + } + /** Say hello to server. */ private static String unaryRpc( String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) {