diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index da4dfac69..d6b85275b 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -29,6 +29,7 @@ */ package com.google.api.gax.grpc; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import io.grpc.CallOptions; import io.grpc.ClientCall; @@ -37,9 +38,11 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; /** * A {@link ManagedChannel} that will send requests round robin via a set of channels. @@ -47,9 +50,14 @@ *

Package-private for internal use. */ class ChannelPool extends ManagedChannel { + // size greater than 1 to allow multiple channel to refresh at the same time + // size not too large so refreshing channels doesn't use too many threads + private static final int CHANNEL_REFRESH_EXECUTOR_SIZE = 2; private final ImmutableList channels; private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; + // if set, ChannelPool will manage the life cycle of channelRefreshExecutorService + @Nullable private ScheduledExecutorService channelRefreshExecutorService; /** * Factory method to create a non-refreshing channel pool @@ -63,35 +71,58 @@ static ChannelPool create(int poolSize, final ChannelFactory channelFactory) thr for (int i = 0; i < poolSize; i++) { channels.add(channelFactory.createSingleChannel()); } - return new ChannelPool(channels); + return new ChannelPool(channels, null); } /** * Factory method to create a refreshing channel pool * + *

Package-private for testing purposes only + * * @param poolSize number of channels in the pool * @param channelFactory method to create the channels - * @param executorService periodically refreshes the channels + * @param channelRefreshExecutorService periodically refreshes the channels; its life cycle will + * be managed by ChannelPool * @return ChannelPool of refreshing channels */ + @VisibleForTesting static ChannelPool createRefreshing( - int poolSize, final ChannelFactory channelFactory, ScheduledExecutorService executorService) + int poolSize, + final ChannelFactory channelFactory, + ScheduledExecutorService channelRefreshExecutorService) throws IOException { List channels = new ArrayList<>(); for (int i = 0; i < poolSize; i++) { - channels.add(new RefreshingManagedChannel(channelFactory, executorService)); + channels.add(new RefreshingManagedChannel(channelFactory, channelRefreshExecutorService)); } - return new ChannelPool(channels); + return new ChannelPool(channels, channelRefreshExecutorService); + } + + /** + * Factory method to create a refreshing channel pool + * + * @param poolSize number of channels in the pool + * @param channelFactory method to create the channels + * @return ChannelPool of refreshing channels + */ + static ChannelPool createRefreshing(int poolSize, final ChannelFactory channelFactory) + throws IOException { + return createRefreshing( + poolSize, channelFactory, Executors.newScheduledThreadPool(CHANNEL_REFRESH_EXECUTOR_SIZE)); } /** * Initializes the channel pool. Assumes that all channels have the same authority. * * @param channels a List of channels to pool. + * @param channelRefreshExecutorService periodically refreshes the channels */ - private ChannelPool(List channels) { + private ChannelPool( + List channels, + @Nullable ScheduledExecutorService channelRefreshExecutorService) { this.channels = ImmutableList.copyOf(channels); authority = channels.get(0).authority(); + this.channelRefreshExecutorService = channelRefreshExecutorService; } /** {@inheritDoc} */ @@ -118,7 +149,9 @@ public ManagedChannel shutdown() { for (ManagedChannel channelWrapper : channels) { channelWrapper.shutdown(); } - + if (channelRefreshExecutorService != null) { + channelRefreshExecutorService.shutdown(); + } return this; } @@ -130,6 +163,9 @@ public boolean isShutdown() { return false; } } + if (channelRefreshExecutorService != null && !channelRefreshExecutorService.isShutdown()) { + return false; + } return true; } @@ -141,6 +177,9 @@ public boolean isTerminated() { return false; } } + if (channelRefreshExecutorService != null && !channelRefreshExecutorService.isTerminated()) { + return false; + } return true; } @@ -150,6 +189,9 @@ public ManagedChannel shutdownNow() { for (ManagedChannel channel : channels) { channel.shutdownNow(); } + if (channelRefreshExecutorService != null) { + channelRefreshExecutorService.shutdownNow(); + } return this; } @@ -164,7 +206,10 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE } channel.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS); } - + if (channelRefreshExecutorService != null) { + long awaitTimeNanos = endTimeNanos - System.nanoTime(); + channelRefreshExecutorService.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS); + } return isTerminated(); } diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java index 211da3683..afe1f322a 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java @@ -204,9 +204,7 @@ public ManagedChannel createSingleChannel() throws IOException { }; ManagedChannel outerChannel; if (channelPrimer != null) { - outerChannel = - ChannelPool.createRefreshing( - realPoolSize, channelFactory, executorProvider.getExecutor()); + outerChannel = ChannelPool.createRefreshing(realPoolSize, channelFactory); } else { outerChannel = ChannelPool.create(realPoolSize, channelFactory); } diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java index aec7df9b7..9d2df5756 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java @@ -36,7 +36,6 @@ import com.google.api.core.ApiFunction; import com.google.api.gax.core.ExecutorProvider; -import com.google.api.gax.core.FixedExecutorProvider; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.EnvironmentProvider; import com.google.api.gax.rpc.HeaderProvider; @@ -47,20 +46,14 @@ import io.grpc.ManagedChannelBuilder; import io.grpc.alts.ComputeEngineChannelBuilder; import java.io.IOException; -import java.util.ArrayList; import java.util.Collections; -import java.util.List; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledThreadPoolExecutor; -import java.util.concurrent.TimeUnit; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.threeten.bp.Duration; @RunWith(JUnit4.class) @@ -368,52 +361,28 @@ public void testWithIPv6Address() throws IOException { provider.getTransportChannel().shutdownNow(); } - // Test that if ChannelPrimer is provided, it is called during creation and periodically + // Test that if ChannelPrimer is provided, it is called during creation @Test public void testWithPrimeChannel() throws IOException { - - final ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class); - final List channelRefreshers = new ArrayList<>(); - - ScheduledExecutorService scheduledExecutorService = - Mockito.mock(ScheduledExecutorService.class); - - Answer extractChannelRefresher = - new Answer() { - public Object answer(InvocationOnMock invocation) { - channelRefreshers.add((Runnable) invocation.getArgument(0)); - return Mockito.mock(ScheduledFuture.class); - } - }; - - Mockito.doAnswer(extractChannelRefresher) - .when(scheduledExecutorService) - .schedule( - Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); - - InstantiatingGrpcChannelProvider provider = - InstantiatingGrpcChannelProvider.newBuilder() - .setEndpoint("localhost:8080") - .setPoolSize(2) - .setHeaderProvider(Mockito.mock(HeaderProvider.class)) - .setExecutorProvider(FixedExecutorProvider.create(scheduledExecutorService)) - .setChannelPrimer(mockChannelPrimer) - .build(); - - provider.getTransportChannel().shutdownNow(); - - // 2 calls during the creation, 2 more calls when they get scheduled - Mockito.verify(mockChannelPrimer, Mockito.times(2)) - .primeChannel(Mockito.any(ManagedChannel.class)); - assertThat(channelRefreshers).hasSize(2); - Mockito.verify(scheduledExecutorService, Mockito.times(2)) - .schedule( - Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); - channelRefreshers.get(0).run(); - Mockito.verify(mockChannelPrimer, Mockito.times(3)) - .primeChannel(Mockito.any(ManagedChannel.class)); - channelRefreshers.get(1).run(); - Mockito.verify(mockChannelPrimer, Mockito.times(4)) - .primeChannel(Mockito.any(ManagedChannel.class)); + // create channelProvider with different pool sizes to verify ChannelPrimer is called the + // correct number of times + for (int poolSize = 1; poolSize < 5; poolSize++) { + final ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class); + + InstantiatingGrpcChannelProvider provider = + InstantiatingGrpcChannelProvider.newBuilder() + .setEndpoint("localhost:8080") + .setPoolSize(poolSize) + .setHeaderProvider(Mockito.mock(HeaderProvider.class)) + .setExecutorProvider(Mockito.mock(ExecutorProvider.class)) + .setChannelPrimer(mockChannelPrimer) + .build(); + + provider.getTransportChannel().shutdownNow(); + + // every channel in the pool should call primeChannel during creation. + Mockito.verify(mockChannelPrimer, Mockito.times(poolSize)) + .primeChannel(Mockito.any(ManagedChannel.class)); + } } }