From 6a185919f9888aa86c6e284eb64a06fda89673a7 Mon Sep 17 00:00:00 2001 From: Jacob Kiefer Date: Fri, 19 May 2023 17:32:11 -0400 Subject: [PATCH] core: Rework retry memory leak fix in https://github.com/grpc/grpc-java/pull/9360 to send fewer FIN packets. Without the explicit stream flush after writes in RetriableStream, buffer data is being orphaned on RetriableStream cancellation since calls to writeRaw() can happen after dispose() calls. We should verify the framer is not closed before continuing with writes since writes and dispose can be interleaved. --- .../grpc/internal/AbstractClientStream.java | 1 + .../java/io/grpc/internal/MessageFramer.java | 14 +++- .../io/grpc/internal/RetriableStream.java | 4 - .../internal/AbstractClientStreamTest.java | 47 +++++++++++ .../io/grpc/internal/MessageFramerTest.java | 10 +++ .../io/grpc/internal/RetriableStreamTest.java | 77 +++---------------- 6 files changed, 81 insertions(+), 72 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index 4ef743bf96dd..8a7c8ea2aa73 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -199,6 +199,7 @@ public final void halfClose() { public final void cancel(Status reason) { Preconditions.checkArgument(!reason.isOk(), "Should not cancel with OK status"); cancelled = true; + framer().dispose(); abstractClientStreamSink().cancel(reason); } diff --git a/core/src/main/java/io/grpc/internal/MessageFramer.java b/core/src/main/java/io/grpc/internal/MessageFramer.java index 93d35250a0fe..cbf248238a22 100644 --- a/core/src/main/java/io/grpc/internal/MessageFramer.java +++ b/core/src/main/java/io/grpc/internal/MessageFramer.java @@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkState; import static java.lang.Math.min; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; import io.grpc.Codec; import io.grpc.Compressor; @@ -163,7 +164,7 @@ public void writePayload(InputStream message) { statsTraceCtx.outboundMessageSent(currentMessageSeqNo, currentMessageWireSize, written); } - private int writeUncompressed(InputStream message, int messageLength) throws IOException { + int writeUncompressed(InputStream message, int messageLength) throws IOException { if (messageLength != -1) { currentMessageWireSize = messageLength; return writeKnownLengthUncompressed(message, messageLength); @@ -213,7 +214,7 @@ private int getKnownLength(InputStream inputStream) throws IOException { /** * Write an unserialized message with a known length, uncompressed. */ - private int writeKnownLengthUncompressed(InputStream message, int messageLength) + int writeKnownLengthUncompressed(InputStream message, int messageLength) throws IOException { if (maxOutboundMessageSize >= 0 && messageLength > maxOutboundMessageSize) { throw Status.RESOURCE_EXHAUSTED @@ -279,8 +280,15 @@ private static int writeToOutputStream(InputStream message, OutputStream outputS } } - private void writeRaw(byte[] b, int off, int len) { + @VisibleForTesting + void writeRaw(byte[] b, int off, int len) { while (len > 0) { + if (isClosed()) { + if (buffer != null) { + dispose(); + } + return; + } if (buffer != null && buffer.writableBytes() == 0) { commitToSink(false, false); } diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 1cb2a668a456..ed92ccb59401 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -565,10 +565,6 @@ class SendMessageEntry implements BufferEntry { @Override public void runWith(Substream substream) { substream.stream.writeMessage(method.streamRequest(message)); - // TODO(ejona): Workaround Netty memory leak. Message writes always need to be followed by - // flushes (or half close), but retry appears to have a code path that the flushes may - // not happen. The code needs to be fixed and this removed. See #9340. - substream.stream.flush(); } } diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java index 4ce8a467d9f8..1399714c3a98 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java @@ -52,6 +52,8 @@ import java.io.IOException; import java.io.InputStream; import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Rule; @@ -158,6 +160,18 @@ public void cancel(Status errorStatus) { verify(mockListener).closed(any(Status.class), same(PROCESSED), any(Metadata.class)); } + @Test + public void cancel_closesFramerAndReleasesBuffers() { + TrackingWritableBufferAllocator trackingAllocator = new TrackingWritableBufferAllocator(); + AbstractClientStream stream = + new BaseAbstractClientStream(trackingAllocator, statsTraceCtx, transportTracer); + stream.start(mockListener); + stream.writeMessage(new ByteArrayInputStream(new byte[1])); + stream.cancel(Status.DEADLINE_EXCEEDED); + assertTrue(trackingAllocator.allocatedBuffersReleased()); + assertTrue(stream.framer().isClosed()); + } + @Test public void startFailsOnNullListener() { AbstractClientStream stream = @@ -584,4 +598,37 @@ public void runOnTransportThread(Runnable r) { r.run(); } } + + private static class TrackingWritableBufferAllocator implements WritableBufferAllocator { + List allocatedBuffers = new ArrayList<>(); + + @Override + public WritableBuffer allocate(int capacityHint) { + ReleaseVerifyingBuffer buf = new ReleaseVerifyingBuffer(capacityHint); + allocatedBuffers.add(buf); + return buf; + } + + boolean allocatedBuffersReleased() { + return allocatedBuffers.stream().allMatch(ReleaseVerifyingBuffer::isReleased); + } + } + + private static class ReleaseVerifyingBuffer extends ByteWritableBuffer { + boolean isReleased; + + ReleaseVerifyingBuffer(int maxFrameSize) { + super(maxFrameSize); + } + + @Override + public void release() { + super.release(); + isReleased = true; + } + + boolean isReleased() { + return isReleased; + } + } } diff --git a/core/src/test/java/io/grpc/internal/MessageFramerTest.java b/core/src/test/java/io/grpc/internal/MessageFramerTest.java index 07f717cb81da..bfa595d56f1f 100644 --- a/core/src/test/java/io/grpc/internal/MessageFramerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageFramerTest.java @@ -31,6 +31,7 @@ import io.grpc.internal.testing.TestStreamTracer.TestBaseStreamTracer; import java.io.BufferedInputStream; import java.io.ByteArrayInputStream; +import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; @@ -323,6 +324,15 @@ public void dontCompressIfNotRequested() { checkStats(1000, 1000); } + @Test + public void writeRawNoopIfDisposed() { + byte[] bytes = new byte[]{3, 14}; + framer.dispose(); + framer.writeRaw(bytes, 0, bytes.length); + verifyNoMoreInteractions(sink); + assertEquals(0, allocator.allocCount); + } + @Test public void closeIsRentrantSafe() { MessageFramer.Sink reentrant = new MessageFramer.Sink() { diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 9b1f6284bcc5..e028d0d17945 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -279,14 +279,10 @@ public Void answer(InvocationOnMock in) { retriableStream.sendMessage("msg3"); retriableStream.request(456); - inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround - inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround + inOrder.verify(mockStream1, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream1).request(345); inOrder.verify(mockStream1, times(2)).flush(); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream1).request(456); inOrder.verifyNoMoreInteractions(); @@ -319,19 +315,12 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround + inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream2).request(345); inOrder.verify(mockStream2, times(2)).flush(); inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream2).request(456); - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround + inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); @@ -341,10 +330,7 @@ public Void answer(InvocationOnMock in) { // mockStream1 is closed so it is not in the drainedSubstreams verifyNoMoreInteractions(mockStream1); - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround + inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); // retry2 doReturn(mockStream3).when(retriableStreamRecorder).newSubstream(2); @@ -378,19 +364,12 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround + inOrder.verify(mockStream3, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream3).request(345); inOrder.verify(mockStream3, times(2)).flush(); inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround inOrder.verify(mockStream3).request(456); - for (int i = 0; i < 7; i++) { - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround - } + inOrder.verify(mockStream3, times(7)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); @@ -2103,14 +2082,10 @@ public Void answer(InvocationOnMock in) { hedgingStream.sendMessage("msg3"); hedgingStream.request(456); - inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround - inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround + inOrder.verify(mockStream1, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream1).request(345); inOrder.verify(mockStream1, times(2)).flush(); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream1).request(456); inOrder.verifyNoMoreInteractions(); @@ -2133,14 +2108,10 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround - inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround + inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream2).request(345); inOrder.verify(mockStream2, times(2)).flush(); inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream2).request(456); inOrder.verify(mockStream1).isReady(); inOrder.verify(mockStream2).isReady(); @@ -2151,13 +2122,9 @@ public Void answer(InvocationOnMock in) { hedgingStream.sendMessage("msg2 after hedge2 starts"); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verifyNoMoreInteractions(); @@ -2179,19 +2146,12 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround + inOrder.verify(mockStream3, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream3).request(345); inOrder.verify(mockStream3, times(2)).flush(); inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround inOrder.verify(mockStream3).request(456); - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround - inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround + inOrder.verify(mockStream3, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream1).isReady(); inOrder.verify(mockStream2).isReady(); inOrder.verify(mockStream3).isReady(); @@ -2200,11 +2160,8 @@ public Void answer(InvocationOnMock in) { // send one more message hedgingStream.sendMessage("msg1 after hedge3 starts"); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround // hedge3 receives nonFatalStatus sublistenerCaptor3.getValue().closed( @@ -2214,9 +2171,7 @@ public Void answer(InvocationOnMock in) { // send one more message hedgingStream.sendMessage("msg1 after hedge3 fails"); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream1).flush(); // Memory leak workaround inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream2).flush(); // Memory leak workaround // the hedge mockStream4 starts fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2236,19 +2191,12 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor4 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream4).start(sublistenerCaptor4.capture()); - inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream4).flush(); // Memory leak workaround - inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream4).flush(); // Memory leak workaround + inOrder.verify(mockStream4, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream4).request(345); inOrder.verify(mockStream4, times(2)).flush(); inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream4).flush(); // Memory leak workaround inOrder.verify(mockStream4).request(456); - for (int i = 0; i < 4; i++) { - inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream4).flush(); // Memory leak workaround - } + inOrder.verify(mockStream4, times(4)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream1).isReady(); inOrder.verify(mockStream2).isReady(); inOrder.verify(mockStream4).isReady(); @@ -2371,7 +2319,6 @@ public void hedging_maxAttempts() { hedgingStream.sendMessage("msg1 after commit"); inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); - inOrder.verify(mockStream3).flush(); // Memory leak workaround inOrder.verifyNoMoreInteractions(); Metadata heders = new Metadata();