Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Create a executorService with a separate pool of threads for channelpool #836

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
53 changes: 50 additions & 3 deletions gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
Expand Up @@ -29,6 +29,7 @@
*/
package com.google.api.gax.grpc;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
Expand All @@ -37,19 +38,23 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;

/**
* A {@link ManagedChannel} that will send requests round robin via a set of channels.
*
* <p>Package-private for internal use.
*/
class ChannelPool extends ManagedChannel {
private static final int EXECUTOR_THREAD_POOL_SIZE = 2;
igorbernstein2 marked this conversation as resolved.
Show resolved Hide resolved
private final ImmutableList<ManagedChannel> channels;
private final AtomicInteger indexTicker = new AtomicInteger();
private final String authority;
@Nullable private ScheduledExecutorService executorService;
igorbernstein2 marked this conversation as resolved.
Show resolved Hide resolved

/**
* Factory method to create a non-refreshing channel pool
Expand All @@ -69,19 +74,35 @@ static ChannelPool create(int poolSize, final ChannelFactory channelFactory) thr
/**
* Factory method to create a refreshing channel pool
*
* <p>Package-private for testing purposes only
*
* @param poolSize number of channels in the pool
* @param channelFactory method to create the channels
* @param executorService periodically refreshes the channels
* @return ChannelPool of refreshing channels
*/
@VisibleForTesting
static ChannelPool createRefreshing(
int poolSize, final ChannelFactory channelFactory, ScheduledExecutorService executorService)
throws IOException {
List<ManagedChannel> channels = new ArrayList<>();
for (int i = 0; i < poolSize; i++) {
channels.add(new RefreshingManagedChannel(channelFactory, executorService));
}
return new ChannelPool(channels);
return new ChannelPool(channels, executorService);
}

/**
* Factory method to create a refreshing channel pool
*
* @param poolSize number of channels in the pool
* @param channelFactory method to create the channels
* @return ChannelPool of refreshing channels
*/
static ChannelPool createRefreshing(int poolSize, final ChannelFactory channelFactory)
throws IOException {
return createRefreshing(
poolSize, channelFactory, Executors.newScheduledThreadPool(EXECUTOR_THREAD_POOL_SIZE));
}

/**
Expand All @@ -90,8 +111,20 @@ static ChannelPool createRefreshing(
* @param channels a List of channels to pool.
*/
private ChannelPool(List<ManagedChannel> channels) {
this(channels, null);
}
igorbernstein2 marked this conversation as resolved.
Show resolved Hide resolved

/**
* Initializes the channel pool. Assumes that all channels have the same authority.
*
* @param channels a List of channels to pool.
* @param executorService periodically refreshes the channels
*/
private ChannelPool(
List<ManagedChannel> channels, @Nullable ScheduledExecutorService executorService) {
this.channels = ImmutableList.copyOf(channels);
authority = channels.get(0).authority();
this.executorService = executorService;
}

/** {@inheritDoc} */
Expand All @@ -118,7 +151,9 @@ public ManagedChannel shutdown() {
for (ManagedChannel channelWrapper : channels) {
channelWrapper.shutdown();
}

if (executorService != null) {
executorService.shutdown();
}
return this;
}

Expand All @@ -130,6 +165,9 @@ public boolean isShutdown() {
return false;
}
}
if (executorService != null && !executorService.isShutdown()) {
return false;
}
return true;
}

Expand All @@ -141,6 +179,9 @@ public boolean isTerminated() {
return false;
}
}
if (executorService != null && !executorService.isTerminated()) {
return false;
}
return true;
}

Expand All @@ -150,6 +191,9 @@ public ManagedChannel shutdownNow() {
for (ManagedChannel channel : channels) {
channel.shutdownNow();
}
if (executorService != null) {
executorService.shutdownNow();
}
return this;
}

Expand All @@ -164,7 +208,10 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE
}
channel.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
}

if (executorService != null) {
long awaitTimeNanos = endTimeNanos - System.nanoTime();
executorService.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
}
return isTerminated();
}

Expand Down
Expand Up @@ -204,9 +204,7 @@ public ManagedChannel createSingleChannel() throws IOException {
};
ManagedChannel outerChannel;
if (channelPrimer != null) {
outerChannel =
ChannelPool.createRefreshing(
realPoolSize, channelFactory, executorProvider.getExecutor());
outerChannel = ChannelPool.createRefreshing(realPoolSize, channelFactory);
} else {
outerChannel = ChannelPool.create(realPoolSize, channelFactory);
}
Expand Down
Expand Up @@ -36,7 +36,6 @@

import com.google.api.core.ApiFunction;
import com.google.api.gax.core.ExecutorProvider;
import com.google.api.gax.core.FixedExecutorProvider;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.EnvironmentProvider;
import com.google.api.gax.rpc.HeaderProvider;
Expand All @@ -47,20 +46,14 @@
import io.grpc.ManagedChannelBuilder;
import io.grpc.alts.ComputeEngineChannelBuilder;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.threeten.bp.Duration;

@RunWith(JUnit4.class)
Expand Down Expand Up @@ -368,52 +361,26 @@ public void testWithIPv6Address() throws IOException {
provider.getTransportChannel().shutdownNow();
}

// Test that if ChannelPrimer is provided, it is called during creation and periodically
// Test that if ChannelPrimer is provided, it is called during creation
@Test
public void testWithPrimeChannel() throws IOException {

final ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class);
final List<Runnable> channelRefreshers = new ArrayList<>();

ScheduledExecutorService scheduledExecutorService =
Mockito.mock(ScheduledExecutorService.class);

Answer extractChannelRefresher =
new Answer() {
public Object answer(InvocationOnMock invocation) {
channelRefreshers.add((Runnable) invocation.getArgument(0));
return Mockito.mock(ScheduledFuture.class);
}
};

Mockito.doAnswer(extractChannelRefresher)
.when(scheduledExecutorService)
.schedule(
Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS));

InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setPoolSize(2)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutorProvider(FixedExecutorProvider.create(scheduledExecutorService))
.setChannelPrimer(mockChannelPrimer)
.build();

provider.getTransportChannel().shutdownNow();

// 2 calls during the creation, 2 more calls when they get scheduled
Mockito.verify(mockChannelPrimer, Mockito.times(2))
.primeChannel(Mockito.any(ManagedChannel.class));
assertThat(channelRefreshers).hasSize(2);
Mockito.verify(scheduledExecutorService, Mockito.times(2))
.schedule(
Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS));
channelRefreshers.get(0).run();
Mockito.verify(mockChannelPrimer, Mockito.times(3))
.primeChannel(Mockito.any(ManagedChannel.class));
channelRefreshers.get(1).run();
Mockito.verify(mockChannelPrimer, Mockito.times(4))
.primeChannel(Mockito.any(ManagedChannel.class));
for (int poolSize = 1; poolSize < 5; poolSize++) {
final ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class);

InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setPoolSize(poolSize)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutorProvider(Mockito.mock(ExecutorProvider.class))
.setChannelPrimer(mockChannelPrimer)
.build();

provider.getTransportChannel().shutdownNow();

// every channel in the pool should call primeChannel during creation.
Mockito.verify(mockChannelPrimer, Mockito.times(poolSize))
.primeChannel(Mockito.any(ManagedChannel.class));
}
}
}