diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 3922ee5b89e..df57a14bae6 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -66,6 +66,8 @@ final class DelayedClientTransport implements ManagedClientTransport { @Nonnull @GuardedBy("lock") private Collection pendingStreams = new LinkedHashSet<>(); + @GuardedBy("lock") + private int pendingCompleteStreams; /** * When {@code shutdownStatus != null && !hasPendingStreams()}, then the transport is considered @@ -175,6 +177,7 @@ public final ClientStream newStream( private PendingStream createPendingStream(PickSubchannelArgs args) { PendingStream pendingStream = new PendingStream(args); pendingStreams.add(pendingStream); + pendingCompleteStreams++; if (getPendingStreamsCount() == 1) { syncContext.executeLater(reportTransportInUse); } @@ -211,7 +214,7 @@ public void run() { listener.transportShutdown(status); } }); - if (!hasPendingStreams() && reportTransportTerminated != null) { + if (pendingCompleteStreams == 0 && reportTransportTerminated != null) { syncContext.executeLater(reportTransportTerminated); reportTransportTerminated = null; } @@ -227,23 +230,15 @@ public void run() { public final void shutdownNow(Status status) { shutdown(status); Collection savedPendingStreams; - Runnable savedReportTransportTerminated; synchronized (lock) { savedPendingStreams = pendingStreams; - savedReportTransportTerminated = reportTransportTerminated; - reportTransportTerminated = null; if (!pendingStreams.isEmpty()) { pendingStreams = Collections.emptyList(); } } - if (savedReportTransportTerminated != null) { - for (PendingStream stream : savedPendingStreams) { - stream.cancel(status); - } - syncContext.execute(savedReportTransportTerminated); + for (PendingStream stream : savedPendingStreams) { + stream.cancel(status); } - // If savedReportTransportTerminated == null, transportTerminated() has already been called in - // shutdown(). } public final boolean hasPendingStreams() { @@ -259,6 +254,13 @@ final int getPendingStreamsCount() { } } + @VisibleForTesting + final int getPendingCompleteStreamsCount() { + synchronized (lock) { + return pendingCompleteStreams; + } + } + /** * Use the picker to try picking a transport for every pending stream, proceed the stream if the * pick is successful, otherwise keep it pending. @@ -324,10 +326,6 @@ public void run() { // (which would shutdown the transports and LoadBalancer) because the gap should be shorter // than IDLE_MODE_DEFAULT_TIMEOUT_MILLIS (1 second). syncContext.executeLater(reportTransportNotInUse); - if (shutdownStatus != null && reportTransportTerminated != null) { - syncContext.executeLater(reportTransportTerminated); - reportTransportTerminated = null; - } } } syncContext.drain(); @@ -341,6 +339,8 @@ public InternalLogId getLogId() { private class PendingStream extends DelayedStream { private final PickSubchannelArgs args; private final Context context = Context.current(); + @GuardedBy("lock") + private boolean transferCompleted; private PendingStream(PickSubchannelArgs args) { this.args = args; @@ -362,15 +362,25 @@ private void createRealStream(ClientTransport transport) { public void cancel(Status reason) { super.cancel(reason); synchronized (lock) { - if (reportTransportTerminated != null) { - boolean justRemovedAnElement = pendingStreams.remove(this); - if (!hasPendingStreams() && justRemovedAnElement) { - syncContext.executeLater(reportTransportNotInUse); - if (shutdownStatus != null) { - syncContext.executeLater(reportTransportTerminated); - reportTransportTerminated = null; - } - } + boolean justRemovedAnElement = pendingStreams.remove(this); + if (!hasPendingStreams() && justRemovedAnElement && reportTransportTerminated != null) { + syncContext.executeLater(reportTransportNotInUse); + } + } + syncContext.drain(); + } + + @Override + public void onTransferComplete() { + synchronized (lock) { + if (transferCompleted) { + return; + } + transferCompleted = true; + pendingCompleteStreams--; + if (shutdownStatus != null && pendingCompleteStreams == 0) { + syncContext.executeLater(reportTransportTerminated); + reportTransportTerminated = null; } } syncContext.drain(); diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index be21b4991ba..1e42fabd163 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -39,7 +39,7 @@ * DelayedStream} may be internally altered by different threads, thus internal synchronization is * necessary. */ -class DelayedStream implements ClientStream { +abstract class DelayedStream implements ClientStream { /** {@code true} once realStream is valid and all pending calls have been drained. */ private volatile boolean passThrough; /** @@ -221,12 +221,14 @@ public void start(ClientStreamListener listener) { if (savedPassThrough) { realStream.start(listener); + onTransferComplete(); } else { final ClientStreamListener finalListener = listener; delayOrExecute(new Runnable() { @Override public void run() { realStream.start(finalListener); + onTransferComplete(); } }); } @@ -302,6 +304,7 @@ public void run() { listenerToClose.closed(reason, new Metadata()); } drainPendingCalls(); + onTransferComplete(); } } @@ -407,6 +410,12 @@ ClientStream getRealStream() { return realStream; } + /** + * Provides the place to define actions at the point when transfer is done. + * Call this method to trigger those transfer completion activities. + */ + abstract void onTransferComplete(); + private static class DelayedStreamListener implements ClientStreamListener { private final ClientStreamListener realListener; private volatile boolean passThrough; diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index c3196ddd107..fdbcc33fcb0 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -47,7 +47,7 @@ final class MetadataApplierImpl extends MetadataApplier { boolean finalized; // not null if returnStream() was called before apply() - DelayedStream delayedStream; + ApplierDelayedStream delayedStream; MetadataApplierImpl( ClientTransport transport, MethodDescriptor method, Metadata origHeaders, @@ -105,11 +105,16 @@ ClientStream returnStream() { synchronized (lock) { if (returnedStream == null) { // apply() has not been called, needs to buffer the requests. - delayedStream = new DelayedStream(); + delayedStream = new ApplierDelayedStream(); return returnedStream = delayedStream; } else { return returnedStream; } } } + + private static class ApplierDelayedStream extends DelayedStream { + @Override + void onTransferComplete() {} + } } diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index 41a97d62f9a..bf88ae5e3cb 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -158,12 +158,14 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.reprocess(mockPicker); assertEquals(0, delayedTransport.getPendingStreamsCount()); delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); - verify(transportListener).transportTerminated(); + verify(transportListener, never()).transportTerminated(); assertEquals(1, fakeExecutor.runDueTasks()); + verify(transportListener, never()).transportTerminated(); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); stream.start(streamListener); verify(mockRealStream).start(same(streamListener)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportTerminated(); } @Test public void transportTerminatedThenAssignTransport() { @@ -201,8 +203,10 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void cancelStreamWithoutSetTransport() { ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); assertEquals(1, delayedTransport.getPendingStreamsCount()); + assertEquals(1, delayedTransport.getPendingCompleteStreamsCount()); stream.cancel(Status.CANCELLED); assertEquals(0, delayedTransport.getPendingStreamsCount()); + assertEquals(0, delayedTransport.getPendingCompleteStreamsCount()); verifyNoMoreInteractions(mockRealTransport); verifyNoMoreInteractions(mockRealStream); } @@ -211,13 +215,34 @@ public void uncaughtException(Thread t, Throwable e) { ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); stream.start(streamListener); assertEquals(1, delayedTransport.getPendingStreamsCount()); + assertEquals(1, delayedTransport.getPendingCompleteStreamsCount()); stream.cancel(Status.CANCELLED); assertEquals(0, delayedTransport.getPendingStreamsCount()); + assertEquals(0, delayedTransport.getPendingCompleteStreamsCount()); verify(streamListener).closed(same(Status.CANCELLED), any(Metadata.class)); verifyNoMoreInteractions(mockRealTransport); verifyNoMoreInteractions(mockRealStream); } + @Test + public void cancelStreamShutdownThenStart() { + ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + delayedTransport.shutdown(Status.UNAVAILABLE); + assertEquals(1, delayedTransport.getPendingStreamsCount()); + assertEquals(1, delayedTransport.getPendingCompleteStreamsCount()); + delayedTransport.reprocess(mockPicker); + assertEquals(1, fakeExecutor.runDueTasks()); + assertEquals(0, delayedTransport.getPendingStreamsCount()); + assertEquals(1, delayedTransport.getPendingCompleteStreamsCount()); + stream.cancel(Status.CANCELLED); + verify(mockRealStream).cancel(same(Status.CANCELLED)); + verify(transportListener, never()).transportTerminated(); + stream.start(streamListener); + assertEquals(0, delayedTransport.getPendingCompleteStreamsCount()); + verify(mockRealStream).start(streamListener); + verify(transportListener).transportTerminated(); + } + @Test public void newStreamThenShutdownTransportThenAssignTransport() { ClientStream stream = delayedTransport.newStream(method, headers, callOptions); stream.start(streamListener); @@ -353,6 +378,7 @@ public void uncaughtException(Thread t, Throwable e) { waitForReadyCallOptions); assertEquals(8, delayedTransport.getPendingStreamsCount()); + assertEquals(8, delayedTransport.getPendingCompleteStreamsCount()); // First reprocess(). Some will proceed, some will fail and the rest will stay buffered. SubchannelPicker picker = mock(SubchannelPicker.class); @@ -370,6 +396,7 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.reprocess(picker); assertEquals(5, delayedTransport.getPendingStreamsCount()); + assertEquals(8, delayedTransport.getPendingCompleteStreamsCount()); inOrder.verify(picker).pickSubchannel(ff1args); inOrder.verify(picker).pickSubchannel(ff2args); inOrder.verify(picker).pickSubchannel(ff3args); @@ -385,8 +412,12 @@ public void uncaughtException(Thread t, Throwable e) { any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); verify(mockRealTransport2, never()).newStream( any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + ff1.start(streamListener); + ff2.start(streamListener); fakeExecutor.runDueTasks(); assertEquals(0, fakeExecutor.numPendingTasks()); + // 8 - 2(runDueTask with start) + assertEquals(6, delayedTransport.getPendingCompleteStreamsCount()); // ff1 and wfr1 went through verify(mockRealTransport).newStream(method, headers, failFastCallOptions); verify(mockRealTransport2).newStream(method, headers, waitForReadyCallOptions); @@ -394,6 +425,8 @@ public void uncaughtException(Thread t, Throwable e) { assertSame(mockRealStream2, wfr1.getRealStream()); // The ff2 has failed due to picker returning an error assertSame(Status.UNAVAILABLE, ((FailingClientStream) ff2.getRealStream()).getError()); + wfr1.start(streamListener); + assertEquals(5, delayedTransport.getPendingCompleteStreamsCount()); // Other streams are still buffered assertNull(ff3.getRealStream()); assertNull(ff4.getRealStream()); @@ -414,8 +447,14 @@ public void uncaughtException(Thread t, Throwable e) { assertEquals(0, wfr3Executor.numPendingTasks()); verify(transportListener, never()).transportInUse(false); + ff3.start(streamListener); + ff4.start(streamListener); + wfr2.start(streamListener); + wfr3.start(streamListener); + wfr4.start(streamListener); delayedTransport.reprocess(picker); assertEquals(0, delayedTransport.getPendingStreamsCount()); + assertEquals(5, delayedTransport.getPendingCompleteStreamsCount()); verify(transportListener).transportInUse(false); inOrder.verify(picker).pickSubchannel(ff3args); // ff3 inOrder.verify(picker).pickSubchannel(ff4args); // ff4 @@ -423,8 +462,9 @@ public void uncaughtException(Thread t, Throwable e) { inOrder.verify(picker).pickSubchannel(wfr3args); // wfr3 inOrder.verify(picker).pickSubchannel(wfr4args); // wfr4 inOrder.verifyNoMoreInteractions(); - fakeExecutor.runDueTasks(); + assertEquals(4, fakeExecutor.runDueTasks()); assertEquals(0, fakeExecutor.numPendingTasks()); + assertEquals(1, delayedTransport.getPendingCompleteStreamsCount()); assertSame(mockRealStream, ff3.getRealStream()); assertSame(mockRealStream2, ff4.getRealStream()); assertSame(mockRealStream2, wfr2.getRealStream()); @@ -434,15 +474,18 @@ public void uncaughtException(Thread t, Throwable e) { assertNull(wfr3.getRealStream()); wfr3Executor.runDueTasks(); assertSame(mockRealStream, wfr3.getRealStream()); + assertEquals(0, delayedTransport.getPendingCompleteStreamsCount()); // New streams will use the last picker DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream( method, headers, waitForReadyCallOptions); + wfr5.start(streamListener); assertNull(wfr5.getRealStream()); inOrder.verify(picker).pickSubchannel( new PickSubchannelArgsImpl(method, headers, waitForReadyCallOptions)); inOrder.verifyNoMoreInteractions(); assertEquals(1, delayedTransport.getPendingStreamsCount()); + assertEquals(1, delayedTransport.getPendingCompleteStreamsCount()); // wfr5 will stop delayed transport from terminating delayedTransport.shutdown(SHUTDOWN_STATUS); diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index 393a6c6e6d0..9b857c119fb 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -65,7 +65,7 @@ public class DelayedStreamTest { @Mock private ClientStreamListener listener; @Mock private ClientStream realStream; @Captor private ArgumentCaptor listenerCaptor; - private DelayedStream stream = new DelayedStream(); + private DelayedStream stream = new SimpleDelayedStream(); @Test public void setStream_setAuthority() { @@ -378,4 +378,10 @@ public Void answer(InvocationOnMock in) { assertThat(insight.toString()) .matches("\\[buffered_nanos=[0-9]+, remote_addr=127\\.0\\.0\\.1:443\\]"); } + + private static class SimpleDelayedStream extends DelayedStream { + @Override + void onTransferComplete() { + } + } }