Skip to content

Commit

Permalink
check pending stream completion at delayed transport lifecycle (grpc…
Browse files Browse the repository at this point in the history
…#7720)

add onTransferComplete() at delayedStream and wait for all pending streams to complete transfer when shutting down delayedClientTransport
  • Loading branch information
YifeiZhuang authored and dfawley committed Jan 15, 2021
1 parent 5776c9f commit e40a7a1
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 31 deletions.
58 changes: 34 additions & 24 deletions core/src/main/java/io/grpc/internal/DelayedClientTransport.java
Expand Up @@ -66,6 +66,8 @@ final class DelayedClientTransport implements ManagedClientTransport {
@Nonnull
@GuardedBy("lock")
private Collection<PendingStream> pendingStreams = new LinkedHashSet<>();
@GuardedBy("lock")
private int pendingCompleteStreams;

/**
* When {@code shutdownStatus != null && !hasPendingStreams()}, then the transport is considered
Expand Down Expand Up @@ -175,6 +177,7 @@ public final ClientStream newStream(
private PendingStream createPendingStream(PickSubchannelArgs args) {
PendingStream pendingStream = new PendingStream(args);
pendingStreams.add(pendingStream);
pendingCompleteStreams++;
if (getPendingStreamsCount() == 1) {
syncContext.executeLater(reportTransportInUse);
}
Expand Down Expand Up @@ -211,7 +214,7 @@ public void run() {
listener.transportShutdown(status);
}
});
if (!hasPendingStreams() && reportTransportTerminated != null) {
if (pendingCompleteStreams == 0 && reportTransportTerminated != null) {
syncContext.executeLater(reportTransportTerminated);
reportTransportTerminated = null;
}
Expand All @@ -227,23 +230,15 @@ public void run() {
public final void shutdownNow(Status status) {
shutdown(status);
Collection<PendingStream> savedPendingStreams;
Runnable savedReportTransportTerminated;
synchronized (lock) {
savedPendingStreams = pendingStreams;
savedReportTransportTerminated = reportTransportTerminated;
reportTransportTerminated = null;
if (!pendingStreams.isEmpty()) {
pendingStreams = Collections.emptyList();
}
}
if (savedReportTransportTerminated != null) {
for (PendingStream stream : savedPendingStreams) {
stream.cancel(status);
}
syncContext.execute(savedReportTransportTerminated);
for (PendingStream stream : savedPendingStreams) {
stream.cancel(status);
}
// If savedReportTransportTerminated == null, transportTerminated() has already been called in
// shutdown().
}

public final boolean hasPendingStreams() {
Expand All @@ -259,6 +254,13 @@ final int getPendingStreamsCount() {
}
}

@VisibleForTesting
final int getPendingCompleteStreamsCount() {
synchronized (lock) {
return pendingCompleteStreams;
}
}

/**
* Use the picker to try picking a transport for every pending stream, proceed the stream if the
* pick is successful, otherwise keep it pending.
Expand Down Expand Up @@ -324,10 +326,6 @@ public void run() {
// (which would shutdown the transports and LoadBalancer) because the gap should be shorter
// than IDLE_MODE_DEFAULT_TIMEOUT_MILLIS (1 second).
syncContext.executeLater(reportTransportNotInUse);
if (shutdownStatus != null && reportTransportTerminated != null) {
syncContext.executeLater(reportTransportTerminated);
reportTransportTerminated = null;
}
}
}
syncContext.drain();
Expand All @@ -341,6 +339,8 @@ public InternalLogId getLogId() {
private class PendingStream extends DelayedStream {
private final PickSubchannelArgs args;
private final Context context = Context.current();
@GuardedBy("lock")
private boolean transferCompleted;

private PendingStream(PickSubchannelArgs args) {
this.args = args;
Expand All @@ -362,15 +362,25 @@ private void createRealStream(ClientTransport transport) {
public void cancel(Status reason) {
super.cancel(reason);
synchronized (lock) {
if (reportTransportTerminated != null) {
boolean justRemovedAnElement = pendingStreams.remove(this);
if (!hasPendingStreams() && justRemovedAnElement) {
syncContext.executeLater(reportTransportNotInUse);
if (shutdownStatus != null) {
syncContext.executeLater(reportTransportTerminated);
reportTransportTerminated = null;
}
}
boolean justRemovedAnElement = pendingStreams.remove(this);
if (!hasPendingStreams() && justRemovedAnElement && reportTransportTerminated != null) {
syncContext.executeLater(reportTransportNotInUse);
}
}
syncContext.drain();
}

@Override
public void onTransferComplete() {
synchronized (lock) {
if (transferCompleted) {
return;
}
transferCompleted = true;
pendingCompleteStreams--;
if (shutdownStatus != null && pendingCompleteStreams == 0) {
syncContext.executeLater(reportTransportTerminated);
reportTransportTerminated = null;
}
}
syncContext.drain();
Expand Down
11 changes: 10 additions & 1 deletion core/src/main/java/io/grpc/internal/DelayedStream.java
Expand Up @@ -39,7 +39,7 @@
* DelayedStream} may be internally altered by different threads, thus internal synchronization is
* necessary.
*/
class DelayedStream implements ClientStream {
abstract class DelayedStream implements ClientStream {
/** {@code true} once realStream is valid and all pending calls have been drained. */
private volatile boolean passThrough;
/**
Expand Down Expand Up @@ -221,12 +221,14 @@ public void start(ClientStreamListener listener) {

if (savedPassThrough) {
realStream.start(listener);
onTransferComplete();
} else {
final ClientStreamListener finalListener = listener;
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.start(finalListener);
onTransferComplete();
}
});
}
Expand Down Expand Up @@ -302,6 +304,7 @@ public void run() {
listenerToClose.closed(reason, new Metadata());
}
drainPendingCalls();
onTransferComplete();
}
}

Expand Down Expand Up @@ -407,6 +410,12 @@ ClientStream getRealStream() {
return realStream;
}

/**
* Provides the place to define actions at the point when transfer is done.
* Call this method to trigger those transfer completion activities.
*/
abstract void onTransferComplete();

private static class DelayedStreamListener implements ClientStreamListener {
private final ClientStreamListener realListener;
private volatile boolean passThrough;
Expand Down
9 changes: 7 additions & 2 deletions core/src/main/java/io/grpc/internal/MetadataApplierImpl.java
Expand Up @@ -47,7 +47,7 @@ final class MetadataApplierImpl extends MetadataApplier {
boolean finalized;

// not null if returnStream() was called before apply()
DelayedStream delayedStream;
ApplierDelayedStream delayedStream;

MetadataApplierImpl(
ClientTransport transport, MethodDescriptor<?, ?> method, Metadata origHeaders,
Expand Down Expand Up @@ -105,11 +105,16 @@ ClientStream returnStream() {
synchronized (lock) {
if (returnedStream == null) {
// apply() has not been called, needs to buffer the requests.
delayedStream = new DelayedStream();
delayedStream = new ApplierDelayedStream();
return returnedStream = delayedStream;
} else {
return returnedStream;
}
}
}

private static class ApplierDelayedStream extends DelayedStream {
@Override
void onTransferComplete() {}
}
}
Expand Up @@ -158,12 +158,14 @@ public void uncaughtException(Thread t, Throwable e) {
delayedTransport.reprocess(mockPicker);
assertEquals(0, delayedTransport.getPendingStreamsCount());
delayedTransport.shutdown(SHUTDOWN_STATUS);
verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS));
verify(transportListener).transportTerminated();
verify(transportListener, never()).transportTerminated();
assertEquals(1, fakeExecutor.runDueTasks());
verify(transportListener, never()).transportTerminated();
verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions));
stream.start(streamListener);
verify(mockRealStream).start(same(streamListener));
verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS));
verify(transportListener).transportTerminated();
}

@Test public void transportTerminatedThenAssignTransport() {
Expand Down Expand Up @@ -201,8 +203,10 @@ public void uncaughtException(Thread t, Throwable e) {
@Test public void cancelStreamWithoutSetTransport() {
ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT);
assertEquals(1, delayedTransport.getPendingStreamsCount());
assertEquals(1, delayedTransport.getPendingCompleteStreamsCount());
stream.cancel(Status.CANCELLED);
assertEquals(0, delayedTransport.getPendingStreamsCount());
assertEquals(0, delayedTransport.getPendingCompleteStreamsCount());
verifyNoMoreInteractions(mockRealTransport);
verifyNoMoreInteractions(mockRealStream);
}
Expand All @@ -211,13 +215,34 @@ public void uncaughtException(Thread t, Throwable e) {
ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT);
stream.start(streamListener);
assertEquals(1, delayedTransport.getPendingStreamsCount());
assertEquals(1, delayedTransport.getPendingCompleteStreamsCount());
stream.cancel(Status.CANCELLED);
assertEquals(0, delayedTransport.getPendingStreamsCount());
assertEquals(0, delayedTransport.getPendingCompleteStreamsCount());
verify(streamListener).closed(same(Status.CANCELLED), any(Metadata.class));
verifyNoMoreInteractions(mockRealTransport);
verifyNoMoreInteractions(mockRealStream);
}

@Test
public void cancelStreamShutdownThenStart() {
ClientStream stream = delayedTransport.newStream(method, headers, callOptions);
delayedTransport.shutdown(Status.UNAVAILABLE);
assertEquals(1, delayedTransport.getPendingStreamsCount());
assertEquals(1, delayedTransport.getPendingCompleteStreamsCount());
delayedTransport.reprocess(mockPicker);
assertEquals(1, fakeExecutor.runDueTasks());
assertEquals(0, delayedTransport.getPendingStreamsCount());
assertEquals(1, delayedTransport.getPendingCompleteStreamsCount());
stream.cancel(Status.CANCELLED);
verify(mockRealStream).cancel(same(Status.CANCELLED));
verify(transportListener, never()).transportTerminated();
stream.start(streamListener);
assertEquals(0, delayedTransport.getPendingCompleteStreamsCount());
verify(mockRealStream).start(streamListener);
verify(transportListener).transportTerminated();
}

@Test public void newStreamThenShutdownTransportThenAssignTransport() {
ClientStream stream = delayedTransport.newStream(method, headers, callOptions);
stream.start(streamListener);
Expand Down Expand Up @@ -353,6 +378,7 @@ public void uncaughtException(Thread t, Throwable e) {
waitForReadyCallOptions);

assertEquals(8, delayedTransport.getPendingStreamsCount());
assertEquals(8, delayedTransport.getPendingCompleteStreamsCount());

// First reprocess(). Some will proceed, some will fail and the rest will stay buffered.
SubchannelPicker picker = mock(SubchannelPicker.class);
Expand All @@ -370,6 +396,7 @@ public void uncaughtException(Thread t, Throwable e) {
delayedTransport.reprocess(picker);

assertEquals(5, delayedTransport.getPendingStreamsCount());
assertEquals(8, delayedTransport.getPendingCompleteStreamsCount());
inOrder.verify(picker).pickSubchannel(ff1args);
inOrder.verify(picker).pickSubchannel(ff2args);
inOrder.verify(picker).pickSubchannel(ff3args);
Expand All @@ -385,15 +412,21 @@ public void uncaughtException(Thread t, Throwable e) {
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class));
verify(mockRealTransport2, never()).newStream(
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class));
ff1.start(streamListener);
ff2.start(streamListener);
fakeExecutor.runDueTasks();
assertEquals(0, fakeExecutor.numPendingTasks());
// 8 - 2(runDueTask with start)
assertEquals(6, delayedTransport.getPendingCompleteStreamsCount());
// ff1 and wfr1 went through
verify(mockRealTransport).newStream(method, headers, failFastCallOptions);
verify(mockRealTransport2).newStream(method, headers, waitForReadyCallOptions);
assertSame(mockRealStream, ff1.getRealStream());
assertSame(mockRealStream2, wfr1.getRealStream());
// The ff2 has failed due to picker returning an error
assertSame(Status.UNAVAILABLE, ((FailingClientStream) ff2.getRealStream()).getError());
wfr1.start(streamListener);
assertEquals(5, delayedTransport.getPendingCompleteStreamsCount());
// Other streams are still buffered
assertNull(ff3.getRealStream());
assertNull(ff4.getRealStream());
Expand All @@ -414,17 +447,24 @@ public void uncaughtException(Thread t, Throwable e) {
assertEquals(0, wfr3Executor.numPendingTasks());
verify(transportListener, never()).transportInUse(false);

ff3.start(streamListener);
ff4.start(streamListener);
wfr2.start(streamListener);
wfr3.start(streamListener);
wfr4.start(streamListener);
delayedTransport.reprocess(picker);
assertEquals(0, delayedTransport.getPendingStreamsCount());
assertEquals(5, delayedTransport.getPendingCompleteStreamsCount());
verify(transportListener).transportInUse(false);
inOrder.verify(picker).pickSubchannel(ff3args); // ff3
inOrder.verify(picker).pickSubchannel(ff4args); // ff4
inOrder.verify(picker).pickSubchannel(wfr2args); // wfr2
inOrder.verify(picker).pickSubchannel(wfr3args); // wfr3
inOrder.verify(picker).pickSubchannel(wfr4args); // wfr4
inOrder.verifyNoMoreInteractions();
fakeExecutor.runDueTasks();
assertEquals(4, fakeExecutor.runDueTasks());
assertEquals(0, fakeExecutor.numPendingTasks());
assertEquals(1, delayedTransport.getPendingCompleteStreamsCount());
assertSame(mockRealStream, ff3.getRealStream());
assertSame(mockRealStream2, ff4.getRealStream());
assertSame(mockRealStream2, wfr2.getRealStream());
Expand All @@ -434,15 +474,18 @@ public void uncaughtException(Thread t, Throwable e) {
assertNull(wfr3.getRealStream());
wfr3Executor.runDueTasks();
assertSame(mockRealStream, wfr3.getRealStream());
assertEquals(0, delayedTransport.getPendingCompleteStreamsCount());

// New streams will use the last picker
DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream(
method, headers, waitForReadyCallOptions);
wfr5.start(streamListener);
assertNull(wfr5.getRealStream());
inOrder.verify(picker).pickSubchannel(
new PickSubchannelArgsImpl(method, headers, waitForReadyCallOptions));
inOrder.verifyNoMoreInteractions();
assertEquals(1, delayedTransport.getPendingStreamsCount());
assertEquals(1, delayedTransport.getPendingCompleteStreamsCount());

// wfr5 will stop delayed transport from terminating
delayedTransport.shutdown(SHUTDOWN_STATUS);
Expand Down
8 changes: 7 additions & 1 deletion core/src/test/java/io/grpc/internal/DelayedStreamTest.java
Expand Up @@ -65,7 +65,7 @@ public class DelayedStreamTest {
@Mock private ClientStreamListener listener;
@Mock private ClientStream realStream;
@Captor private ArgumentCaptor<ClientStreamListener> listenerCaptor;
private DelayedStream stream = new DelayedStream();
private DelayedStream stream = new SimpleDelayedStream();

@Test
public void setStream_setAuthority() {
Expand Down Expand Up @@ -378,4 +378,10 @@ public Void answer(InvocationOnMock in) {
assertThat(insight.toString())
.matches("\\[buffered_nanos=[0-9]+, remote_addr=127\\.0\\.0\\.1:443\\]");
}

private static class SimpleDelayedStream extends DelayedStream {
@Override
void onTransferComplete() {
}
}
}

0 comments on commit e40a7a1

Please sign in to comment.