diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
index da4dfac69..d6b85275b 100644
--- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
+++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
@@ -29,6 +29,7 @@
*/
package com.google.api.gax.grpc;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
@@ -37,9 +38,11 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.Nullable;
/**
* A {@link ManagedChannel} that will send requests round robin via a set of channels.
@@ -47,9 +50,14 @@
*
Package-private for internal use.
*/
class ChannelPool extends ManagedChannel {
+ // size greater than 1 to allow multiple channel to refresh at the same time
+ // size not too large so refreshing channels doesn't use too many threads
+ private static final int CHANNEL_REFRESH_EXECUTOR_SIZE = 2;
private final ImmutableList channels;
private final AtomicInteger indexTicker = new AtomicInteger();
private final String authority;
+ // if set, ChannelPool will manage the life cycle of channelRefreshExecutorService
+ @Nullable private ScheduledExecutorService channelRefreshExecutorService;
/**
* Factory method to create a non-refreshing channel pool
@@ -63,35 +71,58 @@ static ChannelPool create(int poolSize, final ChannelFactory channelFactory) thr
for (int i = 0; i < poolSize; i++) {
channels.add(channelFactory.createSingleChannel());
}
- return new ChannelPool(channels);
+ return new ChannelPool(channels, null);
}
/**
* Factory method to create a refreshing channel pool
*
+ * Package-private for testing purposes only
+ *
* @param poolSize number of channels in the pool
* @param channelFactory method to create the channels
- * @param executorService periodically refreshes the channels
+ * @param channelRefreshExecutorService periodically refreshes the channels; its life cycle will
+ * be managed by ChannelPool
* @return ChannelPool of refreshing channels
*/
+ @VisibleForTesting
static ChannelPool createRefreshing(
- int poolSize, final ChannelFactory channelFactory, ScheduledExecutorService executorService)
+ int poolSize,
+ final ChannelFactory channelFactory,
+ ScheduledExecutorService channelRefreshExecutorService)
throws IOException {
List channels = new ArrayList<>();
for (int i = 0; i < poolSize; i++) {
- channels.add(new RefreshingManagedChannel(channelFactory, executorService));
+ channels.add(new RefreshingManagedChannel(channelFactory, channelRefreshExecutorService));
}
- return new ChannelPool(channels);
+ return new ChannelPool(channels, channelRefreshExecutorService);
+ }
+
+ /**
+ * Factory method to create a refreshing channel pool
+ *
+ * @param poolSize number of channels in the pool
+ * @param channelFactory method to create the channels
+ * @return ChannelPool of refreshing channels
+ */
+ static ChannelPool createRefreshing(int poolSize, final ChannelFactory channelFactory)
+ throws IOException {
+ return createRefreshing(
+ poolSize, channelFactory, Executors.newScheduledThreadPool(CHANNEL_REFRESH_EXECUTOR_SIZE));
}
/**
* Initializes the channel pool. Assumes that all channels have the same authority.
*
* @param channels a List of channels to pool.
+ * @param channelRefreshExecutorService periodically refreshes the channels
*/
- private ChannelPool(List channels) {
+ private ChannelPool(
+ List channels,
+ @Nullable ScheduledExecutorService channelRefreshExecutorService) {
this.channels = ImmutableList.copyOf(channels);
authority = channels.get(0).authority();
+ this.channelRefreshExecutorService = channelRefreshExecutorService;
}
/** {@inheritDoc} */
@@ -118,7 +149,9 @@ public ManagedChannel shutdown() {
for (ManagedChannel channelWrapper : channels) {
channelWrapper.shutdown();
}
-
+ if (channelRefreshExecutorService != null) {
+ channelRefreshExecutorService.shutdown();
+ }
return this;
}
@@ -130,6 +163,9 @@ public boolean isShutdown() {
return false;
}
}
+ if (channelRefreshExecutorService != null && !channelRefreshExecutorService.isShutdown()) {
+ return false;
+ }
return true;
}
@@ -141,6 +177,9 @@ public boolean isTerminated() {
return false;
}
}
+ if (channelRefreshExecutorService != null && !channelRefreshExecutorService.isTerminated()) {
+ return false;
+ }
return true;
}
@@ -150,6 +189,9 @@ public ManagedChannel shutdownNow() {
for (ManagedChannel channel : channels) {
channel.shutdownNow();
}
+ if (channelRefreshExecutorService != null) {
+ channelRefreshExecutorService.shutdownNow();
+ }
return this;
}
@@ -164,7 +206,10 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE
}
channel.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
}
-
+ if (channelRefreshExecutorService != null) {
+ long awaitTimeNanos = endTimeNanos - System.nanoTime();
+ channelRefreshExecutorService.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
+ }
return isTerminated();
}
diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java
index 211da3683..afe1f322a 100644
--- a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java
+++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java
@@ -204,9 +204,7 @@ public ManagedChannel createSingleChannel() throws IOException {
};
ManagedChannel outerChannel;
if (channelPrimer != null) {
- outerChannel =
- ChannelPool.createRefreshing(
- realPoolSize, channelFactory, executorProvider.getExecutor());
+ outerChannel = ChannelPool.createRefreshing(realPoolSize, channelFactory);
} else {
outerChannel = ChannelPool.create(realPoolSize, channelFactory);
}
diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java
index aec7df9b7..9d2df5756 100644
--- a/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java
+++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java
@@ -36,7 +36,6 @@
import com.google.api.core.ApiFunction;
import com.google.api.gax.core.ExecutorProvider;
-import com.google.api.gax.core.FixedExecutorProvider;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.EnvironmentProvider;
import com.google.api.gax.rpc.HeaderProvider;
@@ -47,20 +46,14 @@
import io.grpc.ManagedChannelBuilder;
import io.grpc.alts.ComputeEngineChannelBuilder;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.Collections;
-import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ScheduledThreadPoolExecutor;
-import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
import org.threeten.bp.Duration;
@RunWith(JUnit4.class)
@@ -368,52 +361,28 @@ public void testWithIPv6Address() throws IOException {
provider.getTransportChannel().shutdownNow();
}
- // Test that if ChannelPrimer is provided, it is called during creation and periodically
+ // Test that if ChannelPrimer is provided, it is called during creation
@Test
public void testWithPrimeChannel() throws IOException {
-
- final ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class);
- final List channelRefreshers = new ArrayList<>();
-
- ScheduledExecutorService scheduledExecutorService =
- Mockito.mock(ScheduledExecutorService.class);
-
- Answer extractChannelRefresher =
- new Answer() {
- public Object answer(InvocationOnMock invocation) {
- channelRefreshers.add((Runnable) invocation.getArgument(0));
- return Mockito.mock(ScheduledFuture.class);
- }
- };
-
- Mockito.doAnswer(extractChannelRefresher)
- .when(scheduledExecutorService)
- .schedule(
- Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS));
-
- InstantiatingGrpcChannelProvider provider =
- InstantiatingGrpcChannelProvider.newBuilder()
- .setEndpoint("localhost:8080")
- .setPoolSize(2)
- .setHeaderProvider(Mockito.mock(HeaderProvider.class))
- .setExecutorProvider(FixedExecutorProvider.create(scheduledExecutorService))
- .setChannelPrimer(mockChannelPrimer)
- .build();
-
- provider.getTransportChannel().shutdownNow();
-
- // 2 calls during the creation, 2 more calls when they get scheduled
- Mockito.verify(mockChannelPrimer, Mockito.times(2))
- .primeChannel(Mockito.any(ManagedChannel.class));
- assertThat(channelRefreshers).hasSize(2);
- Mockito.verify(scheduledExecutorService, Mockito.times(2))
- .schedule(
- Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS));
- channelRefreshers.get(0).run();
- Mockito.verify(mockChannelPrimer, Mockito.times(3))
- .primeChannel(Mockito.any(ManagedChannel.class));
- channelRefreshers.get(1).run();
- Mockito.verify(mockChannelPrimer, Mockito.times(4))
- .primeChannel(Mockito.any(ManagedChannel.class));
+ // create channelProvider with different pool sizes to verify ChannelPrimer is called the
+ // correct number of times
+ for (int poolSize = 1; poolSize < 5; poolSize++) {
+ final ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class);
+
+ InstantiatingGrpcChannelProvider provider =
+ InstantiatingGrpcChannelProvider.newBuilder()
+ .setEndpoint("localhost:8080")
+ .setPoolSize(poolSize)
+ .setHeaderProvider(Mockito.mock(HeaderProvider.class))
+ .setExecutorProvider(Mockito.mock(ExecutorProvider.class))
+ .setChannelPrimer(mockChannelPrimer)
+ .build();
+
+ provider.getTransportChannel().shutdownNow();
+
+ // every channel in the pool should call primeChannel during creation.
+ Mockito.verify(mockChannelPrimer, Mockito.times(poolSize))
+ .primeChannel(Mockito.any(ManagedChannel.class));
+ }
}
}