Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Create a executorService with a separate pool of threads for channelpool #836

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 53 additions & 8 deletions gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
Expand Up @@ -29,6 +29,7 @@
*/
package com.google.api.gax.grpc;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
Expand All @@ -37,19 +38,26 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;

/**
* A {@link ManagedChannel} that will send requests round robin via a set of channels.
*
* <p>Package-private for internal use.
*/
class ChannelPool extends ManagedChannel {
// size greater than 1 to allow multiple channel to refresh at the same time
// size not too large so refreshing channels doesn't use too many threads
private static final int CHANNEL_REFRESH_EXECUTOR_SIZE = 2;
private final ImmutableList<ManagedChannel> channels;
private final AtomicInteger indexTicker = new AtomicInteger();
private final String authority;
// if set, ChannelPool will manage the life cycle of channelRefreshExecutorService
@Nullable private ScheduledExecutorService channelRefreshExecutorService;

/**
* Factory method to create a non-refreshing channel pool
Expand All @@ -63,35 +71,58 @@ static ChannelPool create(int poolSize, final ChannelFactory channelFactory) thr
for (int i = 0; i < poolSize; i++) {
channels.add(channelFactory.createSingleChannel());
}
return new ChannelPool(channels);
return new ChannelPool(channels, null);
}

/**
* Factory method to create a refreshing channel pool
*
* <p>Package-private for testing purposes only
*
* @param poolSize number of channels in the pool
* @param channelFactory method to create the channels
* @param executorService periodically refreshes the channels
* @param channelRefreshExecutorService periodically refreshes the channels; its life cycle will
* be managed by ChannelPool
* @return ChannelPool of refreshing channels
*/
@VisibleForTesting
static ChannelPool createRefreshing(
int poolSize, final ChannelFactory channelFactory, ScheduledExecutorService executorService)
int poolSize,
final ChannelFactory channelFactory,
ScheduledExecutorService channelRefreshExecutorService)
throws IOException {
List<ManagedChannel> channels = new ArrayList<>();
for (int i = 0; i < poolSize; i++) {
channels.add(new RefreshingManagedChannel(channelFactory, executorService));
channels.add(new RefreshingManagedChannel(channelFactory, channelRefreshExecutorService));
}
return new ChannelPool(channels);
return new ChannelPool(channels, channelRefreshExecutorService);
}

/**
* Factory method to create a refreshing channel pool
*
* @param poolSize number of channels in the pool
* @param channelFactory method to create the channels
* @return ChannelPool of refreshing channels
*/
static ChannelPool createRefreshing(int poolSize, final ChannelFactory channelFactory)
throws IOException {
return createRefreshing(
poolSize, channelFactory, Executors.newScheduledThreadPool(CHANNEL_REFRESH_EXECUTOR_SIZE));
}

/**
* Initializes the channel pool. Assumes that all channels have the same authority.
*
* @param channels a List of channels to pool.
* @param channelRefreshExecutorService periodically refreshes the channels
*/
private ChannelPool(List<ManagedChannel> channels) {
private ChannelPool(
List<ManagedChannel> channels,
@Nullable ScheduledExecutorService channelRefreshExecutorService) {
this.channels = ImmutableList.copyOf(channels);
authority = channels.get(0).authority();
this.channelRefreshExecutorService = channelRefreshExecutorService;
}

/** {@inheritDoc} */
Expand All @@ -118,7 +149,9 @@ public ManagedChannel shutdown() {
for (ManagedChannel channelWrapper : channels) {
channelWrapper.shutdown();
}

if (channelRefreshExecutorService != null) {
channelRefreshExecutorService.shutdown();
}
return this;
}

Expand All @@ -130,6 +163,9 @@ public boolean isShutdown() {
return false;
}
}
if (channelRefreshExecutorService != null && !channelRefreshExecutorService.isShutdown()) {
return false;
}
return true;
}

Expand All @@ -141,6 +177,9 @@ public boolean isTerminated() {
return false;
}
}
if (channelRefreshExecutorService != null && !channelRefreshExecutorService.isTerminated()) {
return false;
}
return true;
}

Expand All @@ -150,6 +189,9 @@ public ManagedChannel shutdownNow() {
for (ManagedChannel channel : channels) {
channel.shutdownNow();
}
if (channelRefreshExecutorService != null) {
channelRefreshExecutorService.shutdownNow();
}
return this;
}

Expand All @@ -164,7 +206,10 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE
}
channel.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
}

if (channelRefreshExecutorService != null) {
long awaitTimeNanos = endTimeNanos - System.nanoTime();
channelRefreshExecutorService.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
}
return isTerminated();
}

Expand Down
Expand Up @@ -204,9 +204,7 @@ public ManagedChannel createSingleChannel() throws IOException {
};
ManagedChannel outerChannel;
if (channelPrimer != null) {
outerChannel =
ChannelPool.createRefreshing(
realPoolSize, channelFactory, executorProvider.getExecutor());
outerChannel = ChannelPool.createRefreshing(realPoolSize, channelFactory);
} else {
outerChannel = ChannelPool.create(realPoolSize, channelFactory);
}
Expand Down
Expand Up @@ -36,7 +36,6 @@

import com.google.api.core.ApiFunction;
import com.google.api.gax.core.ExecutorProvider;
import com.google.api.gax.core.FixedExecutorProvider;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.EnvironmentProvider;
import com.google.api.gax.rpc.HeaderProvider;
Expand All @@ -47,20 +46,14 @@
import io.grpc.ManagedChannelBuilder;
import io.grpc.alts.ComputeEngineChannelBuilder;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.threeten.bp.Duration;

@RunWith(JUnit4.class)
Expand Down Expand Up @@ -368,52 +361,28 @@ public void testWithIPv6Address() throws IOException {
provider.getTransportChannel().shutdownNow();
}

// Test that if ChannelPrimer is provided, it is called during creation and periodically
// Test that if ChannelPrimer is provided, it is called during creation
@Test
public void testWithPrimeChannel() throws IOException {

final ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class);
final List<Runnable> channelRefreshers = new ArrayList<>();

ScheduledExecutorService scheduledExecutorService =
Mockito.mock(ScheduledExecutorService.class);

Answer extractChannelRefresher =
new Answer() {
public Object answer(InvocationOnMock invocation) {
channelRefreshers.add((Runnable) invocation.getArgument(0));
return Mockito.mock(ScheduledFuture.class);
}
};

Mockito.doAnswer(extractChannelRefresher)
.when(scheduledExecutorService)
.schedule(
Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS));

InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setPoolSize(2)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutorProvider(FixedExecutorProvider.create(scheduledExecutorService))
.setChannelPrimer(mockChannelPrimer)
.build();

provider.getTransportChannel().shutdownNow();

// 2 calls during the creation, 2 more calls when they get scheduled
Mockito.verify(mockChannelPrimer, Mockito.times(2))
.primeChannel(Mockito.any(ManagedChannel.class));
assertThat(channelRefreshers).hasSize(2);
Mockito.verify(scheduledExecutorService, Mockito.times(2))
.schedule(
Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS));
channelRefreshers.get(0).run();
Mockito.verify(mockChannelPrimer, Mockito.times(3))
.primeChannel(Mockito.any(ManagedChannel.class));
channelRefreshers.get(1).run();
Mockito.verify(mockChannelPrimer, Mockito.times(4))
.primeChannel(Mockito.any(ManagedChannel.class));
// create channelProvider with different pool sizes to verify ChannelPrimer is called the
// correct number of times
for (int poolSize = 1; poolSize < 5; poolSize++) {
final ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class);

InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setPoolSize(poolSize)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutorProvider(Mockito.mock(ExecutorProvider.class))
.setChannelPrimer(mockChannelPrimer)
.build();

provider.getTransportChannel().shutdownNow();

// every channel in the pool should call primeChannel during creation.
Mockito.verify(mockChannelPrimer, Mockito.times(poolSize))
.primeChannel(Mockito.any(ManagedChannel.class));
}
}
}