Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check pending stream completion at delayed transport lifecycle #7720

Merged
merged 11 commits into from Dec 21, 2020
58 changes: 34 additions & 24 deletions core/src/main/java/io/grpc/internal/DelayedClientTransport.java
Expand Up @@ -66,6 +66,8 @@ final class DelayedClientTransport implements ManagedClientTransport {
@Nonnull
@GuardedBy("lock")
private Collection<PendingStream> pendingStreams = new LinkedHashSet<>();
@GuardedBy("lock")
private int pendingCompleteStreams;

/**
* When {@code shutdownStatus != null && !hasPendingStreams()}, then the transport is considered
Expand Down Expand Up @@ -175,6 +177,7 @@ public final ClientStream newStream(
private PendingStream createPendingStream(PickSubchannelArgs args) {
PendingStream pendingStream = new PendingStream(args);
pendingStreams.add(pendingStream);
pendingCompleteStreams++;
if (getPendingStreamsCount() == 1) {
syncContext.executeLater(reportTransportInUse);
}
Expand Down Expand Up @@ -211,7 +214,7 @@ public void run() {
listener.transportShutdown(status);
}
});
if (!hasPendingStreams() && reportTransportTerminated != null) {
if (pendingCompleteStreams == 0 && reportTransportTerminated != null) {
syncContext.executeLater(reportTransportTerminated);
reportTransportTerminated = null;
}
Expand All @@ -227,23 +230,15 @@ public void run() {
public final void shutdownNow(Status status) {
shutdown(status);
Collection<PendingStream> savedPendingStreams;
Runnable savedReportTransportTerminated;
synchronized (lock) {
savedPendingStreams = pendingStreams;
savedReportTransportTerminated = reportTransportTerminated;
reportTransportTerminated = null;
if (!pendingStreams.isEmpty()) {
pendingStreams = Collections.emptyList();
}
}
if (savedReportTransportTerminated != null) {
for (PendingStream stream : savedPendingStreams) {
stream.cancel(status);
}
syncContext.execute(savedReportTransportTerminated);
for (PendingStream stream : savedPendingStreams) {
stream.cancel(status);
}
// If savedReportTransportTerminated == null, transportTerminated() has already been called in
// shutdown().
}

public final boolean hasPendingStreams() {
Expand All @@ -259,6 +254,13 @@ final int getPendingStreamsCount() {
}
}

@VisibleForTesting
final int getPendingCompleteStreamsCount() {
synchronized (lock) {
return pendingCompleteStreams;
}
}

/**
* Use the picker to try picking a transport for every pending stream, proceed the stream if the
* pick is successful, otherwise keep it pending.
Expand Down Expand Up @@ -324,10 +326,6 @@ public void run() {
// (which would shutdown the transports and LoadBalancer) because the gap should be shorter
// than IDLE_MODE_DEFAULT_TIMEOUT_MILLIS (1 second).
syncContext.executeLater(reportTransportNotInUse);
if (shutdownStatus != null && reportTransportTerminated != null) {
syncContext.executeLater(reportTransportTerminated);
reportTransportTerminated = null;
}
}
}
syncContext.drain();
Expand All @@ -341,6 +339,8 @@ public InternalLogId getLogId() {
private class PendingStream extends DelayedStream {
private final PickSubchannelArgs args;
private final Context context = Context.current();
@GuardedBy("lock")
private boolean transferCompleted;

private PendingStream(PickSubchannelArgs args) {
this.args = args;
Expand All @@ -362,15 +362,25 @@ private void createRealStream(ClientTransport transport) {
public void cancel(Status reason) {
super.cancel(reason);
synchronized (lock) {
if (reportTransportTerminated != null) {
boolean justRemovedAnElement = pendingStreams.remove(this);
if (!hasPendingStreams() && justRemovedAnElement) {
syncContext.executeLater(reportTransportNotInUse);
if (shutdownStatus != null) {
syncContext.executeLater(reportTransportTerminated);
reportTransportTerminated = null;
}
}
boolean justRemovedAnElement = pendingStreams.remove(this);
if (!hasPendingStreams() && justRemovedAnElement && reportTransportTerminated != null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to check reportTransportTerminated != null after the change. Previously if reportTransportTerminated == null, pendingStream is guaranteed reset to empty and justRemovedAnElement will never happen.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that in time order: 1. super.cancel() 2.onTransferComplete(), reportTransportTerminated=null 3. this lock block might reportNotInUse, which won't happen previously.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think the check for reportTransportTerminated != null is unnecessary and the (!hasPendingStreams() && justRemovedAnElement) is the canonical invariant for reportTransportNotInUse. But ether way seems working anyway.

syncContext.executeLater(reportTransportNotInUse);
}
}
syncContext.drain();
}

@Override
public void onTransferComplete() {
synchronized (lock) {
if (transferCompleted) {
return;
}
transferCompleted = true;
pendingCompleteStreams--;
if (shutdownStatus != null && pendingCompleteStreams == 0) {
syncContext.executeLater(reportTransportTerminated);
reportTransportTerminated = null;
}
}
syncContext.drain();
Expand Down
11 changes: 10 additions & 1 deletion core/src/main/java/io/grpc/internal/DelayedStream.java
Expand Up @@ -39,7 +39,7 @@
* DelayedStream} may be internally altered by different threads, thus internal synchronization is
* necessary.
*/
class DelayedStream implements ClientStream {
abstract class DelayedStream implements ClientStream {
/** {@code true} once realStream is valid and all pending calls have been drained. */
private volatile boolean passThrough;
/**
Expand Down Expand Up @@ -221,12 +221,14 @@ public void start(ClientStreamListener listener) {

if (savedPassThrough) {
realStream.start(listener);
onTransferComplete();
} else {
final ClientStreamListener finalListener = listener;
delayOrExecute(new Runnable() {
@Override
public void run() {
realStream.start(finalListener);
onTransferComplete();
}
});
}
Expand Down Expand Up @@ -302,6 +304,7 @@ public void run() {
listenerToClose.closed(reason, new Metadata());
}
drainPendingCalls();
onTransferComplete();
}
}

Expand Down Expand Up @@ -407,6 +410,12 @@ ClientStream getRealStream() {
return realStream;
}

/**
* Provides the place to define actions at the point when transfer is done.
* Call this method to trigger those transfer completion activities.
*/
abstract void onTransferComplete();

private static class DelayedStreamListener implements ClientStreamListener {
private final ClientStreamListener realListener;
private volatile boolean passThrough;
Expand Down
9 changes: 7 additions & 2 deletions core/src/main/java/io/grpc/internal/MetadataApplierImpl.java
Expand Up @@ -47,7 +47,7 @@ final class MetadataApplierImpl extends MetadataApplier {
boolean finalized;

// not null if returnStream() was called before apply()
DelayedStream delayedStream;
ApplierDelayedStream delayedStream;

MetadataApplierImpl(
ClientTransport transport, MethodDescriptor<?, ?> method, Metadata origHeaders,
Expand Down Expand Up @@ -105,11 +105,16 @@ ClientStream returnStream() {
synchronized (lock) {
if (returnedStream == null) {
// apply() has not been called, needs to buffer the requests.
delayedStream = new DelayedStream();
delayedStream = new ApplierDelayedStream();
return returnedStream = delayedStream;
} else {
return returnedStream;
}
}
}

private static class ApplierDelayedStream extends DelayedStream {
@Override
void onTransferComplete() {}
}
}
Expand Up @@ -158,12 +158,14 @@ public void uncaughtException(Thread t, Throwable e) {
delayedTransport.reprocess(mockPicker);
assertEquals(0, delayedTransport.getPendingStreamsCount());
delayedTransport.shutdown(SHUTDOWN_STATUS);
verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS));
verify(transportListener).transportTerminated();
verify(transportListener, never()).transportTerminated();
assertEquals(1, fakeExecutor.runDueTasks());
verify(transportListener, never()).transportTerminated();
verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions));
stream.start(streamListener);
verify(mockRealStream).start(same(streamListener));
verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS));
verify(transportListener).transportTerminated();
}

@Test public void transportTerminatedThenAssignTransport() {
Expand Down Expand Up @@ -201,8 +203,10 @@ public void uncaughtException(Thread t, Throwable e) {
@Test public void cancelStreamWithoutSetTransport() {
ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT);
assertEquals(1, delayedTransport.getPendingStreamsCount());
assertEquals(1, delayedTransport.getPendingCompleteStreamsCount());
stream.cancel(Status.CANCELLED);
assertEquals(0, delayedTransport.getPendingStreamsCount());
assertEquals(0, delayedTransport.getPendingCompleteStreamsCount());
verifyNoMoreInteractions(mockRealTransport);
verifyNoMoreInteractions(mockRealStream);
}
Expand All @@ -211,13 +215,34 @@ public void uncaughtException(Thread t, Throwable e) {
ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT);
stream.start(streamListener);
assertEquals(1, delayedTransport.getPendingStreamsCount());
assertEquals(1, delayedTransport.getPendingCompleteStreamsCount());
stream.cancel(Status.CANCELLED);
assertEquals(0, delayedTransport.getPendingStreamsCount());
assertEquals(0, delayedTransport.getPendingCompleteStreamsCount());
verify(streamListener).closed(same(Status.CANCELLED), any(Metadata.class));
verifyNoMoreInteractions(mockRealTransport);
verifyNoMoreInteractions(mockRealStream);
}

@Test
public void cancelStreamShutdownThenStart() {
ClientStream stream = delayedTransport.newStream(method, headers, callOptions);
delayedTransport.shutdown(Status.UNAVAILABLE);
assertEquals(1, delayedTransport.getPendingStreamsCount());
assertEquals(1, delayedTransport.getPendingCompleteStreamsCount());
delayedTransport.reprocess(mockPicker);
assertEquals(1, fakeExecutor.runDueTasks());
assertEquals(0, delayedTransport.getPendingStreamsCount());
assertEquals(1, delayedTransport.getPendingCompleteStreamsCount());
stream.cancel(Status.CANCELLED);
verify(mockRealStream).cancel(same(Status.CANCELLED));
verify(transportListener, never()).transportTerminated();
stream.start(streamListener);
assertEquals(0, delayedTransport.getPendingCompleteStreamsCount());
verify(mockRealStream).start(streamListener);
verify(transportListener).transportTerminated();
}

@Test public void newStreamThenShutdownTransportThenAssignTransport() {
ClientStream stream = delayedTransport.newStream(method, headers, callOptions);
stream.start(streamListener);
Expand Down Expand Up @@ -353,6 +378,7 @@ public void uncaughtException(Thread t, Throwable e) {
waitForReadyCallOptions);

assertEquals(8, delayedTransport.getPendingStreamsCount());
assertEquals(8, delayedTransport.getPendingCompleteStreamsCount());

// First reprocess(). Some will proceed, some will fail and the rest will stay buffered.
SubchannelPicker picker = mock(SubchannelPicker.class);
Expand All @@ -370,6 +396,7 @@ public void uncaughtException(Thread t, Throwable e) {
delayedTransport.reprocess(picker);

assertEquals(5, delayedTransport.getPendingStreamsCount());
assertEquals(8, delayedTransport.getPendingCompleteStreamsCount());
inOrder.verify(picker).pickSubchannel(ff1args);
inOrder.verify(picker).pickSubchannel(ff2args);
inOrder.verify(picker).pickSubchannel(ff3args);
Expand All @@ -385,15 +412,21 @@ public void uncaughtException(Thread t, Throwable e) {
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class));
verify(mockRealTransport2, never()).newStream(
any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class));
ff1.start(streamListener);
ff2.start(streamListener);
fakeExecutor.runDueTasks();
assertEquals(0, fakeExecutor.numPendingTasks());
// 8 - 2(runDueTask with start)
assertEquals(6, delayedTransport.getPendingCompleteStreamsCount());
// ff1 and wfr1 went through
verify(mockRealTransport).newStream(method, headers, failFastCallOptions);
verify(mockRealTransport2).newStream(method, headers, waitForReadyCallOptions);
assertSame(mockRealStream, ff1.getRealStream());
assertSame(mockRealStream2, wfr1.getRealStream());
// The ff2 has failed due to picker returning an error
assertSame(Status.UNAVAILABLE, ((FailingClientStream) ff2.getRealStream()).getError());
wfr1.start(streamListener);
assertEquals(5, delayedTransport.getPendingCompleteStreamsCount());
// Other streams are still buffered
assertNull(ff3.getRealStream());
assertNull(ff4.getRealStream());
Expand All @@ -414,17 +447,24 @@ public void uncaughtException(Thread t, Throwable e) {
assertEquals(0, wfr3Executor.numPendingTasks());
verify(transportListener, never()).transportInUse(false);

ff3.start(streamListener);
ff4.start(streamListener);
wfr2.start(streamListener);
wfr3.start(streamListener);
wfr4.start(streamListener);
delayedTransport.reprocess(picker);
assertEquals(0, delayedTransport.getPendingStreamsCount());
assertEquals(5, delayedTransport.getPendingCompleteStreamsCount());
verify(transportListener).transportInUse(false);
inOrder.verify(picker).pickSubchannel(ff3args); // ff3
inOrder.verify(picker).pickSubchannel(ff4args); // ff4
inOrder.verify(picker).pickSubchannel(wfr2args); // wfr2
inOrder.verify(picker).pickSubchannel(wfr3args); // wfr3
inOrder.verify(picker).pickSubchannel(wfr4args); // wfr4
inOrder.verifyNoMoreInteractions();
fakeExecutor.runDueTasks();
assertEquals(4, fakeExecutor.runDueTasks());
assertEquals(0, fakeExecutor.numPendingTasks());
assertEquals(1, delayedTransport.getPendingCompleteStreamsCount());
assertSame(mockRealStream, ff3.getRealStream());
assertSame(mockRealStream2, ff4.getRealStream());
assertSame(mockRealStream2, wfr2.getRealStream());
Expand All @@ -434,15 +474,18 @@ public void uncaughtException(Thread t, Throwable e) {
assertNull(wfr3.getRealStream());
wfr3Executor.runDueTasks();
assertSame(mockRealStream, wfr3.getRealStream());
assertEquals(0, delayedTransport.getPendingCompleteStreamsCount());

// New streams will use the last picker
DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream(
method, headers, waitForReadyCallOptions);
wfr5.start(streamListener);
assertNull(wfr5.getRealStream());
inOrder.verify(picker).pickSubchannel(
new PickSubchannelArgsImpl(method, headers, waitForReadyCallOptions));
inOrder.verifyNoMoreInteractions();
assertEquals(1, delayedTransport.getPendingStreamsCount());
assertEquals(1, delayedTransport.getPendingCompleteStreamsCount());

// wfr5 will stop delayed transport from terminating
delayedTransport.shutdown(SHUTDOWN_STATUS);
Expand Down
8 changes: 7 additions & 1 deletion core/src/test/java/io/grpc/internal/DelayedStreamTest.java
Expand Up @@ -65,7 +65,7 @@ public class DelayedStreamTest {
@Mock private ClientStreamListener listener;
@Mock private ClientStream realStream;
@Captor private ArgumentCaptor<ClientStreamListener> listenerCaptor;
private DelayedStream stream = new DelayedStream();
private DelayedStream stream = new SimpleDelayedStream();

@Test
public void setStream_setAuthority() {
Expand Down Expand Up @@ -378,4 +378,10 @@ public Void answer(InvocationOnMock in) {
assertThat(insight.toString())
.matches("\\[buffered_nanos=[0-9]+, remote_addr=127\\.0\\.0\\.1:443\\]");
}

private static class SimpleDelayedStream extends DelayedStream {
@Override
void onTransferComplete() {
}
}
}