diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelFactory.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelFactory.java index 9ae6a0f5ee..13956009f4 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelFactory.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelFactory.java @@ -33,6 +33,7 @@ import io.grpc.ManagedChannel; import java.io.IOException; +/** This interface represents a factory for creating one ManagedChannel */ @InternalApi("For internal use by google-cloud-java clients only") public interface ChannelFactory { ManagedChannel createSingleChannel() throws IOException; diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index de5dc30d8c..1aadb7da02 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -50,6 +50,26 @@ class ChannelPool extends ManagedChannel { private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; + /** + * Factory method to create a non-refreshing channel pool + * + * @see ChannelPool#ChannelPool(int, ChannelFactory, ScheduledExecutorService, boolean) + */ + static ChannelPool create(int poolSize, final ChannelFactory channelFactory) throws IOException { + return new ChannelPool(poolSize, channelFactory, null, false); + } + + /** + * Factory method to create a refreshing channel pool + * + * @see ChannelPool#ChannelPool(int, ChannelFactory, ScheduledExecutorService, boolean) + */ + static ChannelPool createRefreshing( + int poolSize, final ChannelFactory channelFactory, ScheduledExecutorService executorService) + throws IOException { + return new ChannelPool(poolSize, channelFactory, executorService, true); + } + /** * Initializes the channel pool. Assumes that all channels have the same authority. * @@ -57,13 +77,16 @@ class ChannelPool extends ManagedChannel { * @param channelFactory method to create the channels * @param executorService if set, schedule periodically refresh the channels */ - ChannelPool( - int poolSize, final ChannelFactory channelFactory, ScheduledExecutorService executorService) + private ChannelPool( + int poolSize, + final ChannelFactory channelFactory, + ScheduledExecutorService executorService, + boolean isRefreshing) throws IOException { channels = new ArrayList<>(poolSize); // if executorService is available, create RefreshingManagedChannel that will get refreshed. // otherwise create with regular ManagedChannel - if (executorService != null) { + if (isRefreshing) { for (int i = 0; i < poolSize; i++) { channels.add(new RefreshingManagedChannel(channelFactory, executorService)); } diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPrimer.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPrimer.java index ef87b2c34c..683859a659 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPrimer.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPrimer.java @@ -32,6 +32,7 @@ import com.google.api.core.InternalApi; import io.grpc.ManagedChannel; +/** An interface to prepare a ManagedChannel for normal requests by priming the channel */ @InternalApi("For internal use by google-cloud-java clients only") public interface ChannelPrimer { void primeChannel(ManagedChannel managedChannel); diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java index 1098691132..1021821b17 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java @@ -31,6 +31,7 @@ import com.google.api.core.ApiFunction; import com.google.api.core.BetaApi; +import com.google.api.core.InternalApi; import com.google.api.core.InternalExtensionOnly; import com.google.api.gax.core.ExecutorProvider; import com.google.api.gax.core.FixedExecutorProvider; @@ -70,6 +71,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP static final String DIRECT_PATH_ENV_VAR = "GOOGLE_CLOUD_ENABLE_DIRECT_PATH"; static final long DIRECT_PATH_KEEP_ALIVE_TIME_SECONDS = 3600; static final long DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS = 20; + // reduce the thundering herd problem of too many channels trying to (re)connect at the same time static final int MAX_POOL_SIZE = 1000; private final int processorCount; @@ -206,9 +208,11 @@ public ManagedChannel createSingleChannel() throws IOException { } }; if (channelPrimer != null) { - outerChannel = new ChannelPool(realPoolSize, channelFactory, executorProvider.getExecutor()); + outerChannel = + ChannelPool.createRefreshing( + realPoolSize, channelFactory, executorProvider.getExecutor()); } else { - outerChannel = new ChannelPool(realPoolSize, channelFactory, null); + outerChannel = ChannelPool.create(realPoolSize, channelFactory); } return GrpcTransportChannel.create(outerChannel); @@ -553,6 +557,14 @@ public Builder setCredentials(Credentials credentials) { return this; } + /** + * By setting a channelPrimer, the ChannelPool created by the provider will be refreshing + * ChannelPool. channelPrimer will be invoked periodically when the channels are refreshed + * + * @param channelPrimer invoked when the channels are refreshed + * @return builder for the provider + */ + @InternalApi("For internal use by google-cloud-java clients only") public Builder setChannelPrimer(ChannelPrimer channelPrimer) { this.channelPrimer = channelPrimer; return this; diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/RefreshingManagedChannel.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/RefreshingManagedChannel.java index 84b65d6376..7ad068afdb 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/RefreshingManagedChannel.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/RefreshingManagedChannel.java @@ -90,8 +90,13 @@ private void refreshChannel() { SafeShutdownManagedChannel oldChannel = delegate; lock.writeLock().lock(); try { + // This thread can be interrupted by invoking cancel on nextScheduledRefresh + // Interrupt happens when this thread is blocked on acquiring the write lock because shutdown + // was called and that thread holds the read lock. + // When shutdown completes and releases the read lock and this thread acquires the write lock. + // This thread should not continue because the channel has shutdown. This check ensures that + // this thread terminates without swapping the channel and do not schedule the next refresh. if (Thread.interrupted()) { - // this refresh has been interrupted, do not swap the channels return; } delegate = newChannel; @@ -104,6 +109,9 @@ private void refreshChannel() { /** Schedule the next instance of refreshing this channel */ private ScheduledFuture scheduleNextRefresh() { + long delayPeriod = refreshPeriod.toMillis(); + long jitter = (long) ((Math.random() - 0.5) * jitterPercentage * delayPeriod); + long delay = jitter + delayPeriod; return scheduledExecutorService.schedule( new Runnable() { @Override @@ -111,8 +119,7 @@ public void run() { refreshChannel(); } }, - (long) ((Math.random() - 0.5) * refreshPeriod.toMillis() * jitterPercentage) - + refreshPeriod.toMillis(), + delay, TimeUnit.MILLISECONDS); } @@ -150,10 +157,12 @@ public ManagedChannel shutdown() { /** {@inheritDoc} */ @Override - public boolean isShutdown() { + public ManagedChannel shutdownNow() { lock.readLock().lock(); try { - return delegate.isShutdown(); + nextScheduledRefresh.cancel(true); + delegate.shutdownNow(); + return this; } finally { lock.readLock().unlock(); } @@ -161,10 +170,10 @@ public boolean isShutdown() { /** {@inheritDoc} */ @Override - public boolean isTerminated() { + public boolean isShutdown() { lock.readLock().lock(); try { - return delegate.isTerminated(); + return delegate.isShutdown(); } finally { lock.readLock().unlock(); } @@ -172,12 +181,10 @@ public boolean isTerminated() { /** {@inheritDoc} */ @Override - public ManagedChannel shutdownNow() { + public boolean isTerminated() { lock.readLock().lock(); try { - nextScheduledRefresh.cancel(true); - delegate.shutdownNow(); - return this; + return delegate.isTerminated(); } finally { lock.readLock().unlock(); } diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/SafeShutdownManagedChannel.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/SafeShutdownManagedChannel.java index 859c88aefe..06d8f622c1 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/SafeShutdownManagedChannel.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/SafeShutdownManagedChannel.java @@ -29,8 +29,10 @@ */ package com.google.api.gax.grpc; +import com.google.common.base.Preconditions; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.ClientCall.Listener; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; import io.grpc.ManagedChannel; @@ -61,6 +63,9 @@ class SafeShutdownManagedChannel extends ManagedChannel { /** * Safely shutdown channel by checking that there are no more outstanding calls. If there are * outstanding calls, the last call will invoke this method again when it complete + * + *

Caller should take care to synchronize with newCall so no new calls are started after + * shutdownSafely is called */ void shutdownSafely() { isShutdownSafely = true; @@ -84,15 +89,15 @@ public boolean isShutdown() { /** {@inheritDoc} */ @Override - public boolean isTerminated() { - return delegate.isTerminated(); + public ManagedChannel shutdownNow() { + delegate.shutdownNow(); + return this; } /** {@inheritDoc} */ @Override - public ManagedChannel shutdownNow() { - delegate.shutdownNow(); - return this; + public boolean isTerminated() { + return delegate.isTerminated(); } /** {@inheritDoc} */ @@ -101,6 +106,25 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE return delegate.awaitTermination(timeout, unit); } + /** Listener that's responsible for decrementing outstandingCalls when the call closes */ + private class DecrementOutstandingCalls extends SimpleForwardingClientCallListener { + DecrementOutstandingCalls(Listener delegate) { + super(delegate); + } + + @Override + public void onClose(Status status, Metadata trailers) { + // decrement in finally block in case onClose throws an exception + try { + super.onClose(status, trailers); + } finally { + if (outstandingCalls.decrementAndGet() == 0 && isShutdownSafely) { + shutdownSafely(); + } + } + } + } + /** * Wrap client call to decrement outstandingCalls and shutdown channel if necessary when it * completes @@ -111,31 +135,22 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE private ClientCall clientCallWrapper(ClientCall call) { return new SimpleForwardingClientCall(call) { public void start(Listener responseListener, Metadata headers) { - Listener forwardingResponseListener = - new SimpleForwardingClientCallListener(responseListener) { - @Override - public void onClose(Status status, Metadata trailers) { - // decrement in finally block in case onClose throws an exception - try { - super.onClose(status, trailers); - } finally { - if (outstandingCalls.decrementAndGet() == 0 && isShutdownSafely) { - shutdownSafely(); - } - } - } - }; - super.start(forwardingResponseListener, headers); + super.start(new DecrementOutstandingCalls<>(responseListener), headers); } }; } - /** {@inheritDoc} */ + /** + * Caller must take care to synchronize newCall and shutdownSafely in order to avoid race + * conditions of starting new calls after shutdownSafely is called + * + * @see io.grpc.ManagedChannel#newCall(MethodDescriptor, CallOptions) + */ @Override public ClientCall newCall( MethodDescriptor methodDescriptor, CallOptions callOptions) { // increment after client call in case newCall throws an exception - assert (!isShutdownSafely); + Preconditions.checkState(!isShutdownSafely); ClientCall clientCall = clientCallWrapper(delegate.newCall(methodDescriptor, callOptions)); outstandingCalls.incrementAndGet(); diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index 1ac664a0be..0218d85d8c 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -65,7 +65,7 @@ public void testAuthority() throws IOException { Mockito.when(sub1.authority()).thenReturn("myAuth"); - ChannelPool pool = new ChannelPool(2, new FakeChannelFactory(Arrays.asList(sub1, sub2)), null); + ChannelPool pool = ChannelPool.create(2, new FakeChannelFactory(Arrays.asList(sub1, sub2))); Truth.assertThat(pool.authority()).isEqualTo("myAuth"); } @@ -77,7 +77,7 @@ public void testRoundRobin() throws IOException { Mockito.when(sub1.authority()).thenReturn("myAuth"); ArrayList channels = Lists.newArrayList(sub1, sub2); - ChannelPool pool = new ChannelPool(2, new FakeChannelFactory(Arrays.asList(sub1, sub2)), null); + ChannelPool pool = ChannelPool.create(2, new FakeChannelFactory(Arrays.asList(sub1, sub2))); verifyTargetChannel(pool, channels, sub1); verifyTargetChannel(pool, channels, sub2); @@ -138,7 +138,7 @@ public ClientCall answer(InvocationOnMock invocationOnMock) } final ChannelPool pool = - new ChannelPool(numChannels, new FakeChannelFactory(Arrays.asList(channels)), null); + ChannelPool.create(numChannels, new FakeChannelFactory(Arrays.asList(channels))); int numThreads = 20; final int numPerThread = 1000; @@ -172,8 +172,8 @@ public void channelPrimerShouldBeCalledOnce() throws IOException { ManagedChannel channel1 = Mockito.mock(ManagedChannel.class); ManagedChannel channel2 = Mockito.mock(ManagedChannel.class); - new ChannelPool( - 2, new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer), null); + ChannelPool.create( + 2, new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer)); Mockito.verify(mockChannelPrimer, Mockito.times(2)) .primeChannel(Mockito.any(ManagedChannel.class)); } @@ -201,7 +201,7 @@ public Object answer(InvocationOnMock invocation) { } }); - new ChannelPool( + ChannelPool.createRefreshing( 1, new FakeChannelFactory(Arrays.asList(channel1, channel2, channel3), mockChannelPrimer), scheduledExecutorService); diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java index 6f0465c650..9722f8634b 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java @@ -75,8 +75,7 @@ public void testAffinity() throws IOException { .thenReturn(clientCall0); Mockito.when(channel1.newCall(Mockito.eq(descriptor), Mockito.any())) .thenReturn(clientCall1); - Channel pool = - new ChannelPool(2, new FakeChannelFactory(Arrays.asList(channel0, channel1)), null); + Channel pool = ChannelPool.create(2, new FakeChannelFactory(Arrays.asList(channel0, channel1))); GrpcCallContext context = GrpcCallContext.createDefault().withChannel(pool); ClientCall gotCallA =