diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index 4ef743bf96dd..8a7c8ea2aa73 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -199,6 +199,7 @@ public final void halfClose() { public final void cancel(Status reason) { Preconditions.checkArgument(!reason.isOk(), "Should not cancel with OK status"); cancelled = true; + framer().dispose(); abstractClientStreamSink().cancel(reason); } diff --git a/core/src/main/java/io/grpc/internal/MessageFramer.java b/core/src/main/java/io/grpc/internal/MessageFramer.java index 93d35250a0fe..cbf248238a22 100644 --- a/core/src/main/java/io/grpc/internal/MessageFramer.java +++ b/core/src/main/java/io/grpc/internal/MessageFramer.java @@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkState; import static java.lang.Math.min; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; import io.grpc.Codec; import io.grpc.Compressor; @@ -163,7 +164,7 @@ public void writePayload(InputStream message) { statsTraceCtx.outboundMessageSent(currentMessageSeqNo, currentMessageWireSize, written); } - private int writeUncompressed(InputStream message, int messageLength) throws IOException { + int writeUncompressed(InputStream message, int messageLength) throws IOException { if (messageLength != -1) { currentMessageWireSize = messageLength; return writeKnownLengthUncompressed(message, messageLength); @@ -213,7 +214,7 @@ private int getKnownLength(InputStream inputStream) throws IOException { /** * Write an unserialized message with a known length, uncompressed. */ - private int writeKnownLengthUncompressed(InputStream message, int messageLength) + int writeKnownLengthUncompressed(InputStream message, int messageLength) throws IOException { if (maxOutboundMessageSize >= 0 && messageLength > maxOutboundMessageSize) { throw Status.RESOURCE_EXHAUSTED @@ -279,8 +280,15 @@ private static int writeToOutputStream(InputStream message, OutputStream outputS } } - private void writeRaw(byte[] b, int off, int len) { + @VisibleForTesting + void writeRaw(byte[] b, int off, int len) { while (len > 0) { + if (isClosed()) { + if (buffer != null) { + dispose(); + } + return; + } if (buffer != null && buffer.writableBytes() == 0) { commitToSink(false, false); } diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 1cb2a668a456..ed92ccb59401 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -565,10 +565,6 @@ class SendMessageEntry implements BufferEntry { @Override public void runWith(Substream substream) { substream.stream.writeMessage(method.streamRequest(message)); - // TODO(ejona): Workaround Netty memory leak. Message writes always need to be followed by - // flushes (or half close), but retry appears to have a code path that the flushes may - // not happen. The code needs to be fixed and this removed. See #9340. - substream.stream.flush(); } } diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java index 4ce8a467d9f8..1399714c3a98 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java @@ -52,6 +52,8 @@ import java.io.IOException; import java.io.InputStream; import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Rule; @@ -158,6 +160,18 @@ public void cancel(Status errorStatus) { verify(mockListener).closed(any(Status.class), same(PROCESSED), any(Metadata.class)); } + @Test + public void cancel_closesFramerAndReleasesBuffers() { + TrackingWritableBufferAllocator trackingAllocator = new TrackingWritableBufferAllocator(); + AbstractClientStream stream = + new BaseAbstractClientStream(trackingAllocator, statsTraceCtx, transportTracer); + stream.start(mockListener); + stream.writeMessage(new ByteArrayInputStream(new byte[1])); + stream.cancel(Status.DEADLINE_EXCEEDED); + assertTrue(trackingAllocator.allocatedBuffersReleased()); + assertTrue(stream.framer().isClosed()); + } + @Test public void startFailsOnNullListener() { AbstractClientStream stream = @@ -584,4 +598,37 @@ public void runOnTransportThread(Runnable r) { r.run(); } } + + private static class TrackingWritableBufferAllocator implements WritableBufferAllocator { + List allocatedBuffers = new ArrayList<>(); + + @Override + public WritableBuffer allocate(int capacityHint) { + ReleaseVerifyingBuffer buf = new ReleaseVerifyingBuffer(capacityHint); + allocatedBuffers.add(buf); + return buf; + } + + boolean allocatedBuffersReleased() { + return allocatedBuffers.stream().allMatch(ReleaseVerifyingBuffer::isReleased); + } + } + + private static class ReleaseVerifyingBuffer extends ByteWritableBuffer { + boolean isReleased; + + ReleaseVerifyingBuffer(int maxFrameSize) { + super(maxFrameSize); + } + + @Override + public void release() { + super.release(); + isReleased = true; + } + + boolean isReleased() { + return isReleased; + } + } } diff --git a/core/src/test/java/io/grpc/internal/MessageFramerTest.java b/core/src/test/java/io/grpc/internal/MessageFramerTest.java index 07f717cb81da..bfa595d56f1f 100644 --- a/core/src/test/java/io/grpc/internal/MessageFramerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageFramerTest.java @@ -31,6 +31,7 @@ import io.grpc.internal.testing.TestStreamTracer.TestBaseStreamTracer; import java.io.BufferedInputStream; import java.io.ByteArrayInputStream; +import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; @@ -323,6 +324,15 @@ public void dontCompressIfNotRequested() { checkStats(1000, 1000); } + @Test + public void writeRawNoopIfDisposed() { + byte[] bytes = new byte[]{3, 14}; + framer.dispose(); + framer.writeRaw(bytes, 0, bytes.length); + verifyNoMoreInteractions(sink); + assertEquals(0, allocator.allocCount); + } + @Test public void closeIsRentrantSafe() { MessageFramer.Sink reentrant = new MessageFramer.Sink() { diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 9b1f6284bcc5..e028d0d17945 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -279,14 +279,10 @@ public Void answer(InvocationOnMock in) { retriableStream.sendMessage("msg3"); retriableStream.request(456); - inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround - inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround + inOrder.verify(mockStream1, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream1).request(345); inOrder.verify(mockStream1, times(2)).flush(); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream1).request(456); inOrder.verifyNoMoreInteractions(); @@ -319,19 +315,12 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround + inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream2).request(345); inOrder.verify(mockStream2, times(2)).flush(); inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream2).request(456); - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround + inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); @@ -341,10 +330,7 @@ public Void answer(InvocationOnMock in) { // mockStream1 is closed so it is not in the drainedSubstreams verifyNoMoreInteractions(mockStream1); - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround + inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); // retry2 doReturn(mockStream3).when(retriableStreamRecorder).newSubstream(2); @@ -378,19 +364,12 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround + inOrder.verify(mockStream3, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream3).request(345); inOrder.verify(mockStream3, times(2)).flush(); inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround inOrder.verify(mockStream3).request(456); - for (int i = 0; i < 7; i++) { - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround - } + inOrder.verify(mockStream3, times(7)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); @@ -2103,14 +2082,10 @@ public Void answer(InvocationOnMock in) { hedgingStream.sendMessage("msg3"); hedgingStream.request(456); - inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround - inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround + inOrder.verify(mockStream1, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream1).request(345); inOrder.verify(mockStream1, times(2)).flush(); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream1).request(456); inOrder.verifyNoMoreInteractions(); @@ -2133,14 +2108,10 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround + inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream2).request(345); inOrder.verify(mockStream2, times(2)).flush(); inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream2).request(456); inOrder.verify(mockStream1).isReady(); inOrder.verify(mockStream2).isReady(); @@ -2151,13 +2122,9 @@ public Void answer(InvocationOnMock in) { hedgingStream.sendMessage("msg2 after hedge2 starts"); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verifyNoMoreInteractions(); @@ -2179,19 +2146,12 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround + inOrder.verify(mockStream3, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream3).request(345); inOrder.verify(mockStream3, times(2)).flush(); inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround inOrder.verify(mockStream3).request(456); - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround + inOrder.verify(mockStream3, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream1).isReady(); inOrder.verify(mockStream2).isReady(); inOrder.verify(mockStream3).isReady(); @@ -2200,11 +2160,8 @@ public Void answer(InvocationOnMock in) { // send one more message hedgingStream.sendMessage("msg1 after hedge3 starts"); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround // hedge3 receives nonFatalStatus sublistenerCaptor3.getValue().closed( @@ -2214,9 +2171,7 @@ public Void answer(InvocationOnMock in) { // send one more message hedgingStream.sendMessage("msg1 after hedge3 fails"); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround // the hedge mockStream4 starts fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2236,19 +2191,12 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor4 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream4).start(sublistenerCaptor4.capture()); - inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream4).flush(); // Memory leak workaround - inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream4).flush(); // Memory leak workaround + inOrder.verify(mockStream4, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream4).request(345); inOrder.verify(mockStream4, times(2)).flush(); inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream4).flush(); // Memory leak workaround inOrder.verify(mockStream4).request(456); - for (int i = 0; i < 4; i++) { - inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream4).flush(); // Memory leak workaround - } + inOrder.verify(mockStream4, times(4)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream1).isReady(); inOrder.verify(mockStream2).isReady(); inOrder.verify(mockStream4).isReady(); @@ -2371,7 +2319,6 @@ public void hedging_maxAttempts() { hedgingStream.sendMessage("msg1 after commit"); inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround inOrder.verifyNoMoreInteractions(); Metadata heders = new Metadata();