diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 2edb32fa216..6dccd7aca8b 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -152,7 +152,9 @@ public Result selectConfig(PickSubchannelArgs args) { private final NameResolver.Factory nameResolverFactory; private final NameResolver.Args nameResolverArgs; private final AutoConfiguredLoadBalancerFactory loadBalancerFactory; + private final ClientTransportFactory originalTransportFactory; private final ClientTransportFactory transportFactory; + private final ClientTransportFactory oobTransportFactory; private final RestrictedScheduledExecutor scheduledExecutor; private final Executor executor; private final ObjectPool executorPool; @@ -591,8 +593,11 @@ ClientStream newSubstream(ClientStreamTracer.Factory tracerFactory, Metadata new this.timeProvider = checkNotNull(timeProvider, "timeProvider"); this.executorPool = checkNotNull(builder.executorPool, "executorPool"); this.executor = checkNotNull(executorPool.getObject(), "executor"); + this.originalTransportFactory = clientTransportFactory; this.transportFactory = new CallCredentialsApplyingTransportFactory( clientTransportFactory, builder.callCredentials, this.executor); + this.oobTransportFactory = new CallCredentialsApplyingTransportFactory( + clientTransportFactory, null, this.executor); this.scheduledExecutor = new RestrictedScheduledExecutor(transportFactory.getScheduledExecutorService()); maxTraceEvents = builder.maxTraceEvents; @@ -1487,7 +1492,7 @@ public ManagedChannel createOobChannel(EquivalentAddressGroup addressGroup, Stri oobLogId, maxTraceEvents, oobChannelCreationTime, "OobChannel for " + addressGroup); final OobChannel oobChannel = new OobChannel( - authority, balancerRpcExecutorPool, transportFactory.getScheduledExecutorService(), + authority, balancerRpcExecutorPool, oobTransportFactory.getScheduledExecutorService(), syncContext, callTracerFactory.create(), oobChannelTracer, channelz, timeProvider); channelTracer.reportEvent(new ChannelTrace.Event.Builder() .setDescription("Child OobChannel created") @@ -1517,8 +1522,8 @@ void onStateChange(InternalSubchannel is, ConnectivityStateInfo newState) { final InternalSubchannel internalSubchannel = new InternalSubchannel( Collections.singletonList(addressGroup), - authority, userAgent, backoffPolicyProvider, transportFactory, - transportFactory.getScheduledExecutorService(), stopwatchSupplier, syncContext, + authority, userAgent, backoffPolicyProvider, oobTransportFactory, + oobTransportFactory.getScheduledExecutorService(), stopwatchSupplier, syncContext, // All callback methods are run from syncContext new ManagedOobChannelCallback(), channelz, @@ -1577,7 +1582,7 @@ public ManagedChannel build() { // TODO(creamsoup) prevent main channel to shutdown if oob channel is not terminated return new ManagedChannelImpl( managedChannelImplBuilder, - transportFactory, + originalTransportFactory, backoffPolicyProvider, balancerRpcExecutorPool, stopwatchSupplier, diff --git a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java index 52ebb8b8a24..6949ab7c310 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java @@ -312,22 +312,4 @@ public void callOptionAndChanelCreds() { assertEquals(creds2Value, origHeaders.get(creds2Key)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); } - - private abstract static class BaseCallCredentials extends CallCredentials { - @Override public void thisUsesUnstableApi() {} - } - - private static class FakeCallCredentials extends BaseCallCredentials { - private final Metadata headers; - - public FakeCallCredentials(Metadata.Key key, T value) { - headers = new Metadata(); - headers.put(key, value); - } - - @Override public void applyRequestMetadata( - RequestInfo requestInfo, Executor appExecutor, CallCredentials.MetadataApplier applier) { - applier.apply(headers); - } - } } diff --git a/core/src/test/java/io/grpc/internal/FakeCallCredentials.java b/core/src/test/java/io/grpc/internal/FakeCallCredentials.java new file mode 100644 index 00000000000..e307495697c --- /dev/null +++ b/core/src/test/java/io/grpc/internal/FakeCallCredentials.java @@ -0,0 +1,43 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import io.grpc.CallCredentials; +import io.grpc.Metadata; +import java.util.concurrent.Executor; + +/** + * CallCredentials that provides a single, fixed header. + */ +final class FakeCallCredentials extends CallCredentials { + private final Metadata headers; + + public FakeCallCredentials(Metadata.Key key, T value) { + headers = new Metadata(); + headers.put(key, value); + } + + @Override + public void applyRequestMetadata( + CallCredentials.RequestInfo requestInfo, + Executor appExecutor, + CallCredentials.MetadataApplier applier) { + applier.apply(headers); + } + + @Override public void thisUsesUnstableApi() {} +} diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index f8671681096..86f682bad83 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -330,7 +330,10 @@ public void setUp() throws Exception { channelBuilder = new ManagedChannelImplBuilder(TARGET, new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT)); + configureBuilder(channelBuilder); + } + private void configureBuilder(ManagedChannelImplBuilder channelBuilder) { channelBuilder .nameResolverFactory(new FakeNameResolverFactory.Builder(expectedUri).build()) .defaultLoadBalancingPolicy(MOCK_POLICY_NAME) @@ -1741,6 +1744,87 @@ public void oobchannels() { .returnObject(balancerRpcExecutor.getScheduledExecutorService()); } + @Test + public void oobchannelsHaveNoChannelCallCredentials() { + Metadata.Key metadataKey = + Metadata.Key.of("token", Metadata.ASCII_STRING_MARSHALLER); + String channelCredValue = "channel-provided call cred"; + channelBuilder = new ManagedChannelImplBuilder(TARGET, + new FakeCallCredentials(metadataKey, channelCredValue), + new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT)); + configureBuilder(channelBuilder); + createChannel(); + + // Verify that the normal channel has call creds, to validate configuration + Subchannel subchannel = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); + requestConnectionSafely(helper, subchannel); + MockClientTransportInfo transportInfo = transports.poll(); + assertNotNull(transportInfo); + transportInfo.listener.transportReady(); + when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( + PickResult.withSubchannel(subchannel)); + updateBalancingStateSafely(helper, READY, mockPicker); + + String callCredValue = "per-RPC call cred"; + CallOptions callOptions = CallOptions.DEFAULT + .withCallCredentials(new FakeCallCredentials(metadataKey, callCredValue)); + Metadata headers = new Metadata(); + ClientCall call = channel.newCall(method, callOptions); + call.start(mockCallListener, headers); + + verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + assertThat(headers.getAll(metadataKey)) + .containsExactly(channelCredValue, callCredValue).inOrder(); + + // Verify that the oob channel does not + ManagedChannel oob = helper.createOobChannel(addressGroup, "oobauthority"); + + headers = new Metadata(); + call = oob.newCall(method, callOptions); + call.start(mockCallListener2, headers); + + transportInfo = transports.poll(); + assertNotNull(transportInfo); + transportInfo.listener.transportReady(); + balancerRpcExecutor.runDueTasks(); + + verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue); + oob.shutdownNow(); + + // Verify that resolving oob channel does not + oob = helper.createResolvingOobChannelBuilder("oobauthority") + .nameResolverFactory( + new FakeNameResolverFactory.Builder(URI.create("oobauthority")).build()) + .defaultLoadBalancingPolicy(MOCK_POLICY_NAME) + .idleTimeout(ManagedChannelImplBuilder.IDLE_MODE_MAX_TIMEOUT_DAYS, TimeUnit.DAYS) + .build(); + oob.getState(true); + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); + verify(mockLoadBalancerProvider, times(2)).newLoadBalancer(helperCaptor.capture()); + Helper oobHelper = helperCaptor.getValue(); + + subchannel = + createSubchannelSafely(oobHelper, addressGroup, Attributes.EMPTY, subchannelStateListener); + requestConnectionSafely(oobHelper, subchannel); + transportInfo = transports.poll(); + assertNotNull(transportInfo); + transportInfo.listener.transportReady(); + SubchannelPicker mockPicker2 = mock(SubchannelPicker.class); + when(mockPicker2.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( + PickResult.withSubchannel(subchannel)); + updateBalancingStateSafely(oobHelper, READY, mockPicker2); + + headers = new Metadata(); + call = oob.newCall(method, callOptions); + call.start(mockCallListener2, headers); + + verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue); + oob.shutdownNow(); + } + @Test public void oobChannelsWhenChannelShutdownNow() { createChannel();