From 1b6987265c335236c00a2f8af5ec58352905e999 Mon Sep 17 00:00:00 2001 From: "Penn (Dapeng) Zhang" Date: Thu, 25 Mar 2021 17:32:26 -0700 Subject: [PATCH 1/8] add regression test --- .../io/grpc/netty/NettyClientHandlerTest.java | 89 +++++++++++++++++++ .../io/grpc/netty/NettyHandlerTestBase.java | 6 ++ 2 files changed, 95 insertions(+) diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index 25813621cc6..223f4781538 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -383,6 +383,95 @@ public void receivedAbruptGoAwayShouldFailRacingQueuedStreamid() throws Exceptio assertTrue(future.isDone()); } + @Test + public void receivedGoAway_shouldNotFailBufferedStreamWithStreamIdLessThanLastId() + throws Exception { + ClientStreamListener streamListener1 = mock(ClientStreamListener.class); + NettyClientStream.TransportState streamTransportState1 = new TransportStateImpl( + handler(), + channel().eventLoop(), + DEFAULT_MAX_MESSAGE_SIZE, + transportTracer); + streamTransportState1.setListener(streamListener1); + ClientStreamListener streamListener2 = mock(ClientStreamListener.class); + NettyClientStream.TransportState streamTransportState2 = new TransportStateImpl( + handler(), + channel().eventLoop(), + DEFAULT_MAX_MESSAGE_SIZE, + transportTracer); + streamTransportState2.setListener(streamListener2); + // MAX_CONCURRENT_STREAMS=1 + receiveMaxConcurrentStreams(1); + ChannelFuture future1 = writeQueue().enqueue( + newCreateStreamCommand(grpcHeaders, streamTransportState1), true); + ChannelFuture future2 = writeQueue().enqueue( + newCreateStreamCommand(grpcHeaders, streamTransportState2), true); + + // GOAWAY + channelRead(goAwayFrame(Integer.MAX_VALUE)); + assertTrue(future1.isSuccess()); + verify(streamListener1).onReady(); + channel().runPendingTasks(); + if (future2.cause() != null ) { + throw new AssertionError(future2.cause()); + } + assertFalse(future2.isDone()); + verify(streamListener2, never()).onReady(); + + // Let the first stream complete, then the pending stream will be activated. + Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK) + .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC); + channelRead(headersFrame(streamId, headers)); + channelRead(grpcDataFrame(streamId, false, contentAsArray())); + channelRead(trailersFrame( + streamId, + new DefaultHttp2Headers().set(AsciiString.of("grpc-status"), AsciiString.of("0")))); + streamTransportState1.requestMessagesFromDeframerForTesting(1); + verify(streamListener1).closed(eq(Status.OK), any(RpcProgress.class), + any(Metadata.class)); + channel().runPendingTasks(); + if (future2.cause() != null) { + throw new AssertionError(future2.cause()); + } + assertTrue(future2.isSuccess()); + verify(streamListener2).onReady(); + } + + @Test + public void receivedGoAway_shouldRefuseBufferedStreamWithStreamIdGreaterThanLastId() + throws Exception { + ClientStreamListener streamListener1 = mock(ClientStreamListener.class); + NettyClientStream.TransportState streamTransportState1 = new TransportStateImpl( + handler(), + channel().eventLoop(), + DEFAULT_MAX_MESSAGE_SIZE, + transportTracer); + streamTransportState1.setListener(streamListener1); + ClientStreamListener streamListener2 = mock(ClientStreamListener.class); + NettyClientStream.TransportState streamTransportState2 = new TransportStateImpl( + handler(), + channel().eventLoop(), + DEFAULT_MAX_MESSAGE_SIZE, + transportTracer); + streamTransportState2.setListener(streamListener2); + // MAX_CONCURRENT_STREAMS=1 + receiveMaxConcurrentStreams(1); + ChannelFuture future1 = writeQueue().enqueue( + newCreateStreamCommand(grpcHeaders, streamTransportState1), true); + ChannelFuture future2 = writeQueue().enqueue( + newCreateStreamCommand(grpcHeaders, streamTransportState2), true); + + // GOAWAY + channelRead(goAwayFrame(streamId)); + assertTrue(future1.isSuccess()); + verify(streamListener1).onReady(); + assertThat(future2.cause()).isNotNull(); + assertThat(Status.fromThrowable(future2.cause()).getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(Status.fromThrowable(future2.cause()).getDescription()).isEqualTo( + "Abrupt GOAWAY closed unsent stream. HTTP/2 error code: NO_ERROR"); + verify(streamListener2).closed(any(Status.class), eq(RpcProgress.REFUSED), any(Metadata.class)); + } + @Test public void receivedResetWithRefuseCode() throws Exception { ChannelFuture future = enqueue(newCreateStreamCommand(grpcHeaders, streamTransportState)); diff --git a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java index 04f65eed145..684f2050ac2 100644 --- a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java @@ -292,6 +292,12 @@ protected final ByteBuf headersFrame(int streamId, Http2Headers headers) { return captureWrite(ctx); } + protected final ByteBuf trailersFrame(int streamId, Http2Headers headers) { + ChannelHandlerContext ctx = newMockContext(); + new DefaultHttp2FrameWriter().writeHeaders(ctx, streamId, headers, 0, true, newPromise()); + return captureWrite(ctx); + } + protected final ByteBuf goAwayFrame(int lastStreamId) { return goAwayFrame(lastStreamId, 0, Unpooled.EMPTY_BUFFER); } From ee8fb12f230133f2ede33e35fc0ef72886f2062f Mon Sep 17 00:00:00 2001 From: "Penn (Dapeng) Zhang" Date: Thu, 25 Mar 2021 17:33:56 -0700 Subject: [PATCH 2/8] copy StreamBufferingEncoder from Netty --- .../io/grpc/netty/NettyClientHandler.java | 1 - .../io/grpc/netty/StreamBufferingEncoder.java | 375 ++++++++++++ .../netty/StreamBufferingEncoderTest.java | 572 ++++++++++++++++++ 3 files changed, 947 insertions(+), 1 deletion(-) create mode 100644 netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java create mode 100644 netty/src/test/java/io/grpc/netty/StreamBufferingEncoderTest.java diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index 07134939673..ecd3bb086b6 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -74,7 +74,6 @@ import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.codec.http2.Http2Stream; import io.netty.handler.codec.http2.Http2StreamVisitor; -import io.netty.handler.codec.http2.StreamBufferingEncoder; import io.netty.handler.codec.http2.WeightedFairQueueByteDistributor; import io.netty.handler.logging.LogLevel; import io.perfmark.PerfMark; diff --git a/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java b/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java new file mode 100644 index 00000000000..a37221b1284 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java @@ -0,0 +1,375 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoder; +import io.netty.handler.codec.http2.Http2CodecUtil; +import io.netty.handler.codec.http2.Http2ConnectionAdapter; +import io.netty.handler.codec.http2.Http2ConnectionEncoder; +import io.netty.handler.codec.http2.Http2Error; +import io.netty.handler.codec.http2.Http2Exception; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.handler.codec.http2.Http2Stream; +import io.netty.util.ReferenceCountUtil; +import java.util.ArrayDeque; +import java.util.Iterator; +import java.util.Map; +import java.util.Queue; +import java.util.TreeMap; + +/** + * Implementation of a {@link Http2ConnectionEncoder} that dispatches all method call to another + * {@link Http2ConnectionEncoder}, until {@code SETTINGS_MAX_CONCURRENT_STREAMS} is reached. + * + *

When this limit is hit, instead of rejecting any new streams this implementation buffers newly + * created streams and their corresponding frames. Once an active stream gets closed or the maximum + * number of concurrent streams is increased, this encoder will automatically try to empty its + * buffer and create as many new streams as possible. + * + *

If a {@code GOAWAY} frame is received from the remote endpoint, all buffered writes for + * streams with an ID less than the specified {@code lastStreamId} will immediately fail with a + * {@link io.netty.handler.codec.http2.StreamBufferingEncoder.Http2GoAwayException}. + * + *

If the channel/encoder gets closed, all new and buffered writes will immediately fail with a + * {@link io.netty.handler.codec.http2.StreamBufferingEncoder.Http2ChannelClosedException}. + * + *

This implementation makes the buffering mostly transparent and is expected to be used as a + * drop-in decorator of {@link io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder}. + */ +class StreamBufferingEncoder extends DecoratingHttp2ConnectionEncoder { + + /** + * Thrown if buffered streams are terminated due to this encoder being closed. + */ + public static final class Http2ChannelClosedException extends Http2Exception { + private static final long serialVersionUID = 4768543442094476971L; + + public Http2ChannelClosedException() { + super(Http2Error.REFUSED_STREAM, "Connection closed"); + } + } + + /** + * Thrown by {@link StreamBufferingEncoder} if buffered streams are terminated due to + * receipt of a {@code GOAWAY}. + */ + public static final class Http2GoAwayException extends Http2Exception { + private static final long serialVersionUID = 1326785622777291198L; + private final int lastStreamId; + private final long errorCode; + private final byte[] debugData; + + public Http2GoAwayException(int lastStreamId, long errorCode, byte[] debugData) { + super(Http2Error.STREAM_CLOSED); + this.lastStreamId = lastStreamId; + this.errorCode = errorCode; + this.debugData = debugData; + } + + public int lastStreamId() { + return lastStreamId; + } + + public long errorCode() { + return errorCode; + } + + public byte[] debugData() { + return debugData; + } + } + + /** + * Buffer for any streams and corresponding frames that could not be created due to the maximum + * concurrent stream limit being hit. + */ + private final TreeMap pendingStreams = new TreeMap<>(); + private int maxConcurrentStreams; + private boolean closed; + + public StreamBufferingEncoder(Http2ConnectionEncoder delegate) { + this(delegate, SMALLEST_MAX_CONCURRENT_STREAMS); + } + + public StreamBufferingEncoder(Http2ConnectionEncoder delegate, int initialMaxConcurrentStreams) { + super(delegate); + this.maxConcurrentStreams = initialMaxConcurrentStreams; + connection().addListener(new Http2ConnectionAdapter() { + + @Override + public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) { + cancelGoAwayStreams(lastStreamId, errorCode, debugData); + } + + @Override + public void onStreamClosed(Http2Stream stream) { + tryCreatePendingStreams(); + } + }); + } + + /** + * Indicates the number of streams that are currently buffered, awaiting creation. + */ + public int numBufferedStreams() { + return pendingStreams.size(); + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int padding, boolean endStream, ChannelPromise promise) { + return writeHeaders(ctx, streamId, headers, 0, Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT, + false, padding, endStream, promise); + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int streamDependency, short weight, boolean exclusive, + int padding, boolean endOfStream, ChannelPromise promise) { + if (closed) { + return promise.setFailure(new Http2ChannelClosedException()); + } + if (isExistingStream(streamId) || connection().goAwayReceived()) { + return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, + exclusive, padding, endOfStream, promise); + } + if (canCreateStream()) { + return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, + exclusive, padding, endOfStream, promise); + } + PendingStream pendingStream = pendingStreams.get(streamId); + if (pendingStream == null) { + pendingStream = new PendingStream(ctx, streamId); + pendingStreams.put(streamId, pendingStream); + } + pendingStream.frames.add(new HeadersFrame(headers, streamDependency, weight, exclusive, + padding, endOfStream, promise)); + return promise; + } + + @Override + public ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode, + ChannelPromise promise) { + if (isExistingStream(streamId)) { + return super.writeRstStream(ctx, streamId, errorCode, promise); + } + // Since the delegate doesn't know about any buffered streams we have to handle cancellation + // of the promises and releasing of the ByteBufs here. + PendingStream stream = pendingStreams.remove(streamId); + if (stream != null) { + // Sending a RST_STREAM to a buffered stream will succeed the promise of all frames + // associated with the stream, as sending a RST_STREAM means that someone "doesn't care" + // about the stream anymore and thus there is not point in failing the promises and invoking + // error handling routines. + stream.close(null); + promise.setSuccess(); + } else { + promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId)); + } + return promise; + } + + @Override + public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data, + int padding, boolean endOfStream, ChannelPromise promise) { + if (isExistingStream(streamId)) { + return super.writeData(ctx, streamId, data, padding, endOfStream, promise); + } + PendingStream pendingStream = pendingStreams.get(streamId); + if (pendingStream != null) { + pendingStream.frames.add(new DataFrame(data, padding, endOfStream, promise)); + } else { + ReferenceCountUtil.safeRelease(data); + promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId)); + } + return promise; + } + + @Override + public void remoteSettings(Http2Settings settings) throws Http2Exception { + // Need to let the delegate decoder handle the settings first, so that it sees the + // new setting before we attempt to create any new streams. + super.remoteSettings(settings); + + // Get the updated value for SETTINGS_MAX_CONCURRENT_STREAMS. + maxConcurrentStreams = connection().local().maxActiveStreams(); + + // Try to create new streams up to the new threshold. + tryCreatePendingStreams(); + } + + @Override + public void close() { + try { + if (!closed) { + closed = true; + + // Fail all buffered streams. + Http2ChannelClosedException e = new Http2ChannelClosedException(); + while (!pendingStreams.isEmpty()) { + PendingStream stream = pendingStreams.pollFirstEntry().getValue(); + stream.close(e); + } + } + } finally { + super.close(); + } + } + + private void tryCreatePendingStreams() { + while (!pendingStreams.isEmpty() && canCreateStream()) { + Map.Entry entry = pendingStreams.pollFirstEntry(); + PendingStream pendingStream = entry.getValue(); + try { + pendingStream.sendFrames(); + } catch (Throwable t) { + pendingStream.close(t); + } + } + } + + private void cancelGoAwayStreams(int lastStreamId, long errorCode, ByteBuf debugData) { + Iterator iter = pendingStreams.values().iterator(); + Exception e = + new Http2GoAwayException(lastStreamId, errorCode, ByteBufUtil.getBytes(debugData)); + while (iter.hasNext()) { + PendingStream stream = iter.next(); + if (stream.streamId > lastStreamId) { + iter.remove(); + stream.close(e); + } + } + } + + /** + * Determines whether or not we're allowed to create a new stream right now. + */ + private boolean canCreateStream() { + return connection().local().numActiveStreams() < maxConcurrentStreams; + } + + private boolean isExistingStream(int streamId) { + return streamId <= connection().local().lastStreamCreated(); + } + + private static final class PendingStream { + final ChannelHandlerContext ctx; + final int streamId; + final Queue frames = new ArrayDeque<>(2); + + PendingStream(ChannelHandlerContext ctx, int streamId) { + this.ctx = ctx; + this.streamId = streamId; + } + + void sendFrames() { + for (Frame frame : frames) { + frame.send(ctx, streamId); + } + } + + void close(Throwable t) { + for (Frame frame : frames) { + frame.release(t); + } + } + } + + private abstract static class Frame { + final ChannelPromise promise; + + Frame(ChannelPromise promise) { + this.promise = promise; + } + + /** + * Release any resources (features, buffers, ...) associated with the frame. + */ + void release(Throwable t) { + if (t == null) { + promise.setSuccess(); + } else { + promise.setFailure(t); + } + } + + abstract void send(ChannelHandlerContext ctx, int streamId); + } + + private final class HeadersFrame extends Frame { + final Http2Headers headers; + final int streamDependency; + final short weight; + final boolean exclusive; + final int padding; + final boolean endOfStream; + + HeadersFrame(Http2Headers headers, int streamDependency, short weight, boolean exclusive, + int padding, boolean endOfStream, ChannelPromise promise) { + super(promise); + this.headers = headers; + this.streamDependency = streamDependency; + this.weight = weight; + this.exclusive = exclusive; + this.padding = padding; + this.endOfStream = endOfStream; + } + + @Override + @SuppressWarnings("CheckReturnValue") + void send(ChannelHandlerContext ctx, int streamId) { + writeHeaders( + ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, + promise); + } + } + + private final class DataFrame extends Frame { + final ByteBuf data; + final int padding; + final boolean endOfStream; + + DataFrame(ByteBuf data, int padding, boolean endOfStream, ChannelPromise promise) { + super(promise); + this.data = data; + this.padding = padding; + this.endOfStream = endOfStream; + } + + @Override + void release(Throwable t) { + super.release(t); + ReferenceCountUtil.safeRelease(data); + } + + @Override + @SuppressWarnings("CheckReturnValue") + void send(ChannelHandlerContext ctx, int streamId) { + writeData(ctx, streamId, data, padding, endOfStream, promise); + } + } +} diff --git a/netty/src/test/java/io/grpc/netty/StreamBufferingEncoderTest.java b/netty/src/test/java/io/grpc/netty/StreamBufferingEncoderTest.java new file mode 100644 index 00000000000..d026ae7ca2f --- /dev/null +++ b/netty/src/test/java/io/grpc/netty/StreamBufferingEncoderTest.java @@ -0,0 +1,572 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS; +import static io.netty.handler.codec.http2.Http2Error.CANCEL; +import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.anyShort; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.channel.DefaultMessageSizeEstimator; +import io.netty.handler.codec.http2.DefaultHttp2Connection; +import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; +import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; +import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.DefaultHttp2LocalFlowController; +import io.netty.handler.codec.http2.DefaultHttp2RemoteFlowController; +import io.netty.handler.codec.http2.Http2Connection; +import io.netty.handler.codec.http2.Http2ConnectionHandler; +import io.netty.handler.codec.http2.Http2ConnectionHandlerBuilder; +import io.netty.handler.codec.http2.Http2Exception; +import io.netty.handler.codec.http2.Http2FrameListener; +import io.netty.handler.codec.http2.Http2FrameReader; +import io.netty.handler.codec.http2.Http2FrameSizePolicy; +import io.netty.handler.codec.http2.Http2FrameWriter; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.ImmediateEventExecutor; +import java.util.ArrayList; +import java.util.List; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.verification.VerificationMode; + +/** + * Tests for {@link StreamBufferingEncoder}. + */ +@SuppressWarnings("CheckReturnValue") // netty futures +public class StreamBufferingEncoderTest { + + private StreamBufferingEncoder encoder; + + private Http2Connection connection; + + @Mock + private Http2FrameWriter writer; + + @Mock + private ChannelHandlerContext ctx; + + @Mock + private Channel channel; + + @Mock + private Channel.Unsafe unsafe; + + @Mock + private ChannelConfig config; + + @Mock + private EventExecutor executor; + + /** + * Init fields and do mocking. + */ + @Before + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + + Http2FrameWriter.Configuration configuration = mock(Http2FrameWriter.Configuration.class); + Http2FrameSizePolicy frameSizePolicy = mock(Http2FrameSizePolicy.class); + when(writer.configuration()).thenReturn(configuration); + when(configuration.frameSizePolicy()).thenReturn(frameSizePolicy); + when(frameSizePolicy.maxFrameSize()).thenReturn(DEFAULT_MAX_FRAME_SIZE); + when(writer.writeData( + any(ChannelHandlerContext.class), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean(), + any(ChannelPromise.class))) + .thenAnswer(successAnswer()); + when(writer.writeRstStream(eq(ctx), anyInt(), anyLong(), any(ChannelPromise.class))).thenAnswer( + successAnswer()); + when(writer.writeGoAway( + any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class), + any(ChannelPromise.class))) + .thenAnswer(successAnswer()); + + connection = new DefaultHttp2Connection(false); + connection.remote().flowController(new DefaultHttp2RemoteFlowController(connection)); + connection.local() + .flowController(new DefaultHttp2LocalFlowController(connection).frameWriter(writer)); + + DefaultHttp2ConnectionEncoder defaultEncoder = + new DefaultHttp2ConnectionEncoder(connection, writer); + encoder = new StreamBufferingEncoder(defaultEncoder); + DefaultHttp2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(connection, encoder, mock(Http2FrameReader.class)); + Http2ConnectionHandler handler = new Http2ConnectionHandlerBuilder() + .frameListener(mock(Http2FrameListener.class)) + .codec(decoder, encoder).build(); + + // Set LifeCycleManager on encoder and decoder + when(ctx.channel()).thenReturn(channel); + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(executor.inEventLoop()).thenReturn(true); + doAnswer(new Answer() { + @Override + public ChannelPromise answer(InvocationOnMock invocation) throws Throwable { + return newPromise(); + } + }).when(ctx).newPromise(); + when(ctx.executor()).thenReturn(executor); + when(channel.isActive()).thenReturn(false); + when(channel.config()).thenReturn(config); + when(channel.isWritable()).thenReturn(true); + when(channel.bytesBeforeUnwritable()).thenReturn(Long.MAX_VALUE); + when(config.getWriteBufferHighWaterMark()).thenReturn(Integer.MAX_VALUE); + when(config.getMessageSizeEstimator()).thenReturn(DefaultMessageSizeEstimator.DEFAULT); + ChannelMetadata metadata = new ChannelMetadata(false, 16); + when(channel.metadata()).thenReturn(metadata); + when(channel.unsafe()).thenReturn(unsafe); + handler.handlerAdded(ctx); + } + + @After + public void teardown() { + // Close and release any buffered frames. + encoder.close(); + } + + @Test + public void multipleWritesToActiveStream() { + encoder.writeSettingsAck(ctx, newPromise()); + encoderWriteHeaders(3, newPromise()); + assertEquals(0, encoder.numBufferedStreams()); + ByteBuf data = data(); + final int expectedBytes = data.readableBytes() * 3; + encoder.writeData(ctx, 3, data, 0, false, newPromise()); + encoder.writeData(ctx, 3, data(), 0, false, newPromise()); + encoder.writeData(ctx, 3, data(), 0, false, newPromise()); + encoderWriteHeaders(3, newPromise()); + + writeVerifyWriteHeaders(times(2), 3); + // Contiguous data writes are coalesced + ArgumentCaptor bufCaptor = ArgumentCaptor.forClass(ByteBuf.class); + verify(writer, times(1)).writeData( + eq(ctx), eq(3), bufCaptor.capture(), eq(0), eq(false), any(ChannelPromise.class)); + assertEquals(expectedBytes, bufCaptor.getValue().readableBytes()); + } + + @Test + public void ensureCanCreateNextStreamWhenStreamCloses() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(1); + + encoderWriteHeaders(3, newPromise()); + assertEquals(0, encoder.numBufferedStreams()); + + // This one gets buffered. + encoderWriteHeaders(5, newPromise()); + assertEquals(1, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + + // Now prevent us from creating another stream. + setMaxConcurrentStreams(0); + + // Close the previous stream. + connection.stream(3).close(); + + // Ensure that no streams are currently active and that only the HEADERS from the first + // stream were written. + writeVerifyWriteHeaders(times(1), 3); + writeVerifyWriteHeaders(never(), 5); + assertEquals(0, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + } + + @Test + public void alternatingWritesToActiveAndBufferedStreams() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(1); + + encoderWriteHeaders(3, newPromise()); + assertEquals(0, encoder.numBufferedStreams()); + + encoderWriteHeaders(5, newPromise()); + assertEquals(1, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + + encoder.writeData(ctx, 3, EMPTY_BUFFER, 0, false, newPromise()); + writeVerifyWriteHeaders(times(1), 3); + encoder.writeData(ctx, 5, EMPTY_BUFFER, 0, false, newPromise()); + verify(writer, never()) + .writeData(eq(ctx), eq(5), any(ByteBuf.class), eq(0), eq(false), eq(newPromise())); + } + + @Test + public void bufferingNewStreamFailsAfterGoAwayReceived() throws Http2Exception { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(0); + connection.goAwayReceived(1, 8, EMPTY_BUFFER); + + ChannelPromise promise = newPromise(); + encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); + assertTrue(promise.isDone()); + assertFalse(promise.isSuccess()); + } + + @Test + public void receivingGoAwayFailsBufferedStreams() throws Http2Exception { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(5); + + int streamId = 3; + List futures = new ArrayList(); + for (int i = 0; i < 9; i++) { + futures.add(encoderWriteHeaders(streamId, newPromise())); + streamId += 2; + } + assertEquals(4, encoder.numBufferedStreams()); + + connection.goAwayReceived(11, 8, EMPTY_BUFFER); + + assertEquals(5, connection.numActiveStreams()); + int failCount = 0; + for (ChannelFuture f : futures) { + if (f.cause() != null) { + failCount++; + } + } + assertEquals(9, failCount); + assertEquals(0, encoder.numBufferedStreams()); + } + + @Test + public void sendingGoAwayShouldNotFailStreams() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(1); + + when(writer.writeHeaders( + any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), anyInt(), + anyBoolean(), any(ChannelPromise.class))) + .thenAnswer(successAnswer()); + when(writer.writeHeaders( + any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), anyInt(), + anyShort(), anyBoolean(), anyInt(), anyBoolean(), any(ChannelPromise.class))) + .thenAnswer(successAnswer()); + + ChannelFuture f1 = encoderWriteHeaders(3, newPromise()); + assertEquals(0, encoder.numBufferedStreams()); + ChannelFuture f2 = encoderWriteHeaders(5, newPromise()); + assertEquals(1, encoder.numBufferedStreams()); + ChannelFuture f3 = encoderWriteHeaders(7, newPromise()); + assertEquals(2, encoder.numBufferedStreams()); + + ByteBuf empty = Unpooled.buffer(0); + encoder.writeGoAway(ctx, 3, CANCEL.code(), empty, newPromise()); + + assertEquals(1, connection.numActiveStreams()); + assertEquals(2, encoder.numBufferedStreams()); + assertFalse(f1.isDone()); + assertFalse(f2.isDone()); + assertFalse(f3.isDone()); + } + + @Test + public void endStreamDoesNotFailBufferedStream() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(0); + + encoderWriteHeaders(3, newPromise()); + assertEquals(1, encoder.numBufferedStreams()); + + encoder.writeData(ctx, 3, EMPTY_BUFFER, 0, true, newPromise()); + + assertEquals(0, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + + // Simulate that we received a SETTINGS frame which + // increased MAX_CONCURRENT_STREAMS to 1. + setMaxConcurrentStreams(1); + encoder.writeSettingsAck(ctx, newPromise()); + + assertEquals(1, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); + assertEquals(HALF_CLOSED_LOCAL, connection.stream(3).state()); + } + + @Test + public void rstStreamClosesBufferedStream() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(0); + + encoderWriteHeaders(3, newPromise()); + assertEquals(1, encoder.numBufferedStreams()); + + ChannelPromise rstStreamPromise = newPromise(); + encoder.writeRstStream(ctx, 3, CANCEL.code(), rstStreamPromise); + assertTrue(rstStreamPromise.isSuccess()); + assertEquals(0, encoder.numBufferedStreams()); + } + + @Test + public void bufferUntilActiveStreamsAreReset() throws Exception { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(1); + + encoderWriteHeaders(3, newPromise()); + assertEquals(0, encoder.numBufferedStreams()); + encoderWriteHeaders(5, newPromise()); + assertEquals(1, encoder.numBufferedStreams()); + encoderWriteHeaders(7, newPromise()); + assertEquals(2, encoder.numBufferedStreams()); + + writeVerifyWriteHeaders(times(1), 3); + writeVerifyWriteHeaders(never(), 5); + writeVerifyWriteHeaders(never(), 7); + + encoder.writeRstStream(ctx, 3, CANCEL.code(), newPromise()); + connection.remote().flowController().writePendingBytes(); + writeVerifyWriteHeaders(times(1), 5); + writeVerifyWriteHeaders(never(), 7); + assertEquals(1, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + + encoder.writeRstStream(ctx, 5, CANCEL.code(), newPromise()); + connection.remote().flowController().writePendingBytes(); + writeVerifyWriteHeaders(times(1), 7); + assertEquals(1, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); + + encoder.writeRstStream(ctx, 7, CANCEL.code(), newPromise()); + assertEquals(0, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); + } + + @Test + public void bufferUntilMaxStreamsIncreased() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(2); + + encoderWriteHeaders(3, newPromise()); + encoderWriteHeaders(5, newPromise()); + encoderWriteHeaders(7, newPromise()); + encoderWriteHeaders(9, newPromise()); + assertEquals(2, encoder.numBufferedStreams()); + + writeVerifyWriteHeaders(times(1), 3); + writeVerifyWriteHeaders(times(1), 5); + writeVerifyWriteHeaders(never(), 7); + writeVerifyWriteHeaders(never(), 9); + + // Simulate that we received a SETTINGS frame which + // increased MAX_CONCURRENT_STREAMS to 5. + setMaxConcurrentStreams(5); + encoder.writeSettingsAck(ctx, newPromise()); + + assertEquals(0, encoder.numBufferedStreams()); + writeVerifyWriteHeaders(times(1), 7); + writeVerifyWriteHeaders(times(1), 9); + + encoderWriteHeaders(11, newPromise()); + + writeVerifyWriteHeaders(times(1), 11); + + assertEquals(5, connection.local().numActiveStreams()); + } + + @Test + public void bufferUntilSettingsReceived() throws Http2Exception { + int initialLimit = SMALLEST_MAX_CONCURRENT_STREAMS; + int numStreams = initialLimit * 2; + for (int ix = 0, nextStreamId = 3; ix < numStreams; ++ix, nextStreamId += 2) { + encoderWriteHeaders(nextStreamId, newPromise()); + if (ix < initialLimit) { + writeVerifyWriteHeaders(times(1), nextStreamId); + } else { + writeVerifyWriteHeaders(never(), nextStreamId); + } + } + assertEquals(numStreams / 2, encoder.numBufferedStreams()); + + // Simulate that we received a SETTINGS frame. + setMaxConcurrentStreams(initialLimit * 2); + + assertEquals(0, encoder.numBufferedStreams()); + assertEquals(numStreams, connection.local().numActiveStreams()); + } + + @Test + public void bufferUntilSettingsReceivedWithNoMaxConcurrentStreamValue() throws Http2Exception { + int initialLimit = SMALLEST_MAX_CONCURRENT_STREAMS; + int numStreams = initialLimit * 2; + for (int ix = 0, nextStreamId = 3; ix < numStreams; ++ix, nextStreamId += 2) { + encoderWriteHeaders(nextStreamId, newPromise()); + if (ix < initialLimit) { + writeVerifyWriteHeaders(times(1), nextStreamId); + } else { + writeVerifyWriteHeaders(never(), nextStreamId); + } + } + assertEquals(numStreams / 2, encoder.numBufferedStreams()); + + // Simulate that we received an empty SETTINGS frame. + encoder.remoteSettings(new Http2Settings()); + + assertEquals(0, encoder.numBufferedStreams()); + assertEquals(numStreams, connection.local().numActiveStreams()); + } + + @Test + public void exhaustedStreamsDoNotBuffer() throws Http2Exception { + // Write the highest possible stream ID for the client. + // This will cause the next stream ID to be negative. + encoderWriteHeaders(Integer.MAX_VALUE, newPromise()); + + // Disallow any further streams. + setMaxConcurrentStreams(0); + + // Simulate numeric overflow for the next stream ID. + ChannelFuture f = encoderWriteHeaders(-1, newPromise()); + + // Verify that the write fails. + assertNotNull(f.cause()); + } + + @Test + public void closedBufferedStreamReleasesByteBuf() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(0); + ByteBuf data = mock(ByteBuf.class); + ChannelFuture f1 = encoderWriteHeaders(3, newPromise()); + assertEquals(1, encoder.numBufferedStreams()); + ChannelFuture f2 = encoder.writeData(ctx, 3, data, 0, false, newPromise()); + + ChannelPromise rstPromise = mock(ChannelPromise.class); + encoder.writeRstStream(ctx, 3, CANCEL.code(), rstPromise); + + assertEquals(0, encoder.numBufferedStreams()); + verify(rstPromise).setSuccess(); + assertTrue(f1.isSuccess()); + assertTrue(f2.isSuccess()); + verify(data).release(); + } + + @Test + public void closeShouldCancelAllBufferedStreams() throws Http2Exception { + encoder.writeSettingsAck(ctx, newPromise()); + connection.local().maxActiveStreams(0); + + ChannelFuture f1 = encoderWriteHeaders(3, newPromise()); + ChannelFuture f2 = encoderWriteHeaders(5, newPromise()); + ChannelFuture f3 = encoderWriteHeaders(7, newPromise()); + + encoder.close(); + assertNotNull(f1.cause()); + assertNotNull(f2.cause()); + assertNotNull(f3.cause()); + } + + @Test + public void headersAfterCloseShouldImmediatelyFail() { + encoder.writeSettingsAck(ctx, newPromise()); + encoder.close(); + + ChannelFuture f = encoderWriteHeaders(3, newPromise()); + assertNotNull(f.cause()); + } + + private void setMaxConcurrentStreams(int newValue) { + try { + encoder.remoteSettings(new Http2Settings().maxConcurrentStreams(newValue)); + // Flush the remote flow controller to write data + encoder.flowController().writePendingBytes(); + } catch (Http2Exception e) { + throw new RuntimeException(e); + } + } + + private ChannelFuture encoderWriteHeaders(int streamId, ChannelPromise promise) { + encoder.writeHeaders(ctx, streamId, new DefaultHttp2Headers(), 0, DEFAULT_PRIORITY_WEIGHT, + false, 0, false, promise); + try { + encoder.flowController().writePendingBytes(); + return promise; + } catch (Http2Exception e) { + throw new RuntimeException(e); + } + } + + private void writeVerifyWriteHeaders(VerificationMode mode, int streamId) { + verify(writer, mode).writeHeaders(eq(ctx), eq(streamId), any(Http2Headers.class), eq(0), + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), + eq(false), any(ChannelPromise.class)); + } + + private Answer successAnswer() { + return new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { + for (Object a : invocation.getArguments()) { + ReferenceCountUtil.safeRelease(a); + } + + ChannelPromise future = newPromise(); + future.setSuccess(); + return future; + } + }; + } + + private ChannelPromise newPromise() { + return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + } + + private static ByteBuf data() { + ByteBuf buf = Unpooled.buffer(10); + for (int i = 0; i < buf.writableBytes(); i++) { + buf.writeByte(i); + } + return buf; + } +} From ab719bd48029949d9e3232e4669275641443d78b Mon Sep 17 00:00:00 2001 From: "Penn (Dapeng) Zhang" Date: Thu, 25 Mar 2021 17:35:12 -0700 Subject: [PATCH 3/8] fix StreamBufferingEncoder --- .../main/java/io/grpc/netty/StreamBufferingEncoder.java | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java b/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java index a37221b1284..ab0376cd199 100644 --- a/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java +++ b/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java @@ -60,6 +60,8 @@ *

This implementation makes the buffering mostly transparent and is expected to be used as a * drop-in decorator of {@link io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder}. */ +// This is a temporary copy of {@link io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoder} +// with a bug fix that is not available yet in the latest netty release. class StreamBufferingEncoder extends DecoratingHttp2ConnectionEncoder { /** @@ -153,11 +155,7 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2 if (closed) { return promise.setFailure(new Http2ChannelClosedException()); } - if (isExistingStream(streamId) || connection().goAwayReceived()) { - return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, - exclusive, padding, endOfStream, promise); - } - if (canCreateStream()) { + if (isExistingStream(streamId) || canCreateStream()) { return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, promise); } From 43324efae0463cdbada292c86c744a758366282d Mon Sep 17 00:00:00 2001 From: "Penn (Dapeng) Zhang" Date: Mon, 5 Apr 2021 12:18:42 -0700 Subject: [PATCH 4/8] Revert "fix StreamBufferingEncoder" This reverts commit 95060b2c3125c9c65897883cae1244996a6a7845. --- .../main/java/io/grpc/netty/StreamBufferingEncoder.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java b/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java index ab0376cd199..a37221b1284 100644 --- a/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java +++ b/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java @@ -60,8 +60,6 @@ *

This implementation makes the buffering mostly transparent and is expected to be used as a * drop-in decorator of {@link io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder}. */ -// This is a temporary copy of {@link io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoder} -// with a bug fix that is not available yet in the latest netty release. class StreamBufferingEncoder extends DecoratingHttp2ConnectionEncoder { /** @@ -155,7 +153,11 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2 if (closed) { return promise.setFailure(new Http2ChannelClosedException()); } - if (isExistingStream(streamId) || canCreateStream()) { + if (isExistingStream(streamId) || connection().goAwayReceived()) { + return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, + exclusive, padding, endOfStream, promise); + } + if (canCreateStream()) { return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, promise); } From 723ef2534f9b1fdbd3e63cd585b23133bd1bef3b Mon Sep 17 00:00:00 2001 From: "Penn (Dapeng) Zhang" Date: Tue, 6 Apr 2021 12:12:34 -0700 Subject: [PATCH 5/8] fix StreamBufferingEncoder approach 2 --- .../io/grpc/netty/StreamBufferingEncoder.java | 8 +- .../io/grpc/netty/NettyClientHandlerTest.java | 74 ++----------------- .../netty/StreamBufferingEncoderTest.java | 1 + 3 files changed, 15 insertions(+), 68 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java b/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java index a37221b1284..f5a2dcc122b 100644 --- a/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java +++ b/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java @@ -41,6 +41,8 @@ import java.util.Queue; import java.util.TreeMap; +// This is a temporary copy of {@link io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoder} +// with a bug fix that is not available yet in the latest netty release. /** * Implementation of a {@link Http2ConnectionEncoder} that dispatches all method call to another * {@link Http2ConnectionEncoder}, until {@code SETTINGS_MAX_CONCURRENT_STREAMS} is reached. @@ -153,10 +155,14 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2 if (closed) { return promise.setFailure(new Http2ChannelClosedException()); } - if (isExistingStream(streamId) || connection().goAwayReceived()) { + if (isExistingStream(streamId)) { return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, promise); } + if (connection().goAwayReceived()) { + promise.setFailure(new Http2Exception(Http2Error.NO_ERROR, "GOAWAY received")); + return promise; + } if (canCreateStream()) { return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, promise); diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index 223f4781538..99dfe905b5a 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -384,23 +384,20 @@ public void receivedAbruptGoAwayShouldFailRacingQueuedStreamid() throws Exceptio } @Test - public void receivedGoAway_shouldNotFailBufferedStreamWithStreamIdLessThanLastId() + public void receivedGoAway_shouldFailBufferedStreams() throws Exception { - ClientStreamListener streamListener1 = mock(ClientStreamListener.class); NettyClientStream.TransportState streamTransportState1 = new TransportStateImpl( handler(), channel().eventLoop(), DEFAULT_MAX_MESSAGE_SIZE, transportTracer); - streamTransportState1.setListener(streamListener1); - ClientStreamListener streamListener2 = mock(ClientStreamListener.class); + streamTransportState1.setListener(mock(ClientStreamListener.class)); NettyClientStream.TransportState streamTransportState2 = new TransportStateImpl( handler(), channel().eventLoop(), DEFAULT_MAX_MESSAGE_SIZE, transportTracer); - streamTransportState2.setListener(streamListener2); - // MAX_CONCURRENT_STREAMS=1 + streamTransportState2.setListener(mock(ClientStreamListener.class)); receiveMaxConcurrentStreams(1); ChannelFuture future1 = writeQueue().enqueue( newCreateStreamCommand(grpcHeaders, streamTransportState1), true); @@ -409,67 +406,10 @@ public void receivedGoAway_shouldNotFailBufferedStreamWithStreamIdLessThanLastId // GOAWAY channelRead(goAwayFrame(Integer.MAX_VALUE)); - assertTrue(future1.isSuccess()); - verify(streamListener1).onReady(); - channel().runPendingTasks(); - if (future2.cause() != null ) { - throw new AssertionError(future2.cause()); - } - assertFalse(future2.isDone()); - verify(streamListener2, never()).onReady(); - - // Let the first stream complete, then the pending stream will be activated. - Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK) - .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC); - channelRead(headersFrame(streamId, headers)); - channelRead(grpcDataFrame(streamId, false, contentAsArray())); - channelRead(trailersFrame( - streamId, - new DefaultHttp2Headers().set(AsciiString.of("grpc-status"), AsciiString.of("0")))); - streamTransportState1.requestMessagesFromDeframerForTesting(1); - verify(streamListener1).closed(eq(Status.OK), any(RpcProgress.class), - any(Metadata.class)); - channel().runPendingTasks(); - if (future2.cause() != null) { - throw new AssertionError(future2.cause()); - } - assertTrue(future2.isSuccess()); - verify(streamListener2).onReady(); - } - - @Test - public void receivedGoAway_shouldRefuseBufferedStreamWithStreamIdGreaterThanLastId() - throws Exception { - ClientStreamListener streamListener1 = mock(ClientStreamListener.class); - NettyClientStream.TransportState streamTransportState1 = new TransportStateImpl( - handler(), - channel().eventLoop(), - DEFAULT_MAX_MESSAGE_SIZE, - transportTracer); - streamTransportState1.setListener(streamListener1); - ClientStreamListener streamListener2 = mock(ClientStreamListener.class); - NettyClientStream.TransportState streamTransportState2 = new TransportStateImpl( - handler(), - channel().eventLoop(), - DEFAULT_MAX_MESSAGE_SIZE, - transportTracer); - streamTransportState2.setListener(streamListener2); - // MAX_CONCURRENT_STREAMS=1 - receiveMaxConcurrentStreams(1); - ChannelFuture future1 = writeQueue().enqueue( - newCreateStreamCommand(grpcHeaders, streamTransportState1), true); - ChannelFuture future2 = writeQueue().enqueue( - newCreateStreamCommand(grpcHeaders, streamTransportState2), true); - - // GOAWAY - channelRead(goAwayFrame(streamId)); - assertTrue(future1.isSuccess()); - verify(streamListener1).onReady(); - assertThat(future2.cause()).isNotNull(); - assertThat(Status.fromThrowable(future2.cause()).getCode()).isEqualTo(Status.Code.UNAVAILABLE); - assertThat(Status.fromThrowable(future2.cause()).getDescription()).isEqualTo( - "Abrupt GOAWAY closed unsent stream. HTTP/2 error code: NO_ERROR"); - verify(streamListener2).closed(any(Status.class), eq(RpcProgress.REFUSED), any(Metadata.class)); + assertTrue(future1.isDone()); + assertThat(future1.cause().getMessage()).contains("GOAWAY received"); + assertTrue(future2.isDone()); + assertThat(future2.cause().getMessage()).contains("GOAWAY received"); } @Test diff --git a/netty/src/test/java/io/grpc/netty/StreamBufferingEncoderTest.java b/netty/src/test/java/io/grpc/netty/StreamBufferingEncoderTest.java index d026ae7ca2f..318d35f604b 100644 --- a/netty/src/test/java/io/grpc/netty/StreamBufferingEncoderTest.java +++ b/netty/src/test/java/io/grpc/netty/StreamBufferingEncoderTest.java @@ -81,6 +81,7 @@ import org.mockito.stubbing.Answer; import org.mockito.verification.VerificationMode; +// This is a temporary copy of io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoderTest. /** * Tests for {@link StreamBufferingEncoder}. */ From e16fb22d83a5382c403fe475e63985b8a86ad57c Mon Sep 17 00:00:00 2001 From: "Penn (Dapeng) Zhang" Date: Fri, 16 Apr 2021 10:58:11 -0700 Subject: [PATCH 6/8] fix locally in NettyClientHandler --- .../io/grpc/netty/NettyClientHandler.java | 37 +- .../io/grpc/netty/StreamBufferingEncoder.java | 381 ------------ .../io/grpc/netty/NettyClientHandlerTest.java | 9 +- .../netty/StreamBufferingEncoderTest.java | 573 ------------------ 4 files changed, 29 insertions(+), 971 deletions(-) delete mode 100644 netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java delete mode 100644 netty/src/test/java/io/grpc/netty/StreamBufferingEncoderTest.java diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index ecd3bb086b6..bdcac924252 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -74,6 +74,7 @@ import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.codec.http2.Http2Stream; import io.netty.handler.codec.http2.Http2StreamVisitor; +import io.netty.handler.codec.http2.StreamBufferingEncoder; import io.netty.handler.codec.http2.WeightedFairQueueByteDistributor; import io.netty.handler.logging.LogLevel; import io.perfmark.PerfMark; @@ -568,20 +569,22 @@ private void createStream(CreateStreamCommand command, ChannelPromise promise) } return; } - if (connection().goAwayReceived() - && streamId > connection().local().lastStreamKnownByPeer()) { - // This should only be reachable during onGoAwayReceived, as otherwise - // getShutdownThrowable() != null - command.stream().setNonExistent(); - Status s = abruptGoAwayStatus; - if (s == null) { - // Should be impossible, but handle psuedo-gracefully - s = Status.INTERNAL.withDescription( - "Failed due to abrupt GOAWAY, but can't find GOAWAY details"); + if (connection().goAwayReceived()) { + if (streamId > connection().local().lastStreamKnownByPeer() + || connection().local().numActiveStreams() == connection().local().maxActiveStreams()) { + // This should only be reachable during onGoAwayReceived, as otherwise + // getShutdownThrowable() != null + command.stream().setNonExistent(); + Status s = abruptGoAwayStatus; + if (s == null) { + // Should be impossible, but handle psuedo-gracefully + s = Status.INTERNAL.withDescription( + "Failed due to abrupt GOAWAY, but can't find GOAWAY details"); + } + command.stream().transportReportStatus(s, RpcProgress.REFUSED, true, new Metadata()); + promise.setFailure(s.asRuntimeException()); + return; } - command.stream().transportReportStatus(s, RpcProgress.REFUSED, true, new Metadata()); - promise.setFailure(s.asRuntimeException()); - return; } NettyClientStream.TransportState stream = command.stream(); @@ -608,6 +611,14 @@ private void createStreamTraced( // Create an intermediate promise so that we can intercept the failure reported back to the // application. ChannelPromise tempPromise = ctx().newPromise(); + if (connection().goAwayReceived() + && connection().local().numActiveStreams() == connection().local().maxActiveStreams()) { + Status status = Status.UNAVAILABLE.withCause( + new Http2Exception(Http2Error.REFUSED_STREAM, "GOAWAY received")); + stream.transportReportStatus(status, RpcProgress.REFUSED, true, new Metadata()); + promise.setFailure(status.asRuntimeException()); + return; + } encoder().writeHeaders(ctx(), streamId, headers, 0, isGet, tempPromise) .addListener(new ChannelFutureListener() { @Override diff --git a/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java b/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java deleted file mode 100644 index f5a2dcc122b..00000000000 --- a/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java +++ /dev/null @@ -1,381 +0,0 @@ -/* - * Copyright 2021 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.netty; - -import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS; -import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; -import static io.netty.handler.codec.http2.Http2Exception.connectionError; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoder; -import io.netty.handler.codec.http2.Http2CodecUtil; -import io.netty.handler.codec.http2.Http2ConnectionAdapter; -import io.netty.handler.codec.http2.Http2ConnectionEncoder; -import io.netty.handler.codec.http2.Http2Error; -import io.netty.handler.codec.http2.Http2Exception; -import io.netty.handler.codec.http2.Http2Headers; -import io.netty.handler.codec.http2.Http2Settings; -import io.netty.handler.codec.http2.Http2Stream; -import io.netty.util.ReferenceCountUtil; -import java.util.ArrayDeque; -import java.util.Iterator; -import java.util.Map; -import java.util.Queue; -import java.util.TreeMap; - -// This is a temporary copy of {@link io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoder} -// with a bug fix that is not available yet in the latest netty release. -/** - * Implementation of a {@link Http2ConnectionEncoder} that dispatches all method call to another - * {@link Http2ConnectionEncoder}, until {@code SETTINGS_MAX_CONCURRENT_STREAMS} is reached. - * - *

When this limit is hit, instead of rejecting any new streams this implementation buffers newly - * created streams and their corresponding frames. Once an active stream gets closed or the maximum - * number of concurrent streams is increased, this encoder will automatically try to empty its - * buffer and create as many new streams as possible. - * - *

If a {@code GOAWAY} frame is received from the remote endpoint, all buffered writes for - * streams with an ID less than the specified {@code lastStreamId} will immediately fail with a - * {@link io.netty.handler.codec.http2.StreamBufferingEncoder.Http2GoAwayException}. - * - *

If the channel/encoder gets closed, all new and buffered writes will immediately fail with a - * {@link io.netty.handler.codec.http2.StreamBufferingEncoder.Http2ChannelClosedException}. - * - *

This implementation makes the buffering mostly transparent and is expected to be used as a - * drop-in decorator of {@link io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder}. - */ -class StreamBufferingEncoder extends DecoratingHttp2ConnectionEncoder { - - /** - * Thrown if buffered streams are terminated due to this encoder being closed. - */ - public static final class Http2ChannelClosedException extends Http2Exception { - private static final long serialVersionUID = 4768543442094476971L; - - public Http2ChannelClosedException() { - super(Http2Error.REFUSED_STREAM, "Connection closed"); - } - } - - /** - * Thrown by {@link StreamBufferingEncoder} if buffered streams are terminated due to - * receipt of a {@code GOAWAY}. - */ - public static final class Http2GoAwayException extends Http2Exception { - private static final long serialVersionUID = 1326785622777291198L; - private final int lastStreamId; - private final long errorCode; - private final byte[] debugData; - - public Http2GoAwayException(int lastStreamId, long errorCode, byte[] debugData) { - super(Http2Error.STREAM_CLOSED); - this.lastStreamId = lastStreamId; - this.errorCode = errorCode; - this.debugData = debugData; - } - - public int lastStreamId() { - return lastStreamId; - } - - public long errorCode() { - return errorCode; - } - - public byte[] debugData() { - return debugData; - } - } - - /** - * Buffer for any streams and corresponding frames that could not be created due to the maximum - * concurrent stream limit being hit. - */ - private final TreeMap pendingStreams = new TreeMap<>(); - private int maxConcurrentStreams; - private boolean closed; - - public StreamBufferingEncoder(Http2ConnectionEncoder delegate) { - this(delegate, SMALLEST_MAX_CONCURRENT_STREAMS); - } - - public StreamBufferingEncoder(Http2ConnectionEncoder delegate, int initialMaxConcurrentStreams) { - super(delegate); - this.maxConcurrentStreams = initialMaxConcurrentStreams; - connection().addListener(new Http2ConnectionAdapter() { - - @Override - public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) { - cancelGoAwayStreams(lastStreamId, errorCode, debugData); - } - - @Override - public void onStreamClosed(Http2Stream stream) { - tryCreatePendingStreams(); - } - }); - } - - /** - * Indicates the number of streams that are currently buffered, awaiting creation. - */ - public int numBufferedStreams() { - return pendingStreams.size(); - } - - @Override - public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, - int padding, boolean endStream, ChannelPromise promise) { - return writeHeaders(ctx, streamId, headers, 0, Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT, - false, padding, endStream, promise); - } - - @Override - public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, - int streamDependency, short weight, boolean exclusive, - int padding, boolean endOfStream, ChannelPromise promise) { - if (closed) { - return promise.setFailure(new Http2ChannelClosedException()); - } - if (isExistingStream(streamId)) { - return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, - exclusive, padding, endOfStream, promise); - } - if (connection().goAwayReceived()) { - promise.setFailure(new Http2Exception(Http2Error.NO_ERROR, "GOAWAY received")); - return promise; - } - if (canCreateStream()) { - return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, - exclusive, padding, endOfStream, promise); - } - PendingStream pendingStream = pendingStreams.get(streamId); - if (pendingStream == null) { - pendingStream = new PendingStream(ctx, streamId); - pendingStreams.put(streamId, pendingStream); - } - pendingStream.frames.add(new HeadersFrame(headers, streamDependency, weight, exclusive, - padding, endOfStream, promise)); - return promise; - } - - @Override - public ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode, - ChannelPromise promise) { - if (isExistingStream(streamId)) { - return super.writeRstStream(ctx, streamId, errorCode, promise); - } - // Since the delegate doesn't know about any buffered streams we have to handle cancellation - // of the promises and releasing of the ByteBufs here. - PendingStream stream = pendingStreams.remove(streamId); - if (stream != null) { - // Sending a RST_STREAM to a buffered stream will succeed the promise of all frames - // associated with the stream, as sending a RST_STREAM means that someone "doesn't care" - // about the stream anymore and thus there is not point in failing the promises and invoking - // error handling routines. - stream.close(null); - promise.setSuccess(); - } else { - promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId)); - } - return promise; - } - - @Override - public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data, - int padding, boolean endOfStream, ChannelPromise promise) { - if (isExistingStream(streamId)) { - return super.writeData(ctx, streamId, data, padding, endOfStream, promise); - } - PendingStream pendingStream = pendingStreams.get(streamId); - if (pendingStream != null) { - pendingStream.frames.add(new DataFrame(data, padding, endOfStream, promise)); - } else { - ReferenceCountUtil.safeRelease(data); - promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId)); - } - return promise; - } - - @Override - public void remoteSettings(Http2Settings settings) throws Http2Exception { - // Need to let the delegate decoder handle the settings first, so that it sees the - // new setting before we attempt to create any new streams. - super.remoteSettings(settings); - - // Get the updated value for SETTINGS_MAX_CONCURRENT_STREAMS. - maxConcurrentStreams = connection().local().maxActiveStreams(); - - // Try to create new streams up to the new threshold. - tryCreatePendingStreams(); - } - - @Override - public void close() { - try { - if (!closed) { - closed = true; - - // Fail all buffered streams. - Http2ChannelClosedException e = new Http2ChannelClosedException(); - while (!pendingStreams.isEmpty()) { - PendingStream stream = pendingStreams.pollFirstEntry().getValue(); - stream.close(e); - } - } - } finally { - super.close(); - } - } - - private void tryCreatePendingStreams() { - while (!pendingStreams.isEmpty() && canCreateStream()) { - Map.Entry entry = pendingStreams.pollFirstEntry(); - PendingStream pendingStream = entry.getValue(); - try { - pendingStream.sendFrames(); - } catch (Throwable t) { - pendingStream.close(t); - } - } - } - - private void cancelGoAwayStreams(int lastStreamId, long errorCode, ByteBuf debugData) { - Iterator iter = pendingStreams.values().iterator(); - Exception e = - new Http2GoAwayException(lastStreamId, errorCode, ByteBufUtil.getBytes(debugData)); - while (iter.hasNext()) { - PendingStream stream = iter.next(); - if (stream.streamId > lastStreamId) { - iter.remove(); - stream.close(e); - } - } - } - - /** - * Determines whether or not we're allowed to create a new stream right now. - */ - private boolean canCreateStream() { - return connection().local().numActiveStreams() < maxConcurrentStreams; - } - - private boolean isExistingStream(int streamId) { - return streamId <= connection().local().lastStreamCreated(); - } - - private static final class PendingStream { - final ChannelHandlerContext ctx; - final int streamId; - final Queue frames = new ArrayDeque<>(2); - - PendingStream(ChannelHandlerContext ctx, int streamId) { - this.ctx = ctx; - this.streamId = streamId; - } - - void sendFrames() { - for (Frame frame : frames) { - frame.send(ctx, streamId); - } - } - - void close(Throwable t) { - for (Frame frame : frames) { - frame.release(t); - } - } - } - - private abstract static class Frame { - final ChannelPromise promise; - - Frame(ChannelPromise promise) { - this.promise = promise; - } - - /** - * Release any resources (features, buffers, ...) associated with the frame. - */ - void release(Throwable t) { - if (t == null) { - promise.setSuccess(); - } else { - promise.setFailure(t); - } - } - - abstract void send(ChannelHandlerContext ctx, int streamId); - } - - private final class HeadersFrame extends Frame { - final Http2Headers headers; - final int streamDependency; - final short weight; - final boolean exclusive; - final int padding; - final boolean endOfStream; - - HeadersFrame(Http2Headers headers, int streamDependency, short weight, boolean exclusive, - int padding, boolean endOfStream, ChannelPromise promise) { - super(promise); - this.headers = headers; - this.streamDependency = streamDependency; - this.weight = weight; - this.exclusive = exclusive; - this.padding = padding; - this.endOfStream = endOfStream; - } - - @Override - @SuppressWarnings("CheckReturnValue") - void send(ChannelHandlerContext ctx, int streamId) { - writeHeaders( - ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, - promise); - } - } - - private final class DataFrame extends Frame { - final ByteBuf data; - final int padding; - final boolean endOfStream; - - DataFrame(ByteBuf data, int padding, boolean endOfStream, ChannelPromise promise) { - super(promise); - this.data = data; - this.padding = padding; - this.endOfStream = endOfStream; - } - - @Override - void release(Throwable t) { - super.release(t); - ReferenceCountUtil.safeRelease(data); - } - - @Override - @SuppressWarnings("CheckReturnValue") - void send(ChannelHandlerContext ctx, int streamId) { - writeData(ctx, streamId, data, padding, endOfStream, promise); - } - } -} diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index 99dfe905b5a..b901ceeb642 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -384,7 +384,7 @@ public void receivedAbruptGoAwayShouldFailRacingQueuedStreamid() throws Exceptio } @Test - public void receivedGoAway_shouldFailBufferedStreams() + public void receivedGoAway_shouldFailBufferedStreamsExceedingMaxConcurrentStreams() throws Exception { NettyClientStream.TransportState streamTransportState1 = new TransportStateImpl( handler(), @@ -406,10 +406,11 @@ public void receivedGoAway_shouldFailBufferedStreams() // GOAWAY channelRead(goAwayFrame(Integer.MAX_VALUE)); - assertTrue(future1.isDone()); - assertThat(future1.cause().getMessage()).contains("GOAWAY received"); + assertTrue(future1.isSuccess()); assertTrue(future2.isDone()); - assertThat(future2.cause().getMessage()).contains("GOAWAY received"); + assertThat(Status.fromThrowable(future2.cause()).getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(future2.cause().getMessage()).contains( + "Abrupt GOAWAY closed unsent stream. HTTP/2 error code: NO_ERROR"); } @Test diff --git a/netty/src/test/java/io/grpc/netty/StreamBufferingEncoderTest.java b/netty/src/test/java/io/grpc/netty/StreamBufferingEncoderTest.java deleted file mode 100644 index 318d35f604b..00000000000 --- a/netty/src/test/java/io/grpc/netty/StreamBufferingEncoderTest.java +++ /dev/null @@ -1,573 +0,0 @@ -/* - * Copyright 2021 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.netty; - -import static io.netty.buffer.Unpooled.EMPTY_BUFFER; -import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE; -import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; -import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS; -import static io.netty.handler.codec.http2.Http2Error.CANCEL; -import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.anyBoolean; -import static org.mockito.Mockito.anyInt; -import static org.mockito.Mockito.anyLong; -import static org.mockito.Mockito.anyShort; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.buffer.UnpooledByteBufAllocator; -import io.netty.channel.Channel; -import io.netty.channel.ChannelConfig; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelMetadata; -import io.netty.channel.ChannelPromise; -import io.netty.channel.DefaultChannelPromise; -import io.netty.channel.DefaultMessageSizeEstimator; -import io.netty.handler.codec.http2.DefaultHttp2Connection; -import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; -import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; -import io.netty.handler.codec.http2.DefaultHttp2Headers; -import io.netty.handler.codec.http2.DefaultHttp2LocalFlowController; -import io.netty.handler.codec.http2.DefaultHttp2RemoteFlowController; -import io.netty.handler.codec.http2.Http2Connection; -import io.netty.handler.codec.http2.Http2ConnectionHandler; -import io.netty.handler.codec.http2.Http2ConnectionHandlerBuilder; -import io.netty.handler.codec.http2.Http2Exception; -import io.netty.handler.codec.http2.Http2FrameListener; -import io.netty.handler.codec.http2.Http2FrameReader; -import io.netty.handler.codec.http2.Http2FrameSizePolicy; -import io.netty.handler.codec.http2.Http2FrameWriter; -import io.netty.handler.codec.http2.Http2Headers; -import io.netty.handler.codec.http2.Http2Settings; -import io.netty.util.ReferenceCountUtil; -import io.netty.util.concurrent.EventExecutor; -import io.netty.util.concurrent.ImmediateEventExecutor; -import java.util.ArrayList; -import java.util.List; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; -import org.mockito.verification.VerificationMode; - -// This is a temporary copy of io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoderTest. -/** - * Tests for {@link StreamBufferingEncoder}. - */ -@SuppressWarnings("CheckReturnValue") // netty futures -public class StreamBufferingEncoderTest { - - private StreamBufferingEncoder encoder; - - private Http2Connection connection; - - @Mock - private Http2FrameWriter writer; - - @Mock - private ChannelHandlerContext ctx; - - @Mock - private Channel channel; - - @Mock - private Channel.Unsafe unsafe; - - @Mock - private ChannelConfig config; - - @Mock - private EventExecutor executor; - - /** - * Init fields and do mocking. - */ - @Before - public void setup() throws Exception { - MockitoAnnotations.initMocks(this); - - Http2FrameWriter.Configuration configuration = mock(Http2FrameWriter.Configuration.class); - Http2FrameSizePolicy frameSizePolicy = mock(Http2FrameSizePolicy.class); - when(writer.configuration()).thenReturn(configuration); - when(configuration.frameSizePolicy()).thenReturn(frameSizePolicy); - when(frameSizePolicy.maxFrameSize()).thenReturn(DEFAULT_MAX_FRAME_SIZE); - when(writer.writeData( - any(ChannelHandlerContext.class), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean(), - any(ChannelPromise.class))) - .thenAnswer(successAnswer()); - when(writer.writeRstStream(eq(ctx), anyInt(), anyLong(), any(ChannelPromise.class))).thenAnswer( - successAnswer()); - when(writer.writeGoAway( - any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class), - any(ChannelPromise.class))) - .thenAnswer(successAnswer()); - - connection = new DefaultHttp2Connection(false); - connection.remote().flowController(new DefaultHttp2RemoteFlowController(connection)); - connection.local() - .flowController(new DefaultHttp2LocalFlowController(connection).frameWriter(writer)); - - DefaultHttp2ConnectionEncoder defaultEncoder = - new DefaultHttp2ConnectionEncoder(connection, writer); - encoder = new StreamBufferingEncoder(defaultEncoder); - DefaultHttp2ConnectionDecoder decoder = - new DefaultHttp2ConnectionDecoder(connection, encoder, mock(Http2FrameReader.class)); - Http2ConnectionHandler handler = new Http2ConnectionHandlerBuilder() - .frameListener(mock(Http2FrameListener.class)) - .codec(decoder, encoder).build(); - - // Set LifeCycleManager on encoder and decoder - when(ctx.channel()).thenReturn(channel); - when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); - when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); - when(executor.inEventLoop()).thenReturn(true); - doAnswer(new Answer() { - @Override - public ChannelPromise answer(InvocationOnMock invocation) throws Throwable { - return newPromise(); - } - }).when(ctx).newPromise(); - when(ctx.executor()).thenReturn(executor); - when(channel.isActive()).thenReturn(false); - when(channel.config()).thenReturn(config); - when(channel.isWritable()).thenReturn(true); - when(channel.bytesBeforeUnwritable()).thenReturn(Long.MAX_VALUE); - when(config.getWriteBufferHighWaterMark()).thenReturn(Integer.MAX_VALUE); - when(config.getMessageSizeEstimator()).thenReturn(DefaultMessageSizeEstimator.DEFAULT); - ChannelMetadata metadata = new ChannelMetadata(false, 16); - when(channel.metadata()).thenReturn(metadata); - when(channel.unsafe()).thenReturn(unsafe); - handler.handlerAdded(ctx); - } - - @After - public void teardown() { - // Close and release any buffered frames. - encoder.close(); - } - - @Test - public void multipleWritesToActiveStream() { - encoder.writeSettingsAck(ctx, newPromise()); - encoderWriteHeaders(3, newPromise()); - assertEquals(0, encoder.numBufferedStreams()); - ByteBuf data = data(); - final int expectedBytes = data.readableBytes() * 3; - encoder.writeData(ctx, 3, data, 0, false, newPromise()); - encoder.writeData(ctx, 3, data(), 0, false, newPromise()); - encoder.writeData(ctx, 3, data(), 0, false, newPromise()); - encoderWriteHeaders(3, newPromise()); - - writeVerifyWriteHeaders(times(2), 3); - // Contiguous data writes are coalesced - ArgumentCaptor bufCaptor = ArgumentCaptor.forClass(ByteBuf.class); - verify(writer, times(1)).writeData( - eq(ctx), eq(3), bufCaptor.capture(), eq(0), eq(false), any(ChannelPromise.class)); - assertEquals(expectedBytes, bufCaptor.getValue().readableBytes()); - } - - @Test - public void ensureCanCreateNextStreamWhenStreamCloses() { - encoder.writeSettingsAck(ctx, newPromise()); - setMaxConcurrentStreams(1); - - encoderWriteHeaders(3, newPromise()); - assertEquals(0, encoder.numBufferedStreams()); - - // This one gets buffered. - encoderWriteHeaders(5, newPromise()); - assertEquals(1, connection.numActiveStreams()); - assertEquals(1, encoder.numBufferedStreams()); - - // Now prevent us from creating another stream. - setMaxConcurrentStreams(0); - - // Close the previous stream. - connection.stream(3).close(); - - // Ensure that no streams are currently active and that only the HEADERS from the first - // stream were written. - writeVerifyWriteHeaders(times(1), 3); - writeVerifyWriteHeaders(never(), 5); - assertEquals(0, connection.numActiveStreams()); - assertEquals(1, encoder.numBufferedStreams()); - } - - @Test - public void alternatingWritesToActiveAndBufferedStreams() { - encoder.writeSettingsAck(ctx, newPromise()); - setMaxConcurrentStreams(1); - - encoderWriteHeaders(3, newPromise()); - assertEquals(0, encoder.numBufferedStreams()); - - encoderWriteHeaders(5, newPromise()); - assertEquals(1, connection.numActiveStreams()); - assertEquals(1, encoder.numBufferedStreams()); - - encoder.writeData(ctx, 3, EMPTY_BUFFER, 0, false, newPromise()); - writeVerifyWriteHeaders(times(1), 3); - encoder.writeData(ctx, 5, EMPTY_BUFFER, 0, false, newPromise()); - verify(writer, never()) - .writeData(eq(ctx), eq(5), any(ByteBuf.class), eq(0), eq(false), eq(newPromise())); - } - - @Test - public void bufferingNewStreamFailsAfterGoAwayReceived() throws Http2Exception { - encoder.writeSettingsAck(ctx, newPromise()); - setMaxConcurrentStreams(0); - connection.goAwayReceived(1, 8, EMPTY_BUFFER); - - ChannelPromise promise = newPromise(); - encoderWriteHeaders(3, promise); - assertEquals(0, encoder.numBufferedStreams()); - assertTrue(promise.isDone()); - assertFalse(promise.isSuccess()); - } - - @Test - public void receivingGoAwayFailsBufferedStreams() throws Http2Exception { - encoder.writeSettingsAck(ctx, newPromise()); - setMaxConcurrentStreams(5); - - int streamId = 3; - List futures = new ArrayList(); - for (int i = 0; i < 9; i++) { - futures.add(encoderWriteHeaders(streamId, newPromise())); - streamId += 2; - } - assertEquals(4, encoder.numBufferedStreams()); - - connection.goAwayReceived(11, 8, EMPTY_BUFFER); - - assertEquals(5, connection.numActiveStreams()); - int failCount = 0; - for (ChannelFuture f : futures) { - if (f.cause() != null) { - failCount++; - } - } - assertEquals(9, failCount); - assertEquals(0, encoder.numBufferedStreams()); - } - - @Test - public void sendingGoAwayShouldNotFailStreams() { - encoder.writeSettingsAck(ctx, newPromise()); - setMaxConcurrentStreams(1); - - when(writer.writeHeaders( - any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), anyInt(), - anyBoolean(), any(ChannelPromise.class))) - .thenAnswer(successAnswer()); - when(writer.writeHeaders( - any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), anyInt(), - anyShort(), anyBoolean(), anyInt(), anyBoolean(), any(ChannelPromise.class))) - .thenAnswer(successAnswer()); - - ChannelFuture f1 = encoderWriteHeaders(3, newPromise()); - assertEquals(0, encoder.numBufferedStreams()); - ChannelFuture f2 = encoderWriteHeaders(5, newPromise()); - assertEquals(1, encoder.numBufferedStreams()); - ChannelFuture f3 = encoderWriteHeaders(7, newPromise()); - assertEquals(2, encoder.numBufferedStreams()); - - ByteBuf empty = Unpooled.buffer(0); - encoder.writeGoAway(ctx, 3, CANCEL.code(), empty, newPromise()); - - assertEquals(1, connection.numActiveStreams()); - assertEquals(2, encoder.numBufferedStreams()); - assertFalse(f1.isDone()); - assertFalse(f2.isDone()); - assertFalse(f3.isDone()); - } - - @Test - public void endStreamDoesNotFailBufferedStream() { - encoder.writeSettingsAck(ctx, newPromise()); - setMaxConcurrentStreams(0); - - encoderWriteHeaders(3, newPromise()); - assertEquals(1, encoder.numBufferedStreams()); - - encoder.writeData(ctx, 3, EMPTY_BUFFER, 0, true, newPromise()); - - assertEquals(0, connection.numActiveStreams()); - assertEquals(1, encoder.numBufferedStreams()); - - // Simulate that we received a SETTINGS frame which - // increased MAX_CONCURRENT_STREAMS to 1. - setMaxConcurrentStreams(1); - encoder.writeSettingsAck(ctx, newPromise()); - - assertEquals(1, connection.numActiveStreams()); - assertEquals(0, encoder.numBufferedStreams()); - assertEquals(HALF_CLOSED_LOCAL, connection.stream(3).state()); - } - - @Test - public void rstStreamClosesBufferedStream() { - encoder.writeSettingsAck(ctx, newPromise()); - setMaxConcurrentStreams(0); - - encoderWriteHeaders(3, newPromise()); - assertEquals(1, encoder.numBufferedStreams()); - - ChannelPromise rstStreamPromise = newPromise(); - encoder.writeRstStream(ctx, 3, CANCEL.code(), rstStreamPromise); - assertTrue(rstStreamPromise.isSuccess()); - assertEquals(0, encoder.numBufferedStreams()); - } - - @Test - public void bufferUntilActiveStreamsAreReset() throws Exception { - encoder.writeSettingsAck(ctx, newPromise()); - setMaxConcurrentStreams(1); - - encoderWriteHeaders(3, newPromise()); - assertEquals(0, encoder.numBufferedStreams()); - encoderWriteHeaders(5, newPromise()); - assertEquals(1, encoder.numBufferedStreams()); - encoderWriteHeaders(7, newPromise()); - assertEquals(2, encoder.numBufferedStreams()); - - writeVerifyWriteHeaders(times(1), 3); - writeVerifyWriteHeaders(never(), 5); - writeVerifyWriteHeaders(never(), 7); - - encoder.writeRstStream(ctx, 3, CANCEL.code(), newPromise()); - connection.remote().flowController().writePendingBytes(); - writeVerifyWriteHeaders(times(1), 5); - writeVerifyWriteHeaders(never(), 7); - assertEquals(1, connection.numActiveStreams()); - assertEquals(1, encoder.numBufferedStreams()); - - encoder.writeRstStream(ctx, 5, CANCEL.code(), newPromise()); - connection.remote().flowController().writePendingBytes(); - writeVerifyWriteHeaders(times(1), 7); - assertEquals(1, connection.numActiveStreams()); - assertEquals(0, encoder.numBufferedStreams()); - - encoder.writeRstStream(ctx, 7, CANCEL.code(), newPromise()); - assertEquals(0, connection.numActiveStreams()); - assertEquals(0, encoder.numBufferedStreams()); - } - - @Test - public void bufferUntilMaxStreamsIncreased() { - encoder.writeSettingsAck(ctx, newPromise()); - setMaxConcurrentStreams(2); - - encoderWriteHeaders(3, newPromise()); - encoderWriteHeaders(5, newPromise()); - encoderWriteHeaders(7, newPromise()); - encoderWriteHeaders(9, newPromise()); - assertEquals(2, encoder.numBufferedStreams()); - - writeVerifyWriteHeaders(times(1), 3); - writeVerifyWriteHeaders(times(1), 5); - writeVerifyWriteHeaders(never(), 7); - writeVerifyWriteHeaders(never(), 9); - - // Simulate that we received a SETTINGS frame which - // increased MAX_CONCURRENT_STREAMS to 5. - setMaxConcurrentStreams(5); - encoder.writeSettingsAck(ctx, newPromise()); - - assertEquals(0, encoder.numBufferedStreams()); - writeVerifyWriteHeaders(times(1), 7); - writeVerifyWriteHeaders(times(1), 9); - - encoderWriteHeaders(11, newPromise()); - - writeVerifyWriteHeaders(times(1), 11); - - assertEquals(5, connection.local().numActiveStreams()); - } - - @Test - public void bufferUntilSettingsReceived() throws Http2Exception { - int initialLimit = SMALLEST_MAX_CONCURRENT_STREAMS; - int numStreams = initialLimit * 2; - for (int ix = 0, nextStreamId = 3; ix < numStreams; ++ix, nextStreamId += 2) { - encoderWriteHeaders(nextStreamId, newPromise()); - if (ix < initialLimit) { - writeVerifyWriteHeaders(times(1), nextStreamId); - } else { - writeVerifyWriteHeaders(never(), nextStreamId); - } - } - assertEquals(numStreams / 2, encoder.numBufferedStreams()); - - // Simulate that we received a SETTINGS frame. - setMaxConcurrentStreams(initialLimit * 2); - - assertEquals(0, encoder.numBufferedStreams()); - assertEquals(numStreams, connection.local().numActiveStreams()); - } - - @Test - public void bufferUntilSettingsReceivedWithNoMaxConcurrentStreamValue() throws Http2Exception { - int initialLimit = SMALLEST_MAX_CONCURRENT_STREAMS; - int numStreams = initialLimit * 2; - for (int ix = 0, nextStreamId = 3; ix < numStreams; ++ix, nextStreamId += 2) { - encoderWriteHeaders(nextStreamId, newPromise()); - if (ix < initialLimit) { - writeVerifyWriteHeaders(times(1), nextStreamId); - } else { - writeVerifyWriteHeaders(never(), nextStreamId); - } - } - assertEquals(numStreams / 2, encoder.numBufferedStreams()); - - // Simulate that we received an empty SETTINGS frame. - encoder.remoteSettings(new Http2Settings()); - - assertEquals(0, encoder.numBufferedStreams()); - assertEquals(numStreams, connection.local().numActiveStreams()); - } - - @Test - public void exhaustedStreamsDoNotBuffer() throws Http2Exception { - // Write the highest possible stream ID for the client. - // This will cause the next stream ID to be negative. - encoderWriteHeaders(Integer.MAX_VALUE, newPromise()); - - // Disallow any further streams. - setMaxConcurrentStreams(0); - - // Simulate numeric overflow for the next stream ID. - ChannelFuture f = encoderWriteHeaders(-1, newPromise()); - - // Verify that the write fails. - assertNotNull(f.cause()); - } - - @Test - public void closedBufferedStreamReleasesByteBuf() { - encoder.writeSettingsAck(ctx, newPromise()); - setMaxConcurrentStreams(0); - ByteBuf data = mock(ByteBuf.class); - ChannelFuture f1 = encoderWriteHeaders(3, newPromise()); - assertEquals(1, encoder.numBufferedStreams()); - ChannelFuture f2 = encoder.writeData(ctx, 3, data, 0, false, newPromise()); - - ChannelPromise rstPromise = mock(ChannelPromise.class); - encoder.writeRstStream(ctx, 3, CANCEL.code(), rstPromise); - - assertEquals(0, encoder.numBufferedStreams()); - verify(rstPromise).setSuccess(); - assertTrue(f1.isSuccess()); - assertTrue(f2.isSuccess()); - verify(data).release(); - } - - @Test - public void closeShouldCancelAllBufferedStreams() throws Http2Exception { - encoder.writeSettingsAck(ctx, newPromise()); - connection.local().maxActiveStreams(0); - - ChannelFuture f1 = encoderWriteHeaders(3, newPromise()); - ChannelFuture f2 = encoderWriteHeaders(5, newPromise()); - ChannelFuture f3 = encoderWriteHeaders(7, newPromise()); - - encoder.close(); - assertNotNull(f1.cause()); - assertNotNull(f2.cause()); - assertNotNull(f3.cause()); - } - - @Test - public void headersAfterCloseShouldImmediatelyFail() { - encoder.writeSettingsAck(ctx, newPromise()); - encoder.close(); - - ChannelFuture f = encoderWriteHeaders(3, newPromise()); - assertNotNull(f.cause()); - } - - private void setMaxConcurrentStreams(int newValue) { - try { - encoder.remoteSettings(new Http2Settings().maxConcurrentStreams(newValue)); - // Flush the remote flow controller to write data - encoder.flowController().writePendingBytes(); - } catch (Http2Exception e) { - throw new RuntimeException(e); - } - } - - private ChannelFuture encoderWriteHeaders(int streamId, ChannelPromise promise) { - encoder.writeHeaders(ctx, streamId, new DefaultHttp2Headers(), 0, DEFAULT_PRIORITY_WEIGHT, - false, 0, false, promise); - try { - encoder.flowController().writePendingBytes(); - return promise; - } catch (Http2Exception e) { - throw new RuntimeException(e); - } - } - - private void writeVerifyWriteHeaders(VerificationMode mode, int streamId) { - verify(writer, mode).writeHeaders(eq(ctx), eq(streamId), any(Http2Headers.class), eq(0), - eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), - eq(false), any(ChannelPromise.class)); - } - - private Answer successAnswer() { - return new Answer() { - @Override - public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { - for (Object a : invocation.getArguments()) { - ReferenceCountUtil.safeRelease(a); - } - - ChannelPromise future = newPromise(); - future.setSuccess(); - return future; - } - }; - } - - private ChannelPromise newPromise() { - return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); - } - - private static ByteBuf data() { - ByteBuf buf = Unpooled.buffer(10); - for (int i = 0; i < buf.writableBytes(); i++) { - buf.writeByte(i); - } - return buf; - } -} From d9571dc3c8502f23d9cc008c609e16068a92c1bc Mon Sep 17 00:00:00 2001 From: "Penn (Dapeng) Zhang" Date: Fri, 16 Apr 2021 11:03:33 -0700 Subject: [PATCH 7/8] remove duplicate code --- netty/src/main/java/io/grpc/netty/NettyClientHandler.java | 8 -------- 1 file changed, 8 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index bdcac924252..393b3644961 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -611,14 +611,6 @@ private void createStreamTraced( // Create an intermediate promise so that we can intercept the failure reported back to the // application. ChannelPromise tempPromise = ctx().newPromise(); - if (connection().goAwayReceived() - && connection().local().numActiveStreams() == connection().local().maxActiveStreams()) { - Status status = Status.UNAVAILABLE.withCause( - new Http2Exception(Http2Error.REFUSED_STREAM, "GOAWAY received")); - stream.transportReportStatus(status, RpcProgress.REFUSED, true, new Metadata()); - promise.setFailure(status.asRuntimeException()); - return; - } encoder().writeHeaders(ctx(), streamId, headers, 0, isGet, tempPromise) .addListener(new ChannelFutureListener() { @Override From 2b2be63e653dd68be6b2a7071e6b985763808246 Mon Sep 17 00:00:00 2001 From: "Penn (Dapeng) Zhang" Date: Fri, 16 Apr 2021 11:05:48 -0700 Subject: [PATCH 8/8] remove unused code --- netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java | 6 ------ 1 file changed, 6 deletions(-) diff --git a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java index 684f2050ac2..04f65eed145 100644 --- a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java @@ -292,12 +292,6 @@ protected final ByteBuf headersFrame(int streamId, Http2Headers headers) { return captureWrite(ctx); } - protected final ByteBuf trailersFrame(int streamId, Http2Headers headers) { - ChannelHandlerContext ctx = newMockContext(); - new DefaultHttp2FrameWriter().writeHeaders(ctx, streamId, headers, 0, true, newPromise()); - return captureWrite(ctx); - } - protected final ByteBuf goAwayFrame(int lastStreamId) { return goAwayFrame(lastStreamId, 0, Unpooled.EMPTY_BUFFER); }