Skip to content

Commit

Permalink
core: Don't leak CallCredentials into OOB channels
Browse files Browse the repository at this point in the history
This also fixes a bug where resolving OOB channels would have CallCreds
duplicated; that wasn't noticed or important because we don't use
CallCreds in OOB channels.

Fixes grpc#7643
  • Loading branch information
ejona86 committed Dec 9, 2020
1 parent 8ce6355 commit 6b3298b
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 22 deletions.
13 changes: 9 additions & 4 deletions core/src/main/java/io/grpc/internal/ManagedChannelImpl.java
Expand Up @@ -152,7 +152,9 @@ public Result selectConfig(PickSubchannelArgs args) {
private final NameResolver.Factory nameResolverFactory;
private final NameResolver.Args nameResolverArgs;
private final AutoConfiguredLoadBalancerFactory loadBalancerFactory;
private final ClientTransportFactory originalTransportFactory;
private final ClientTransportFactory transportFactory;
private final ClientTransportFactory oobTransportFactory;
private final RestrictedScheduledExecutor scheduledExecutor;
private final Executor executor;
private final ObjectPool<? extends Executor> executorPool;
Expand Down Expand Up @@ -591,8 +593,11 @@ ClientStream newSubstream(ClientStreamTracer.Factory tracerFactory, Metadata new
this.timeProvider = checkNotNull(timeProvider, "timeProvider");
this.executorPool = checkNotNull(builder.executorPool, "executorPool");
this.executor = checkNotNull(executorPool.getObject(), "executor");
this.originalTransportFactory = clientTransportFactory;
this.transportFactory = new CallCredentialsApplyingTransportFactory(
clientTransportFactory, builder.callCredentials, this.executor);
this.oobTransportFactory = new CallCredentialsApplyingTransportFactory(
clientTransportFactory, null, this.executor);
this.scheduledExecutor =
new RestrictedScheduledExecutor(transportFactory.getScheduledExecutorService());
maxTraceEvents = builder.maxTraceEvents;
Expand Down Expand Up @@ -1487,7 +1492,7 @@ public ManagedChannel createOobChannel(EquivalentAddressGroup addressGroup, Stri
oobLogId, maxTraceEvents, oobChannelCreationTime,
"OobChannel for " + addressGroup);
final OobChannel oobChannel = new OobChannel(
authority, balancerRpcExecutorPool, transportFactory.getScheduledExecutorService(),
authority, balancerRpcExecutorPool, oobTransportFactory.getScheduledExecutorService(),
syncContext, callTracerFactory.create(), oobChannelTracer, channelz, timeProvider);
channelTracer.reportEvent(new ChannelTrace.Event.Builder()
.setDescription("Child OobChannel created")
Expand Down Expand Up @@ -1517,8 +1522,8 @@ void onStateChange(InternalSubchannel is, ConnectivityStateInfo newState) {

final InternalSubchannel internalSubchannel = new InternalSubchannel(
Collections.singletonList(addressGroup),
authority, userAgent, backoffPolicyProvider, transportFactory,
transportFactory.getScheduledExecutorService(), stopwatchSupplier, syncContext,
authority, userAgent, backoffPolicyProvider, oobTransportFactory,
oobTransportFactory.getScheduledExecutorService(), stopwatchSupplier, syncContext,
// All callback methods are run from syncContext
new ManagedOobChannelCallback(),
channelz,
Expand Down Expand Up @@ -1577,7 +1582,7 @@ public ManagedChannel build() {
// TODO(creamsoup) prevent main channel to shutdown if oob channel is not terminated
return new ManagedChannelImpl(
managedChannelImplBuilder,
transportFactory,
originalTransportFactory,
backoffPolicyProvider,
balancerRpcExecutorPool,
stopwatchSupplier,
Expand Down
Expand Up @@ -312,22 +312,4 @@ public void callOptionAndChanelCreds() {
assertEquals(creds2Value, origHeaders.get(creds2Key));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
}

private abstract static class BaseCallCredentials extends CallCredentials {
@Override public void thisUsesUnstableApi() {}
}

private static class FakeCallCredentials extends BaseCallCredentials {
private final Metadata headers;

public <T> FakeCallCredentials(Metadata.Key<T> key, T value) {
headers = new Metadata();
headers.put(key, value);
}

@Override public void applyRequestMetadata(
RequestInfo requestInfo, Executor appExecutor, CallCredentials.MetadataApplier applier) {
applier.apply(headers);
}
}
}
43 changes: 43 additions & 0 deletions core/src/test/java/io/grpc/internal/FakeCallCredentials.java
@@ -0,0 +1,43 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc.internal;

import io.grpc.CallCredentials;
import io.grpc.Metadata;
import java.util.concurrent.Executor;

/**
* CallCredentials that provides a single, fixed header.
*/
final class FakeCallCredentials extends CallCredentials {
private final Metadata headers;

public <T> FakeCallCredentials(Metadata.Key<T> key, T value) {
headers = new Metadata();
headers.put(key, value);
}

@Override
public void applyRequestMetadata(
CallCredentials.RequestInfo requestInfo,
Executor appExecutor,
CallCredentials.MetadataApplier applier) {
applier.apply(headers);
}

@Override public void thisUsesUnstableApi() {}
}
84 changes: 84 additions & 0 deletions core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
Expand Up @@ -330,7 +330,10 @@ public void setUp() throws Exception {

channelBuilder = new ManagedChannelImplBuilder(TARGET,
new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT));
configureBuilder(channelBuilder);
}

private void configureBuilder(ManagedChannelImplBuilder channelBuilder) {
channelBuilder
.nameResolverFactory(new FakeNameResolverFactory.Builder(expectedUri).build())
.defaultLoadBalancingPolicy(MOCK_POLICY_NAME)
Expand Down Expand Up @@ -1741,6 +1744,87 @@ public void oobchannels() {
.returnObject(balancerRpcExecutor.getScheduledExecutorService());
}

@Test
public void oobchannelsHaveNoChannelCallCredentials() {
Metadata.Key<String> metadataKey =
Metadata.Key.of("token", Metadata.ASCII_STRING_MARSHALLER);
String channelCredValue = "channel-provided call cred";
channelBuilder = new ManagedChannelImplBuilder(TARGET,
new FakeCallCredentials(metadataKey, channelCredValue),
new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT));
configureBuilder(channelBuilder);
createChannel();

// Verify that the normal channel has call creds, to validate configuration
Subchannel subchannel =
createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener);
requestConnectionSafely(helper, subchannel);
MockClientTransportInfo transportInfo = transports.poll();
assertNotNull(transportInfo);
transportInfo.listener.transportReady();
when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(
PickResult.withSubchannel(subchannel));
updateBalancingStateSafely(helper, READY, mockPicker);

String callCredValue = "per-RPC call cred";
CallOptions callOptions = CallOptions.DEFAULT
.withCallCredentials(new FakeCallCredentials(metadataKey, callCredValue));
Metadata headers = new Metadata();
ClientCall<String, Integer> call = channel.newCall(method, callOptions);
call.start(mockCallListener, headers);

verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions));
assertThat(headers.getAll(metadataKey))
.containsExactly(channelCredValue, callCredValue).inOrder();

// Verify that the oob channel does not
ManagedChannel oob = helper.createOobChannel(addressGroup, "oobauthority");

headers = new Metadata();
call = oob.newCall(method, callOptions);
call.start(mockCallListener2, headers);

transportInfo = transports.poll();
assertNotNull(transportInfo);
transportInfo.listener.transportReady();
balancerRpcExecutor.runDueTasks();

verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions));
assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue);
oob.shutdownNow();

// Verify that resolving oob channel does not
oob = helper.createResolvingOobChannelBuilder("oobauthority")
.nameResolverFactory(
new FakeNameResolverFactory.Builder(URI.create("oobauthority")).build())
.defaultLoadBalancingPolicy(MOCK_POLICY_NAME)
.idleTimeout(ManagedChannelImplBuilder.IDLE_MODE_MAX_TIMEOUT_DAYS, TimeUnit.DAYS)
.build();
oob.getState(true);
ArgumentCaptor<Helper> helperCaptor = ArgumentCaptor.forClass(Helper.class);
verify(mockLoadBalancerProvider, times(2)).newLoadBalancer(helperCaptor.capture());
Helper oobHelper = helperCaptor.getValue();

subchannel =
createSubchannelSafely(oobHelper, addressGroup, Attributes.EMPTY, subchannelStateListener);
requestConnectionSafely(oobHelper, subchannel);
transportInfo = transports.poll();
assertNotNull(transportInfo);
transportInfo.listener.transportReady();
SubchannelPicker mockPicker2 = mock(SubchannelPicker.class);
when(mockPicker2.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(
PickResult.withSubchannel(subchannel));
updateBalancingStateSafely(oobHelper, READY, mockPicker2);

headers = new Metadata();
call = oob.newCall(method, callOptions);
call.start(mockCallListener2, headers);

verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions));
assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue);
oob.shutdownNow();
}

@Test
public void oobChannelsWhenChannelShutdownNow() {
createChannel();
Expand Down

0 comments on commit 6b3298b

Please sign in to comment.