From b0594737e718c7bb0977ac38fe26fa7ce84b2702 Mon Sep 17 00:00:00 2001 From: Tony Tang Date: Thu, 21 Nov 2019 15:14:46 -0500 Subject: [PATCH] Allow channel pool to refresh its channels periodically (#805) * add refresh capability to channelpool * add tests * fix minor issues * respond to comments and add synchronization to newcall and channel swap * clean up concurrency test * Move refreshing logic to a wrapper class of managed channel * change channel pool test to remove reliance on timer * working solution * Move shutdown logic to sub class to safely perform to ensure non-blocking shutdown * Add comments * Inline jitter calculation * Add comments and documentation * Respond to comments * add comment on synchronization * simplified channel pool creation logic * Add for internal use only comment --- .../google/api/gax/grpc/ChannelFactory.java | 44 ++++ .../com/google/api/gax/grpc/ChannelPool.java | 39 +++- .../google/api/gax/grpc/ChannelPrimer.java | 43 ++++ .../InstantiatingGrpcChannelProvider.java | 54 ++++- .../gax/grpc/RefreshingManagedChannel.java | 215 ++++++++++++++++++ .../gax/grpc/SafeShutdownManagedChannel.java | 170 ++++++++++++++ .../google/api/gax/grpc/ChannelPoolTest.java | 86 ++++++- .../api/gax/grpc/GrpcClientCallsTest.java | 7 +- .../InstantiatingGrpcChannelProviderTest.java | 56 +++++ .../grpc/RefreshingManagedChannelTest.java | 206 +++++++++++++++++ .../grpc/SafeShutdownManagedChannelTest.java | 186 +++++++++++++++ .../gax/grpc/testing/FakeChannelFactory.java | 58 +++++ 12 files changed, 1142 insertions(+), 22 deletions(-) create mode 100644 gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelFactory.java create mode 100644 gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPrimer.java create mode 100644 gax-grpc/src/main/java/com/google/api/gax/grpc/RefreshingManagedChannel.java create mode 100644 gax-grpc/src/main/java/com/google/api/gax/grpc/SafeShutdownManagedChannel.java create mode 100644 gax-grpc/src/test/java/com/google/api/gax/grpc/RefreshingManagedChannelTest.java create mode 100644 gax-grpc/src/test/java/com/google/api/gax/grpc/SafeShutdownManagedChannelTest.java create mode 100644 gax-grpc/src/test/java/com/google/api/gax/grpc/testing/FakeChannelFactory.java diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelFactory.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelFactory.java new file mode 100644 index 000000000..416ce4469 --- /dev/null +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelFactory.java @@ -0,0 +1,44 @@ +/* + * Copyright 2019 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.google.api.gax.grpc; + +import com.google.api.core.InternalApi; +import io.grpc.ManagedChannel; +import java.io.IOException; + +/** + * This interface represents a factory for creating one ManagedChannel + * + *

This is public only for technical reasons, for advanced usage. + */ +@InternalApi("For internal use by google-cloud-java clients only") +public interface ChannelFactory { + ManagedChannel createSingleChannel() throws IOException; +} diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index 8f265b686..da4dfac69 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -34,7 +34,10 @@ import io.grpc.ClientCall; import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; +import java.io.IOException; +import java.util.ArrayList; import java.util.List; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -48,12 +51,45 @@ class ChannelPool extends ManagedChannel { private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; + /** + * Factory method to create a non-refreshing channel pool + * + * @param poolSize number of channels in the pool + * @param channelFactory method to create the channels + * @return ChannelPool of non refreshing channels + */ + static ChannelPool create(int poolSize, final ChannelFactory channelFactory) throws IOException { + List channels = new ArrayList<>(); + for (int i = 0; i < poolSize; i++) { + channels.add(channelFactory.createSingleChannel()); + } + return new ChannelPool(channels); + } + + /** + * Factory method to create a refreshing channel pool + * + * @param poolSize number of channels in the pool + * @param channelFactory method to create the channels + * @param executorService periodically refreshes the channels + * @return ChannelPool of refreshing channels + */ + static ChannelPool createRefreshing( + int poolSize, final ChannelFactory channelFactory, ScheduledExecutorService executorService) + throws IOException { + List channels = new ArrayList<>(); + for (int i = 0; i < poolSize; i++) { + channels.add(new RefreshingManagedChannel(channelFactory, executorService)); + } + return new ChannelPool(channels); + } + /** * Initializes the channel pool. Assumes that all channels have the same authority. * * @param channels a List of channels to pool. */ - ChannelPool(List channels) { + private ChannelPool(List channels) { this.channels = ImmutableList.copyOf(channels); authority = channels.get(0).authority(); } @@ -73,7 +109,6 @@ public String authority() { @Override public ClientCall newCall( MethodDescriptor methodDescriptor, CallOptions callOptions) { - return getNextChannel().newCall(methodDescriptor, callOptions); } diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPrimer.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPrimer.java new file mode 100644 index 000000000..5964de957 --- /dev/null +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPrimer.java @@ -0,0 +1,43 @@ +/* + * Copyright 2019 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.google.api.gax.grpc; + +import com.google.api.core.InternalApi; +import io.grpc.ManagedChannel; + +/** + * An interface to prepare a ManagedChannel for normal requests by priming the channel + * + *

This is public only for technical reasons, for advanced usage. + */ +@InternalApi("For internal use by google-cloud-java clients only") +public interface ChannelPrimer { + void primeChannel(ManagedChannel managedChannel); +} diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java index 214dd058c..211da3683 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java @@ -31,6 +31,7 @@ import com.google.api.core.ApiFunction; import com.google.api.core.BetaApi; +import com.google.api.core.InternalApi; import com.google.api.core.InternalExtensionOnly; import com.google.api.gax.core.ExecutorProvider; import com.google.api.gax.core.FixedExecutorProvider; @@ -40,6 +41,7 @@ import com.google.api.gax.rpc.TransportChannelProvider; import com.google.auth.Credentials; import com.google.auth.oauth2.ComputeEngineCredentials; +import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -70,6 +72,8 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP static final String DIRECT_PATH_ENV_VAR = "GOOGLE_CLOUD_ENABLE_DIRECT_PATH"; static final long DIRECT_PATH_KEEP_ALIVE_TIME_SECONDS = 3600; static final long DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS = 20; + // reduce the thundering herd problem of too many channels trying to (re)connect at the same time + static final int MAX_POOL_SIZE = 1000; private final int processorCount; private final ExecutorProvider executorProvider; @@ -84,6 +88,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP @Nullable private final Boolean keepAliveWithoutCalls; @Nullable private final Integer poolSize; @Nullable private final Credentials credentials; + @Nullable private final ChannelPrimer channelPrimer; @Nullable private final ApiFunction channelConfigurator; @@ -103,6 +108,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) { this.poolSize = builder.poolSize; this.channelConfigurator = builder.channelConfigurator; this.credentials = builder.credentials; + this.channelPrimer = builder.channelPrimer; } @Override @@ -188,19 +194,22 @@ public TransportChannel getTransportChannel() throws IOException { } private TransportChannel createChannel() throws IOException { - ManagedChannel outerChannel; - if (poolSize == null || poolSize == 1) { - outerChannel = createSingleChannel(); + int realPoolSize = MoreObjects.firstNonNull(poolSize, 1); + ChannelFactory channelFactory = + new ChannelFactory() { + public ManagedChannel createSingleChannel() throws IOException { + return InstantiatingGrpcChannelProvider.this.createSingleChannel(); + } + }; + ManagedChannel outerChannel; + if (channelPrimer != null) { + outerChannel = + ChannelPool.createRefreshing( + realPoolSize, channelFactory, executorProvider.getExecutor()); } else { - ImmutableList.Builder channels = ImmutableList.builder(); - - for (int i = 0; i < poolSize; i++) { - channels.add(createSingleChannel()); - } - outerChannel = new ChannelPool(channels.build()); + outerChannel = ChannelPool.create(realPoolSize, channelFactory); } - return GrpcTransportChannel.create(outerChannel); } @@ -293,7 +302,11 @@ private ManagedChannel createSingleChannel() throws IOException { builder = channelConfigurator.apply(builder); } - return builder.build(); + ManagedChannel managedChannel = builder.build(); + if (channelPrimer != null) { + channelPrimer.primeChannel(managedChannel); + } + return managedChannel; } /** The endpoint to be used for the channel. */ @@ -350,6 +363,7 @@ public static final class Builder { @Nullable private Integer poolSize; @Nullable private ApiFunction channelConfigurator; @Nullable private Credentials credentials; + @Nullable private ChannelPrimer channelPrimer; private Builder() { processorCount = Runtime.getRuntime().availableProcessors(); @@ -371,6 +385,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) { this.poolSize = provider.poolSize; this.channelConfigurator = provider.channelConfigurator; this.credentials = provider.credentials; + this.channelPrimer = provider.channelPrimer; } /** Sets the number of available CPUs, used internally for testing. */ @@ -509,6 +524,8 @@ public int getPoolSize() { */ public Builder setPoolSize(int poolSize) { Preconditions.checkArgument(poolSize > 0, "Pool size must be positive"); + Preconditions.checkArgument( + poolSize <= MAX_POOL_SIZE, "Pool size must be less than %d", MAX_POOL_SIZE); this.poolSize = poolSize; return this; } @@ -534,6 +551,21 @@ public Builder setCredentials(Credentials credentials) { return this; } + /** + * By setting a channelPrimer, the ChannelPool created by the provider will be refreshing + * ChannelPool. channelPrimer will be invoked periodically when the channels are refreshed + * + *

This is public only for technical reasons, for advanced usage. + * + * @param channelPrimer invoked when the channels are refreshed + * @return builder for the provider + */ + @InternalApi("For internal use by google-cloud-java clients only") + public Builder setChannelPrimer(ChannelPrimer channelPrimer) { + this.channelPrimer = channelPrimer; + return this; + } + public InstantiatingGrpcChannelProvider build() { return new InstantiatingGrpcChannelProvider(this); } diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/RefreshingManagedChannel.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/RefreshingManagedChannel.java new file mode 100644 index 000000000..ec8680d4c --- /dev/null +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/RefreshingManagedChannel.java @@ -0,0 +1,215 @@ +/* + * Copyright 2019 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.google.api.gax.grpc; + +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import java.io.IOException; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.threeten.bp.Duration; + +/** + * A {@link ManagedChannel} that will refresh the underlying channel by swapping the underlying + * channel with a new one periodically + * + *

Package-private for internal use. + * + *

A note on the synchronization logic. refreshChannel is called periodically which updates + * delegate and nextScheduledRefresh. lock is needed to provide atomic access and update of delegate + * and nextScheduledRefresh. One example is newCall needs to be atomic to avoid context switching to + * refreshChannel that shuts down delegate before newCall is completed. + */ +class RefreshingManagedChannel extends ManagedChannel { + private static final Logger LOG = Logger.getLogger(RefreshingManagedChannel.class.getName()); + // refresh every 50 minutes with 15% jitter for a range of 42.5min to 57.5min + private static final Duration refreshPeriod = Duration.ofMinutes(50); + private static final double jitterPercentage = 0.15; + private volatile SafeShutdownManagedChannel delegate; + private volatile ScheduledFuture nextScheduledRefresh; + // Read: method calls on delegate and nextScheduledRefresh + // Write: updating references of delegate and nextScheduledRefresh + private final ReadWriteLock lock; + private final ChannelFactory channelFactory; + private final ScheduledExecutorService scheduledExecutorService; + + RefreshingManagedChannel( + ChannelFactory channelFactory, ScheduledExecutorService scheduledExecutorService) + throws IOException { + this.delegate = new SafeShutdownManagedChannel(channelFactory.createSingleChannel()); + this.channelFactory = channelFactory; + this.scheduledExecutorService = scheduledExecutorService; + this.lock = new ReentrantReadWriteLock(); + this.nextScheduledRefresh = scheduleNextRefresh(); + } + + /** + * Refresh the existing channel by swapping the current channel with a new channel and schedule + * the next refresh + * + *

refreshChannel can only be called by scheduledExecutorService and not any other methods in + * this class. This is important so no threads will try to acquire the write lock while holding + * the read lock. + */ + private void refreshChannel() { + SafeShutdownManagedChannel newChannel; + try { + newChannel = new SafeShutdownManagedChannel(channelFactory.createSingleChannel()); + } catch (IOException ioException) { + LOG.log( + Level.WARNING, + "Failed to create a new channel when refreshing channel. This has no effect on the " + + "existing channels. The existing channel will continue to be used", + ioException); + return; + } + + SafeShutdownManagedChannel oldChannel = delegate; + lock.writeLock().lock(); + try { + // This thread can be interrupted by invoking cancel on nextScheduledRefresh + // Interrupt happens when this thread is blocked on acquiring the write lock because shutdown + // was called and that thread holds the read lock. + // When shutdown completes and releases the read lock and this thread acquires the write lock. + // This thread should not continue because the channel has shutdown. This check ensures that + // this thread terminates without swapping the channel and do not schedule the next refresh. + if (Thread.currentThread().isInterrupted()) { + newChannel.shutdownNow(); + return; + } + delegate = newChannel; + nextScheduledRefresh = scheduleNextRefresh(); + } finally { + lock.writeLock().unlock(); + } + oldChannel.shutdownSafely(); + } + + /** Schedule the next instance of refreshing this channel */ + private ScheduledFuture scheduleNextRefresh() { + long delayPeriod = refreshPeriod.toMillis(); + long jitter = (long) ((Math.random() - 0.5) * jitterPercentage * delayPeriod); + long delay = jitter + delayPeriod; + return scheduledExecutorService.schedule( + new Runnable() { + @Override + public void run() { + refreshChannel(); + } + }, + delay, + TimeUnit.MILLISECONDS); + } + + /** {@inheritDoc} */ + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + lock.readLock().lock(); + try { + return delegate.newCall(methodDescriptor, callOptions); + } finally { + lock.readLock().unlock(); + } + } + + /** {@inheritDoc} */ + @Override + public String authority() { + // no lock here because authority is constant across all channels + return delegate.authority(); + } + + /** {@inheritDoc} */ + @Override + public ManagedChannel shutdown() { + lock.readLock().lock(); + try { + nextScheduledRefresh.cancel(true); + delegate.shutdown(); + return this; + } finally { + lock.readLock().unlock(); + } + } + + /** {@inheritDoc} */ + @Override + public ManagedChannel shutdownNow() { + lock.readLock().lock(); + try { + nextScheduledRefresh.cancel(true); + delegate.shutdownNow(); + return this; + } finally { + lock.readLock().unlock(); + } + } + + /** {@inheritDoc} */ + @Override + public boolean isShutdown() { + lock.readLock().lock(); + try { + return delegate.isShutdown(); + } finally { + lock.readLock().unlock(); + } + } + + /** {@inheritDoc} */ + @Override + public boolean isTerminated() { + lock.readLock().lock(); + try { + return delegate.isTerminated(); + } finally { + lock.readLock().unlock(); + } + } + + /** {@inheritDoc} */ + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + lock.readLock().lock(); + try { + return delegate.awaitTermination(timeout, unit); + } finally { + lock.readLock().unlock(); + } + } +} diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/SafeShutdownManagedChannel.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/SafeShutdownManagedChannel.java new file mode 100644 index 000000000..0ae71d2e3 --- /dev/null +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/SafeShutdownManagedChannel.java @@ -0,0 +1,170 @@ +/* + * Copyright 2019 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.google.api.gax.grpc; + +import com.google.common.base.Preconditions; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ClientCall.Listener; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A {@link ManagedChannel} that will complete all calls started on the underlying channel before + * shutting down. + * + *

This class is not thread-safe. Caller must synchronize in order to ensure no new calls if safe + * shutdown has started. + * + *

Package-private for internal use. + */ +class SafeShutdownManagedChannel extends ManagedChannel { + private final ManagedChannel delegate; + private final AtomicInteger outstandingCalls = new AtomicInteger(0); + private volatile boolean isShutdownSafely = false; + + SafeShutdownManagedChannel(ManagedChannel managedChannel) { + this.delegate = managedChannel; + } + + /** + * Safely shutdown channel by checking that there are no more outstanding calls. If there are + * outstanding calls, the last call will invoke this method again when it complete + * + *

Caller should take care to synchronize with newCall so no new calls are started after + * shutdownSafely is called + */ + void shutdownSafely() { + isShutdownSafely = true; + if (outstandingCalls.get() == 0) { + delegate.shutdown(); + } + } + + /** {@inheritDoc} */ + @Override + public ManagedChannel shutdown() { + delegate.shutdown(); + return this; + } + + /** {@inheritDoc} */ + @Override + public boolean isShutdown() { + return delegate.isShutdown(); + } + + /** {@inheritDoc} */ + @Override + public ManagedChannel shutdownNow() { + delegate.shutdownNow(); + return this; + } + + /** {@inheritDoc} */ + @Override + public boolean isTerminated() { + return delegate.isTerminated(); + } + + /** {@inheritDoc} */ + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return delegate.awaitTermination(timeout, unit); + } + + /** + * Decrement outstanding call counter and shutdown if there are no more outstanding calls and + * {@link SafeShutdownManagedChannel#shutdownSafely()} has been invoked + */ + private void onClientCallClose() { + if (outstandingCalls.decrementAndGet() == 0 && isShutdownSafely) { + shutdownSafely(); + } + } + + /** Listener that's responsible for decrementing outstandingCalls when the call closes */ + private class DecrementOutstandingCalls extends SimpleForwardingClientCallListener { + DecrementOutstandingCalls(Listener delegate) { + super(delegate); + } + + @Override + public void onClose(Status status, Metadata trailers) { + // decrement in finally block in case onClose throws an exception + try { + super.onClose(status, trailers); + } finally { + onClientCallClose(); + } + } + } + + /** To wrap around delegate to hook in {@link DecrementOutstandingCalls} */ + private class ClientCallProxy extends SimpleForwardingClientCall { + ClientCallProxy(ClientCall delegate) { + super(delegate); + } + + @Override + public void start(Listener responseListener, Metadata headers) { + super.start(new DecrementOutstandingCalls<>(responseListener), headers); + } + } + + /** + * Caller must take care to synchronize newCall and shutdownSafely in order to avoid race + * conditions of starting new calls after shutdownSafely is called + * + * @see io.grpc.ManagedChannel#newCall(MethodDescriptor, CallOptions) + */ + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + Preconditions.checkState(!isShutdownSafely); + // increment after client call in case newCall throws an exception + ClientCall clientCall = + new ClientCallProxy<>(delegate.newCall(methodDescriptor, callOptions)); + outstandingCalls.incrementAndGet(); + return clientCall; + } + + /** {@inheritDoc} */ + @Override + public String authority() { + return delegate.authority(); + } +} diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index 33d24d545..a660ca1c6 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -29,6 +29,7 @@ */ package com.google.api.gax.grpc; +import com.google.api.gax.grpc.testing.FakeChannelFactory; import com.google.api.gax.grpc.testing.FakeServiceGrpc; import com.google.common.collect.Lists; import com.google.common.truth.Truth; @@ -38,11 +39,14 @@ import io.grpc.ClientCall; import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; @@ -55,24 +59,25 @@ @RunWith(JUnit4.class) public class ChannelPoolTest { @Test - public void testAuthority() { + public void testAuthority() throws IOException { ManagedChannel sub1 = Mockito.mock(ManagedChannel.class); ManagedChannel sub2 = Mockito.mock(ManagedChannel.class); Mockito.when(sub1.authority()).thenReturn("myAuth"); - ChannelPool pool = new ChannelPool(Lists.newArrayList(sub1, sub2)); + ChannelPool pool = ChannelPool.create(2, new FakeChannelFactory(Arrays.asList(sub1, sub2))); Truth.assertThat(pool.authority()).isEqualTo("myAuth"); } @Test - public void testRoundRobin() { + public void testRoundRobin() throws IOException { ManagedChannel sub1 = Mockito.mock(ManagedChannel.class); ManagedChannel sub2 = Mockito.mock(ManagedChannel.class); + Mockito.when(sub1.authority()).thenReturn("myAuth"); ArrayList channels = Lists.newArrayList(sub1, sub2); - ChannelPool pool = new ChannelPool(channels); + ChannelPool pool = ChannelPool.create(channels.size(), new FakeChannelFactory(channels)); verifyTargetChannel(pool, channels, sub1); verifyTargetChannel(pool, channels, sub2); @@ -104,7 +109,7 @@ private void verifyTargetChannel( } @Test - public void ensureEvenDistribution() throws InterruptedException { + public void ensureEvenDistribution() throws InterruptedException, IOException { int numChannels = 10; final ManagedChannel[] channels = new ManagedChannel[numChannels]; final AtomicInteger[] counts = new AtomicInteger[numChannels]; @@ -132,7 +137,8 @@ public ClientCall answer(InvocationOnMock invocationOnMock) }); } - final ChannelPool pool = new ChannelPool(Arrays.asList(channels)); + final ChannelPool pool = + ChannelPool.create(numChannels, new FakeChannelFactory(Arrays.asList(channels))); int numThreads = 20; final int numPerThread = 1000; @@ -158,4 +164,72 @@ public void run() { Truth.assertThat(count.get()).isAnyOf(expectedCount, expectedCount + 1); } } + + // Test channelPrimer is called same number of times as poolSize if executorService is set to null + @Test + public void channelPrimerShouldCallPoolConstruction() throws IOException { + ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class); + ManagedChannel channel1 = Mockito.mock(ManagedChannel.class); + ManagedChannel channel2 = Mockito.mock(ManagedChannel.class); + + ChannelPool.create( + 2, new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer)); + Mockito.verify(mockChannelPrimer, Mockito.times(2)) + .primeChannel(Mockito.any(ManagedChannel.class)); + } + + // Test channelPrimer is called periodically, if there's an executorService + @Test + public void channelPrimerIsCalledPeriodically() throws IOException, InterruptedException { + ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class); + ManagedChannel channel1 = Mockito.mock(RefreshingManagedChannel.class); + ManagedChannel channel2 = Mockito.mock(RefreshingManagedChannel.class); + ManagedChannel channel3 = Mockito.mock(RefreshingManagedChannel.class); + + final List channelRefreshers = new ArrayList<>(); + + ScheduledExecutorService scheduledExecutorService = + Mockito.mock(ScheduledExecutorService.class); + + Answer extractChannelRefresher = + new Answer() { + public Object answer(InvocationOnMock invocation) { + channelRefreshers.add((Runnable) invocation.getArgument(0)); + return Mockito.mock(ScheduledFuture.class); + } + }; + + Mockito.doAnswer(extractChannelRefresher) + .when(scheduledExecutorService) + .schedule( + Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); + + FakeChannelFactory channelFactory = + new FakeChannelFactory(Arrays.asList(channel1, channel2, channel3), mockChannelPrimer); + + ChannelPool.createRefreshing(1, channelFactory, scheduledExecutorService); + // 1 call during the creation + Mockito.verify(mockChannelPrimer, Mockito.times(1)) + .primeChannel(Mockito.any(ManagedChannel.class)); + Mockito.verify(scheduledExecutorService, Mockito.times(1)) + .schedule( + Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); + + channelRefreshers.get(0).run(); + // 1 more call during channel refresh + Mockito.verify(mockChannelPrimer, Mockito.times(2)) + .primeChannel(Mockito.any(ManagedChannel.class)); + Mockito.verify(scheduledExecutorService, Mockito.times(2)) + .schedule( + Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); + + channelRefreshers.get(0).run(); + // 1 more call during channel refresh + Mockito.verify(mockChannelPrimer, Mockito.times(3)) + .primeChannel(Mockito.any(ManagedChannel.class)); + Mockito.verify(scheduledExecutorService, Mockito.times(3)) + .schedule( + Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); + scheduledExecutorService.shutdown(); + } } diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java index 82d9712f2..9722f8634 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java @@ -31,6 +31,7 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.api.gax.grpc.testing.FakeChannelFactory; import com.google.api.gax.grpc.testing.FakeServiceGrpc; import com.google.common.collect.ImmutableList; import com.google.common.truth.Truth; @@ -43,6 +44,7 @@ import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import java.io.IOException; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -57,7 +59,7 @@ public class GrpcClientCallsTest { @Test - public void testAffinity() { + public void testAffinity() throws IOException { MethodDescriptor descriptor = FakeServiceGrpc.METHOD_RECOGNIZE; @SuppressWarnings("unchecked") @@ -73,8 +75,7 @@ public void testAffinity() { .thenReturn(clientCall0); Mockito.when(channel1.newCall(Mockito.eq(descriptor), Mockito.any())) .thenReturn(clientCall1); - - Channel pool = new ChannelPool(Arrays.asList(channel0, channel1)); + Channel pool = ChannelPool.create(2, new FakeChannelFactory(Arrays.asList(channel0, channel1))); GrpcCallContext context = GrpcCallContext.createDefault().withChannel(pool); ClientCall gotCallA = diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java index ee3c9daa8..aec7df9b7 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java @@ -36,6 +36,7 @@ import com.google.api.core.ApiFunction; import com.google.api.gax.core.ExecutorProvider; +import com.google.api.gax.core.FixedExecutorProvider; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.EnvironmentProvider; import com.google.api.gax.rpc.HeaderProvider; @@ -46,14 +47,20 @@ import io.grpc.ManagedChannelBuilder; import io.grpc.alts.ComputeEngineChannelBuilder; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import org.threeten.bp.Duration; @RunWith(JUnit4.class) @@ -360,4 +367,53 @@ public void testWithIPv6Address() throws IOException { // Make sure we can create channels OK. provider.getTransportChannel().shutdownNow(); } + + // Test that if ChannelPrimer is provided, it is called during creation and periodically + @Test + public void testWithPrimeChannel() throws IOException { + + final ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class); + final List channelRefreshers = new ArrayList<>(); + + ScheduledExecutorService scheduledExecutorService = + Mockito.mock(ScheduledExecutorService.class); + + Answer extractChannelRefresher = + new Answer() { + public Object answer(InvocationOnMock invocation) { + channelRefreshers.add((Runnable) invocation.getArgument(0)); + return Mockito.mock(ScheduledFuture.class); + } + }; + + Mockito.doAnswer(extractChannelRefresher) + .when(scheduledExecutorService) + .schedule( + Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); + + InstantiatingGrpcChannelProvider provider = + InstantiatingGrpcChannelProvider.newBuilder() + .setEndpoint("localhost:8080") + .setPoolSize(2) + .setHeaderProvider(Mockito.mock(HeaderProvider.class)) + .setExecutorProvider(FixedExecutorProvider.create(scheduledExecutorService)) + .setChannelPrimer(mockChannelPrimer) + .build(); + + provider.getTransportChannel().shutdownNow(); + + // 2 calls during the creation, 2 more calls when they get scheduled + Mockito.verify(mockChannelPrimer, Mockito.times(2)) + .primeChannel(Mockito.any(ManagedChannel.class)); + assertThat(channelRefreshers).hasSize(2); + Mockito.verify(scheduledExecutorService, Mockito.times(2)) + .schedule( + Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); + channelRefreshers.get(0).run(); + Mockito.verify(mockChannelPrimer, Mockito.times(3)) + .primeChannel(Mockito.any(ManagedChannel.class)); + channelRefreshers.get(1).run(); + Mockito.verify(mockChannelPrimer, Mockito.times(4)) + .primeChannel(Mockito.any(ManagedChannel.class)); + } } diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/RefreshingManagedChannelTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/RefreshingManagedChannelTest.java new file mode 100644 index 000000000..956d6b52d --- /dev/null +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/RefreshingManagedChannelTest.java @@ -0,0 +1,206 @@ +/* + * Copyright 2019 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.google.api.gax.grpc; + +import com.google.api.gax.grpc.testing.FakeChannelFactory; +import com.google.api.gax.grpc.testing.FakeMethodDescriptor; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +@RunWith(JUnit4.class) +public class RefreshingManagedChannelTest { + @Test + public void channelRefreshShouldSwapChannels() throws IOException { + ManagedChannel underlyingChannel1 = Mockito.mock(ManagedChannel.class); + ManagedChannel underlyingChannel2 = Mockito.mock(ManagedChannel.class); + + // mock executor service to capture the runnable scheduled so we can invoke it when we want to + ScheduledExecutorService scheduledExecutorService = + Mockito.mock(ScheduledExecutorService.class); + final List channelRefreshers = new ArrayList<>(); + Answer extractChannelRefresher = + new Answer() { + public Object answer(InvocationOnMock invocation) { + channelRefreshers.add((Runnable) invocation.getArgument(0)); + return null; + } + }; + + Mockito.doAnswer(extractChannelRefresher) + .when(scheduledExecutorService) + .schedule( + Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); + + FakeChannelFactory channelFactory = + new FakeChannelFactory(Arrays.asList(underlyingChannel1, underlyingChannel2)); + + ManagedChannel refreshingManagedChannel = + new RefreshingManagedChannel(channelFactory, scheduledExecutorService); + + refreshingManagedChannel.newCall( + FakeMethodDescriptor.create(), CallOptions.DEFAULT); + + Mockito.verify(underlyingChannel1, Mockito.only()) + .newCall(Mockito.>any(), Mockito.any(CallOptions.class)); + + // swap channel + channelRefreshers.get(0).run(); + + refreshingManagedChannel.newCall( + FakeMethodDescriptor.create(), CallOptions.DEFAULT); + + Mockito.verify(underlyingChannel2, Mockito.only()) + .newCall(Mockito.>any(), Mockito.any(CallOptions.class)); + } + + @Test + public void randomizeTest() throws IOException, InterruptedException, ExecutionException { + int channelCount = 10; + final ManagedChannel[] underlyingChannels = new ManagedChannel[channelCount]; + final Random r = new Random(); + for (int i = 0; i < channelCount; i++) { + final ManagedChannel mockManagedChannel = Mockito.mock(ManagedChannel.class); + underlyingChannels[i] = mockManagedChannel; + + final Answer waitAndSendMessage = + new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + // add a little time to sleep so calls don't always complete right away + TimeUnit.MICROSECONDS.sleep(r.nextInt(1000)); + // when sending message on the call, the channel cannot be shutdown + Mockito.verify(mockManagedChannel, Mockito.never()).shutdown(); + return invocation.callRealMethod(); + } + }; + + Answer createNewCall = + new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + // create a new client call for every new call to the underlying channel + MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); + MockClientCall spyClientCall = Mockito.spy(mockClientCall); + + // spy into clientCall to verify that the channel is not shutdown + Mockito.doAnswer(waitAndSendMessage) + .when(spyClientCall) + .sendMessage(Mockito.anyString()); + + return spyClientCall; + } + }; + + // return a new mocked client call when requesting new call on the channel + Mockito.doAnswer(createNewCall) + .when(underlyingChannels[i]) + .newCall( + Mockito.>any(), Mockito.any(CallOptions.class)); + } + + // mock executor service to capture the runnable scheduled so we can invoke it when we want to + final List channelRefreshers = new ArrayList<>(); + ScheduledExecutorService scheduledExecutorService = + Mockito.mock(ScheduledExecutorService.class); + Answer extractChannelRefresher = + new Answer() { + public Object answer(InvocationOnMock invocation) { + channelRefreshers.add((Runnable) invocation.getArgument(0)); + return null; + } + }; + Mockito.doAnswer(extractChannelRefresher) + .when(scheduledExecutorService) + .schedule( + Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)); + + FakeChannelFactory channelFactory = new FakeChannelFactory(Arrays.asList(underlyingChannels)); + final ManagedChannel refreshingManagedChannel = + new RefreshingManagedChannel(channelFactory, scheduledExecutorService); + + // send a bunch of request to RefreshingManagedChannel, executor needs more than 1 thread to test out concurrency + ExecutorService executor = Executors.newFixedThreadPool(10); + + // channelCount - 1 because the last channel cannot be refreshed because the FakeChannelFactory + // has no more channel to create + for (int i = 0; i < channelCount - 1; i++) { + List> futures = new ArrayList<>(); + int requestCount = 100; + int whenToRefresh = r.nextInt(requestCount); + for (int j = 0; j < requestCount; j++) { + Runnable createNewCall = + new Runnable() { + @Override + public void run() { + // create a new call and send message on refreshingManagedChannel + ClientCall call = + refreshingManagedChannel.newCall( + FakeMethodDescriptor.create(), CallOptions.DEFAULT); + @SuppressWarnings("unchecked") + ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); + call.start(listener, new Metadata()); + call.sendMessage("message"); + } + }; + futures.add(executor.submit(createNewCall)); + // at the randomly chosen point, refresh the channel + if (j == whenToRefresh) { + futures.add(executor.submit(channelRefreshers.get(i))); + } + } + for (Future future : futures) { + future.get(); + } + Mockito.verify(underlyingChannels[i], Mockito.atLeastOnce()).shutdown(); + Mockito.verify(underlyingChannels[i + 1], Mockito.never()).shutdown(); + } + } +} diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/SafeShutdownManagedChannelTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/SafeShutdownManagedChannelTest.java new file mode 100644 index 000000000..ed5659b1f --- /dev/null +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/SafeShutdownManagedChannelTest.java @@ -0,0 +1,186 @@ +/* + * Copyright 2019 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.google.api.gax.grpc; + +import com.google.api.gax.grpc.testing.FakeMethodDescriptor; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +@RunWith(JUnit4.class) +public class SafeShutdownManagedChannelTest { + // call should be allowed to complete and the channel should not shutdown + @Test + public void callShouldCompleteAfterCreation() { + final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); + + SafeShutdownManagedChannel safeShutdownManagedChannel = + new SafeShutdownManagedChannel(underlyingChannel); + + // create a mock call when new call comes to the underlying channel + MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); + MockClientCall spyClientCall = Mockito.spy(mockClientCall); + Mockito.when( + underlyingChannel.newCall( + Mockito.>any(), Mockito.any(CallOptions.class))) + .thenReturn(spyClientCall); + + Answer verifyChannelNotShutdown = + new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); + return invocation.callRealMethod(); + } + }; + + // verify that underlying channel is not shutdown when clientCall is still sending message + Mockito.doAnswer(verifyChannelNotShutdown).when(spyClientCall).sendMessage(Mockito.anyString()); + + // create a new call on safeShutdownManagedChannel + @SuppressWarnings("unchecked") + ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); + ClientCall call = + safeShutdownManagedChannel.newCall( + FakeMethodDescriptor.create(), CallOptions.DEFAULT); + + safeShutdownManagedChannel.shutdownSafely(); + // shutdown is not called because there is still an outstanding call, even if it hasn't started + Mockito.verify(underlyingChannel, Mockito.after(200).never()).shutdown(); + + // start clientCall + call.start(listener, new Metadata()); + // send message and end the call + call.sendMessage("message"); + // shutdown is called because the outstanding call has completed + Mockito.verify(underlyingChannel, Mockito.atLeastOnce()).shutdown(); + } + + // call should be allowed to complete and the channel should not shutdown + @Test + public void callShouldCompleteAfterStarted() { + final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); + + SafeShutdownManagedChannel safeShutdownManagedChannel = + new SafeShutdownManagedChannel(underlyingChannel); + + // create a mock call when new call comes to the underlying channel + MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); + MockClientCall spyClientCall = Mockito.spy(mockClientCall); + Mockito.when( + underlyingChannel.newCall( + Mockito.>any(), Mockito.any(CallOptions.class))) + .thenReturn(spyClientCall); + + Answer verifyChannelNotShutdown = + new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); + return invocation.callRealMethod(); + } + }; + + // verify that underlying channel is not shutdown when clientCall is still sending message + Mockito.doAnswer(verifyChannelNotShutdown).when(spyClientCall).sendMessage(Mockito.anyString()); + + // create a new call on safeShutdownManagedChannel + @SuppressWarnings("unchecked") + ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); + ClientCall call = + safeShutdownManagedChannel.newCall( + FakeMethodDescriptor.create(), CallOptions.DEFAULT); + + // start clientCall + call.start(listener, new Metadata()); + safeShutdownManagedChannel.shutdownSafely(); + // shutdown is not called because there is still an outstanding call + Mockito.verify(underlyingChannel, Mockito.after(200).never()).shutdown(); + // send message and end the call + call.sendMessage("message"); + // shutdown is called because the outstanding call has completed + Mockito.verify(underlyingChannel, Mockito.atLeastOnce()).shutdown(); + } + + // Channel should shutdown after a refresh all the calls have completed + @Test + public void channelShouldShutdown() { + final ManagedChannel underlyingChannel = Mockito.mock(ManagedChannel.class); + + SafeShutdownManagedChannel safeShutdownManagedChannel = + new SafeShutdownManagedChannel(underlyingChannel); + + // create a mock call when new call comes to the underlying channel + MockClientCall mockClientCall = new MockClientCall<>(1, Status.OK); + MockClientCall spyClientCall = Mockito.spy(mockClientCall); + Mockito.when( + underlyingChannel.newCall( + Mockito.>any(), Mockito.any(CallOptions.class))) + .thenReturn(spyClientCall); + + Answer verifyChannelNotShutdown = + new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + Mockito.verify(underlyingChannel, Mockito.never()).shutdown(); + return invocation.callRealMethod(); + } + }; + + // verify that underlying channel is not shutdown when clientCall is still sending message + Mockito.doAnswer(verifyChannelNotShutdown).when(spyClientCall).sendMessage(Mockito.anyString()); + + // create a new call on safeShutdownManagedChannel + @SuppressWarnings("unchecked") + ClientCall.Listener listener = Mockito.mock(ClientCall.Listener.class); + ClientCall call = + safeShutdownManagedChannel.newCall( + FakeMethodDescriptor.create(), CallOptions.DEFAULT); + + // start clientCall + call.start(listener, new Metadata()); + // send message and end the call + call.sendMessage("message"); + // shutdown is not called because it has not been shutdown yet + Mockito.verify(underlyingChannel, Mockito.after(200).never()).shutdown(); + safeShutdownManagedChannel.shutdownSafely(); + // shutdown is called because the outstanding call has completed + Mockito.verify(underlyingChannel, Mockito.atLeastOnce()).shutdown(); + } +} diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/testing/FakeChannelFactory.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/testing/FakeChannelFactory.java new file mode 100644 index 000000000..5a98d23d0 --- /dev/null +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/testing/FakeChannelFactory.java @@ -0,0 +1,58 @@ +/* + * Copyright 2019 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.google.api.gax.grpc.testing; + +import com.google.api.gax.grpc.ChannelFactory; +import com.google.api.gax.grpc.ChannelPrimer; +import io.grpc.ManagedChannel; +import java.util.List; + +public class FakeChannelFactory implements ChannelFactory { + private int called = 0; + private List channels; + private ChannelPrimer channelPrimer; + + public FakeChannelFactory(List channels) { + this.channels = channels; + } + + public FakeChannelFactory(List channels, ChannelPrimer channelPrimer) { + this.channels = channels; + this.channelPrimer = channelPrimer; + } + + public ManagedChannel createSingleChannel() { + ManagedChannel managedChannel = channels.get(called++); + if (this.channelPrimer != null) { + this.channelPrimer.primeChannel(managedChannel); + } + return managedChannel; + } +}