diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 23725788466..d19a260049b 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -104,6 +104,7 @@ abstract class RetriableStream implements ClientStream { @GuardedBy("lock") private FutureCanceller scheduledHedging; private long nextBackoffIntervalNanos; + private Status cancellationStatus; RetriableStream( MethodDescriptor method, Metadata headers, @@ -244,15 +245,21 @@ private void drain(Substream substream) { int index = 0; int chunk = 0x80; List list = null; + boolean streamStarted = false; while (true) { State savedState; synchronized (lock) { savedState = state; - if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { - // committed but not me - break; + if (streamStarted) { + if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { + // committed but not me, to be cancelled + break; + } + if (savedState.cancelled) { + break; + } } if (index == savedState.buffer.size()) { // I'm drained state = savedState.substreamDrained(substream); @@ -274,22 +281,25 @@ private void drain(Substream substream) { } for (BufferEntry bufferEntry : list) { - savedState = state; - if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { - // committed but not me - break; + bufferEntry.runWith(substream); + if (bufferEntry instanceof RetriableStream.StartEntry) { + streamStarted = true; } - if (savedState.cancelled) { - checkState( - savedState.winningSubstream == substream, - "substream should be CANCELLED_BECAUSE_COMMITTED already"); - return; + if (streamStarted) { + savedState = state; + if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { + // committed but not me, to be cancelled + break; + } + if (savedState.cancelled) { + break; + } } - bufferEntry.runWith(substream); } } - substream.stream.cancel(CANCELLED_BECAUSE_COMMITTED); + substream.stream.cancel( + state.winningSubstream == substream ? cancellationStatus : CANCELLED_BECAUSE_COMMITTED); } /** @@ -299,6 +309,13 @@ private void drain(Substream substream) { @Nullable abstract Status prestart(); + class StartEntry implements BufferEntry { + @Override + public void runWith(Substream substream) { + substream.stream.start(new Sublistener(substream)); + } + } + /** Starts the first PRC attempt. */ @Override public final void start(ClientStreamListener listener) { @@ -311,13 +328,6 @@ public final void start(ClientStreamListener listener) { return; } - class StartEntry implements BufferEntry { - @Override - public void runWith(Substream substream) { - substream.stream.start(new Sublistener(substream)); - } - } - synchronized (lock) { state.buffer.add(new StartEntry()); } @@ -450,11 +460,18 @@ public final void cancel(Status reason) { return; } - state.winningSubstream.stream.cancel(reason); + Substream winningSubstreamToCancel = null; synchronized (lock) { - // This is not required, but causes a short-circuit in the draining process. + if (state.drainedSubstreams.contains(state.winningSubstream)) { + winningSubstreamToCancel = state.winningSubstream; + } else { // the winningSubstream will be cancelled while draining + cancellationStatus = reason; + } state = state.cancelled(); } + if (winningSubstreamToCancel != null) { + winningSubstreamToCancel.stream.cancel(reason); + } } private void delayOrExecute(BufferEntry bufferEntry) { diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 26c6fcf9b4e..95d2c2ba8b5 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -749,6 +749,91 @@ public void request(int numMessages) { inOrder.verify(mockStream2, never()).writeMessage(any(InputStream.class)); } + @Test + public void cancelWhileDraining() { + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + ClientStream mockStream1 = mock(ClientStream.class); + ClientStream mockStream2 = + mock( + ClientStream.class, + delegatesTo( + new NoopClientStream() { + @Override + public void request(int numMessages) { + retriableStream.cancel( + Status.CANCELLED.withDescription("cancelled while requesting")); + } + })); + + InOrder inOrder = inOrder(retriableStreamRecorder, mockStream1, mockStream2); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + retriableStream.start(masterListener); + inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + retriableStream.request(3); + inOrder.verify(mockStream1).request(3); + + // retry + doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); + sublistenerCaptor1.getValue().closed( + Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + + inOrder.verify(mockStream2).start(any(ClientStreamListener.class)); + inOrder.verify(mockStream2).request(3); + inOrder.verify(retriableStreamRecorder).postCommit(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + inOrder.verify(mockStream2).cancel(statusCaptor.capture()); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()) + .isEqualTo("Stream thrown away because RetriableStream committed"); + verify(masterListener).closed( + statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()).isEqualTo("cancelled while requesting"); + } + + @Test + public void cancelWhileRetryStart() { + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + ClientStream mockStream1 = mock(ClientStream.class); + ClientStream mockStream2 = + mock( + ClientStream.class, + delegatesTo( + new NoopClientStream() { + @Override + public void start(ClientStreamListener listener) { + retriableStream.cancel( + Status.CANCELLED.withDescription("cancelled while retry start")); + } + })); + + InOrder inOrder = inOrder(retriableStreamRecorder, mockStream1, mockStream2); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + retriableStream.start(masterListener); + inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + + // retry + doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); + sublistenerCaptor1.getValue().closed( + Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + + inOrder.verify(mockStream2).start(any(ClientStreamListener.class)); + inOrder.verify(retriableStreamRecorder).postCommit(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + inOrder.verify(mockStream2).cancel(statusCaptor.capture()); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()) + .isEqualTo("Stream thrown away because RetriableStream committed"); + verify(masterListener).closed( + statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()).isEqualTo("cancelled while retry start"); + } + @Test public void operationsAfterImmediateCommit() { ArgumentCaptor sublistenerCaptor1 = @@ -916,6 +1001,47 @@ public void start(ClientStreamListener listener) { verify(mockStream3).request(1); } + @Test + public void commitAndCancelWhileDraining() { + ClientStream mockStream1 = mock(ClientStream.class); + ClientStream mockStream2 = + mock( + ClientStream.class, + delegatesTo( + new NoopClientStream() { + @Override + public void start(ClientStreamListener listener) { + // commit while draining + listener.headersRead(new Metadata()); + // cancel while draining + retriableStream.cancel( + Status.CANCELLED.withDescription("cancelled while drained")); + } + })); + + when(retriableStreamRecorder.newSubstream(anyInt())) + .thenReturn(mockStream1, mockStream2); + + retriableStream.start(masterListener); + + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream1).start(sublistenerCaptor1.capture()); + + ClientStreamListener listener1 = sublistenerCaptor1.getValue(); + + // retry + listener1.closed( + Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + + verify(mockStream2).start(any(ClientStreamListener.class)); + verify(retriableStreamRecorder).postCommit(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(mockStream2).cancel(statusCaptor.capture()); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()).isEqualTo("cancelled while drained"); + } @Test public void perRpcBufferLimitExceeded() {