diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 23725788466..396c7cedfe2 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -104,6 +104,7 @@ abstract class RetriableStream implements ClientStream { @GuardedBy("lock") private FutureCanceller scheduledHedging; private long nextBackoffIntervalNanos; + private Status cancellationStatus; RetriableStream( MethodDescriptor method, Metadata headers, @@ -244,14 +245,16 @@ private void drain(Substream substream) { int index = 0; int chunk = 0x80; List list = null; + boolean streamStarted = false; while (true) { State savedState; synchronized (lock) { savedState = state; - if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { - // committed but not me + if (savedState.winningSubstream != null && savedState.winningSubstream != substream + && streamStarted) { + // committed but not me, to be cancelled break; } if (index == savedState.buffer.size()) { // I'm drained @@ -275,17 +278,22 @@ private void drain(Substream substream) { for (BufferEntry bufferEntry : list) { savedState = state; - if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { - // committed but not me + if (savedState.winningSubstream != null && savedState.winningSubstream != substream + && streamStarted) { + // committed but not me, to be cancelled break; } - if (savedState.cancelled) { + if (savedState.cancelled && streamStarted) { checkState( savedState.winningSubstream == substream, "substream should be CANCELLED_BECAUSE_COMMITTED already"); + substream.stream.cancel(cancellationStatus); return; } bufferEntry.runWith(substream); + if (bufferEntry instanceof RetriableStream.StartEntry) { + streamStarted = true; + } } } @@ -299,6 +307,13 @@ private void drain(Substream substream) { @Nullable abstract Status prestart(); + class StartEntry implements BufferEntry { + @Override + public void runWith(Substream substream) { + substream.stream.start(new Sublistener(substream)); + } + } + /** Starts the first PRC attempt. */ @Override public final void start(ClientStreamListener listener) { @@ -311,13 +326,6 @@ public final void start(ClientStreamListener listener) { return; } - class StartEntry implements BufferEntry { - @Override - public void runWith(Substream substream) { - substream.stream.start(new Sublistener(substream)); - } - } - synchronized (lock) { state.buffer.add(new StartEntry()); } @@ -450,11 +458,18 @@ public final void cancel(Status reason) { return; } - state.winningSubstream.stream.cancel(reason); + Substream winningSubstreamToCancel = null; synchronized (lock) { - // This is not required, but causes a short-circuit in the draining process. + if (state.drainedSubstreams.contains(state.winningSubstream)) { + winningSubstreamToCancel = state.winningSubstream; + } else { // the winningSubstream will be cancelled while draining + cancellationStatus = reason; + } state = state.cancelled(); } + if (winningSubstreamToCancel != null) { + winningSubstreamToCancel.stream.cancel(reason); + } } private void delayOrExecute(BufferEntry bufferEntry) {