Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Commit

Permalink
Add comments and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
tonytanger committed Nov 13, 2019
1 parent b2f280b commit 6b3968b
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 46 deletions.
Expand Up @@ -33,6 +33,7 @@
import io.grpc.ManagedChannel;
import java.io.IOException;

/** This interface represents a factory for creating one ManagedChannel */
@InternalApi("For internal use by google-cloud-java clients only")
public interface ChannelFactory {
ManagedChannel createSingleChannel() throws IOException;
Expand Down
29 changes: 26 additions & 3 deletions gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
Expand Up @@ -50,20 +50,43 @@ class ChannelPool extends ManagedChannel {
private final AtomicInteger indexTicker = new AtomicInteger();
private final String authority;

/**
* Factory method to create a non-refreshing channel pool
*
* @see ChannelPool#ChannelPool(int, ChannelFactory, ScheduledExecutorService, boolean)
*/
static ChannelPool create(int poolSize, final ChannelFactory channelFactory) throws IOException {
return new ChannelPool(poolSize, channelFactory, null, false);
}

/**
* Factory method to create a refreshing channel pool
*
* @see ChannelPool#ChannelPool(int, ChannelFactory, ScheduledExecutorService, boolean)
*/
static ChannelPool createRefreshing(
int poolSize, final ChannelFactory channelFactory, ScheduledExecutorService executorService)
throws IOException {
return new ChannelPool(poolSize, channelFactory, executorService, true);
}

/**
* Initializes the channel pool. Assumes that all channels have the same authority.
*
* @param poolSize number of channels in the pool
* @param channelFactory method to create the channels
* @param executorService if set, schedule periodically refresh the channels
*/
ChannelPool(
int poolSize, final ChannelFactory channelFactory, ScheduledExecutorService executorService)
private ChannelPool(
int poolSize,
final ChannelFactory channelFactory,
ScheduledExecutorService executorService,
boolean isRefreshing)
throws IOException {
channels = new ArrayList<>(poolSize);
// if executorService is available, create RefreshingManagedChannel that will get refreshed.
// otherwise create with regular ManagedChannel
if (executorService != null) {
if (isRefreshing) {
for (int i = 0; i < poolSize; i++) {
channels.add(new RefreshingManagedChannel(channelFactory, executorService));
}
Expand Down
Expand Up @@ -32,6 +32,7 @@
import com.google.api.core.InternalApi;
import io.grpc.ManagedChannel;

/** An interface to prepare a ManagedChannel for normal requests by priming the channel */
@InternalApi("For internal use by google-cloud-java clients only")
public interface ChannelPrimer {
void primeChannel(ManagedChannel managedChannel);
Expand Down
Expand Up @@ -31,6 +31,7 @@

import com.google.api.core.ApiFunction;
import com.google.api.core.BetaApi;
import com.google.api.core.InternalApi;
import com.google.api.core.InternalExtensionOnly;
import com.google.api.gax.core.ExecutorProvider;
import com.google.api.gax.core.FixedExecutorProvider;
Expand Down Expand Up @@ -70,6 +71,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
static final String DIRECT_PATH_ENV_VAR = "GOOGLE_CLOUD_ENABLE_DIRECT_PATH";
static final long DIRECT_PATH_KEEP_ALIVE_TIME_SECONDS = 3600;
static final long DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS = 20;
// reduce the thundering herd problem of too many channels trying to (re)connect at the same time
static final int MAX_POOL_SIZE = 1000;

private final int processorCount;
Expand Down Expand Up @@ -206,9 +208,11 @@ public ManagedChannel createSingleChannel() throws IOException {
}
};
if (channelPrimer != null) {
outerChannel = new ChannelPool(realPoolSize, channelFactory, executorProvider.getExecutor());
outerChannel =
ChannelPool.createRefreshing(
realPoolSize, channelFactory, executorProvider.getExecutor());
} else {
outerChannel = new ChannelPool(realPoolSize, channelFactory, null);
outerChannel = ChannelPool.create(realPoolSize, channelFactory);
}

return GrpcTransportChannel.create(outerChannel);
Expand Down Expand Up @@ -553,6 +557,14 @@ public Builder setCredentials(Credentials credentials) {
return this;
}

/**
* By setting a channelPrimer, the ChannelPool created by the provider will be refreshing
* ChannelPool. channelPrimer will be invoked periodically when the channels are refreshed
*
* @param channelPrimer invoked when the channels are refreshed
* @return builder for the provider
*/
@InternalApi("For internal use by google-cloud-java clients only")
public Builder setChannelPrimer(ChannelPrimer channelPrimer) {
this.channelPrimer = channelPrimer;
return this;
Expand Down
Expand Up @@ -90,8 +90,13 @@ private void refreshChannel() {
SafeShutdownManagedChannel oldChannel = delegate;
lock.writeLock().lock();
try {
// This thread can be interrupted by invoking cancel on nextScheduledRefresh
// Interrupt happens when this thread is blocked on acquiring the write lock because shutdown
// was called and that thread holds the read lock.
// When shutdown completes and releases the read lock and this thread acquires the write lock.
// This thread should not continue because the channel has shutdown. This check ensures that
// this thread terminates without swapping the channel and do not schedule the next refresh.
if (Thread.interrupted()) {
// this refresh has been interrupted, do not swap the channels
return;
}
delegate = newChannel;
Expand All @@ -104,15 +109,17 @@ private void refreshChannel() {

/** Schedule the next instance of refreshing this channel */
private ScheduledFuture<?> scheduleNextRefresh() {
long delayPeriod = refreshPeriod.toMillis();
long jitter = (long) ((Math.random() - 0.5) * jitterPercentage * delayPeriod);
long delay = jitter + delayPeriod;
return scheduledExecutorService.schedule(
new Runnable() {
@Override
public void run() {
refreshChannel();
}
},
(long) ((Math.random() - 0.5) * refreshPeriod.toMillis() * jitterPercentage)
+ refreshPeriod.toMillis(),
delay,
TimeUnit.MILLISECONDS);
}

Expand Down Expand Up @@ -150,34 +157,34 @@ public ManagedChannel shutdown() {

/** {@inheritDoc} */
@Override
public boolean isShutdown() {
public ManagedChannel shutdownNow() {
lock.readLock().lock();
try {
return delegate.isShutdown();
nextScheduledRefresh.cancel(true);
delegate.shutdownNow();
return this;
} finally {
lock.readLock().unlock();
}
}

/** {@inheritDoc} */
@Override
public boolean isTerminated() {
public boolean isShutdown() {
lock.readLock().lock();
try {
return delegate.isTerminated();
return delegate.isShutdown();
} finally {
lock.readLock().unlock();
}
}

/** {@inheritDoc} */
@Override
public ManagedChannel shutdownNow() {
public boolean isTerminated() {
lock.readLock().lock();
try {
nextScheduledRefresh.cancel(true);
delegate.shutdownNow();
return this;
return delegate.isTerminated();
} finally {
lock.readLock().unlock();
}
Expand Down
Expand Up @@ -29,8 +29,10 @@
*/
package com.google.api.gax.grpc;

import com.google.common.base.Preconditions;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ClientCall.Listener;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.ManagedChannel;
Expand Down Expand Up @@ -61,6 +63,9 @@ class SafeShutdownManagedChannel extends ManagedChannel {
/**
* Safely shutdown channel by checking that there are no more outstanding calls. If there are
* outstanding calls, the last call will invoke this method again when it complete
*
* <p>Caller should take care to synchronize with newCall so no new calls are started after
* shutdownSafely is called
*/
void shutdownSafely() {
isShutdownSafely = true;
Expand All @@ -84,15 +89,15 @@ public boolean isShutdown() {

/** {@inheritDoc} */
@Override
public boolean isTerminated() {
return delegate.isTerminated();
public ManagedChannel shutdownNow() {
delegate.shutdownNow();
return this;
}

/** {@inheritDoc} */
@Override
public ManagedChannel shutdownNow() {
delegate.shutdownNow();
return this;
public boolean isTerminated() {
return delegate.isTerminated();
}

/** {@inheritDoc} */
Expand All @@ -101,6 +106,25 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE
return delegate.awaitTermination(timeout, unit);
}

/** Listener that's responsible for decrementing outstandingCalls when the call closes */
private class DecrementOutstandingCalls<RespT> extends SimpleForwardingClientCallListener<RespT> {
DecrementOutstandingCalls(Listener<RespT> delegate) {
super(delegate);
}

@Override
public void onClose(Status status, Metadata trailers) {
// decrement in finally block in case onClose throws an exception
try {
super.onClose(status, trailers);
} finally {
if (outstandingCalls.decrementAndGet() == 0 && isShutdownSafely) {
shutdownSafely();
}
}
}
}

/**
* Wrap client call to decrement outstandingCalls and shutdown channel if necessary when it
* completes
Expand All @@ -111,31 +135,22 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE
private <ReqT, RespT> ClientCall<ReqT, RespT> clientCallWrapper(ClientCall<ReqT, RespT> call) {
return new SimpleForwardingClientCall<ReqT, RespT>(call) {
public void start(Listener<RespT> responseListener, Metadata headers) {
Listener<RespT> forwardingResponseListener =
new SimpleForwardingClientCallListener<RespT>(responseListener) {
@Override
public void onClose(Status status, Metadata trailers) {
// decrement in finally block in case onClose throws an exception
try {
super.onClose(status, trailers);
} finally {
if (outstandingCalls.decrementAndGet() == 0 && isShutdownSafely) {
shutdownSafely();
}
}
}
};
super.start(forwardingResponseListener, headers);
super.start(new DecrementOutstandingCalls<>(responseListener), headers);
}
};
}

/** {@inheritDoc} */
/**
* Caller must take care to synchronize newCall and shutdownSafely in order to avoid race
* conditions of starting new calls after shutdownSafely is called
*
* @see io.grpc.ManagedChannel#newCall(MethodDescriptor, CallOptions)
*/
@Override
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
// increment after client call in case newCall throws an exception
assert (!isShutdownSafely);
Preconditions.checkState(!isShutdownSafely);
ClientCall<RequestT, ResponseT> clientCall =
clientCallWrapper(delegate.newCall(methodDescriptor, callOptions));
outstandingCalls.incrementAndGet();
Expand Down
Expand Up @@ -65,7 +65,7 @@ public void testAuthority() throws IOException {

Mockito.when(sub1.authority()).thenReturn("myAuth");

ChannelPool pool = new ChannelPool(2, new FakeChannelFactory(Arrays.asList(sub1, sub2)), null);
ChannelPool pool = ChannelPool.create(2, new FakeChannelFactory(Arrays.asList(sub1, sub2)));
Truth.assertThat(pool.authority()).isEqualTo("myAuth");
}

Expand All @@ -77,7 +77,7 @@ public void testRoundRobin() throws IOException {
Mockito.when(sub1.authority()).thenReturn("myAuth");

ArrayList<ManagedChannel> channels = Lists.newArrayList(sub1, sub2);
ChannelPool pool = new ChannelPool(2, new FakeChannelFactory(Arrays.asList(sub1, sub2)), null);
ChannelPool pool = ChannelPool.create(2, new FakeChannelFactory(Arrays.asList(sub1, sub2)));

verifyTargetChannel(pool, channels, sub1);
verifyTargetChannel(pool, channels, sub2);
Expand Down Expand Up @@ -138,7 +138,7 @@ public ClientCall<Color, Money> answer(InvocationOnMock invocationOnMock)
}

final ChannelPool pool =
new ChannelPool(numChannels, new FakeChannelFactory(Arrays.asList(channels)), null);
ChannelPool.create(numChannels, new FakeChannelFactory(Arrays.asList(channels)));

int numThreads = 20;
final int numPerThread = 1000;
Expand Down Expand Up @@ -172,8 +172,8 @@ public void channelPrimerShouldBeCalledOnce() throws IOException {
ManagedChannel channel1 = Mockito.mock(ManagedChannel.class);
ManagedChannel channel2 = Mockito.mock(ManagedChannel.class);

new ChannelPool(
2, new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer), null);
ChannelPool.create(
2, new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer));
Mockito.verify(mockChannelPrimer, Mockito.times(2))
.primeChannel(Mockito.any(ManagedChannel.class));
}
Expand Down Expand Up @@ -201,7 +201,7 @@ public Object answer(InvocationOnMock invocation) {
}
});

new ChannelPool(
ChannelPool.createRefreshing(
1,
new FakeChannelFactory(Arrays.asList(channel1, channel2, channel3), mockChannelPrimer),
scheduledExecutorService);
Expand Down
Expand Up @@ -75,8 +75,7 @@ public void testAffinity() throws IOException {
.thenReturn(clientCall0);
Mockito.when(channel1.newCall(Mockito.eq(descriptor), Mockito.<CallOptions>any()))
.thenReturn(clientCall1);
Channel pool =
new ChannelPool(2, new FakeChannelFactory(Arrays.asList(channel0, channel1)), null);
Channel pool = ChannelPool.create(2, new FakeChannelFactory(Arrays.asList(channel0, channel1)));
GrpcCallContext context = GrpcCallContext.createDefault().withChannel(pool);

ClientCall<Color, Money> gotCallA =
Expand Down

0 comments on commit 6b3968b

Please sign in to comment.