diff --git a/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java b/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java index a37221b12841..d340a402f8ff 100644 --- a/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java +++ b/netty/src/main/java/io/grpc/netty/StreamBufferingEncoder.java @@ -153,10 +153,14 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2 if (closed) { return promise.setFailure(new Http2ChannelClosedException()); } - if (isExistingStream(streamId) || connection().goAwayReceived()) { + if (isExistingStream(streamId)) { return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, promise); } + if (connection().goAwayReceived()) { + promise.setFailure(new Http2Exception(Http2Error.NO_ERROR, "GOAWAY received")); + return promise; + } if (canCreateStream()) { return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, promise); diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index 223f4781538e..99dfe905b5a1 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -384,23 +384,20 @@ public void receivedAbruptGoAwayShouldFailRacingQueuedStreamid() throws Exceptio } @Test - public void receivedGoAway_shouldNotFailBufferedStreamWithStreamIdLessThanLastId() + public void receivedGoAway_shouldFailBufferedStreams() throws Exception { - ClientStreamListener streamListener1 = mock(ClientStreamListener.class); NettyClientStream.TransportState streamTransportState1 = new TransportStateImpl( handler(), channel().eventLoop(), DEFAULT_MAX_MESSAGE_SIZE, transportTracer); - streamTransportState1.setListener(streamListener1); - ClientStreamListener streamListener2 = mock(ClientStreamListener.class); + streamTransportState1.setListener(mock(ClientStreamListener.class)); NettyClientStream.TransportState streamTransportState2 = new TransportStateImpl( handler(), channel().eventLoop(), DEFAULT_MAX_MESSAGE_SIZE, transportTracer); - streamTransportState2.setListener(streamListener2); - // MAX_CONCURRENT_STREAMS=1 + streamTransportState2.setListener(mock(ClientStreamListener.class)); receiveMaxConcurrentStreams(1); ChannelFuture future1 = writeQueue().enqueue( newCreateStreamCommand(grpcHeaders, streamTransportState1), true); @@ -409,67 +406,10 @@ public void receivedGoAway_shouldNotFailBufferedStreamWithStreamIdLessThanLastId // GOAWAY channelRead(goAwayFrame(Integer.MAX_VALUE)); - assertTrue(future1.isSuccess()); - verify(streamListener1).onReady(); - channel().runPendingTasks(); - if (future2.cause() != null ) { - throw new AssertionError(future2.cause()); - } - assertFalse(future2.isDone()); - verify(streamListener2, never()).onReady(); - - // Let the first stream complete, then the pending stream will be activated. - Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK) - .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC); - channelRead(headersFrame(streamId, headers)); - channelRead(grpcDataFrame(streamId, false, contentAsArray())); - channelRead(trailersFrame( - streamId, - new DefaultHttp2Headers().set(AsciiString.of("grpc-status"), AsciiString.of("0")))); - streamTransportState1.requestMessagesFromDeframerForTesting(1); - verify(streamListener1).closed(eq(Status.OK), any(RpcProgress.class), - any(Metadata.class)); - channel().runPendingTasks(); - if (future2.cause() != null) { - throw new AssertionError(future2.cause()); - } - assertTrue(future2.isSuccess()); - verify(streamListener2).onReady(); - } - - @Test - public void receivedGoAway_shouldRefuseBufferedStreamWithStreamIdGreaterThanLastId() - throws Exception { - ClientStreamListener streamListener1 = mock(ClientStreamListener.class); - NettyClientStream.TransportState streamTransportState1 = new TransportStateImpl( - handler(), - channel().eventLoop(), - DEFAULT_MAX_MESSAGE_SIZE, - transportTracer); - streamTransportState1.setListener(streamListener1); - ClientStreamListener streamListener2 = mock(ClientStreamListener.class); - NettyClientStream.TransportState streamTransportState2 = new TransportStateImpl( - handler(), - channel().eventLoop(), - DEFAULT_MAX_MESSAGE_SIZE, - transportTracer); - streamTransportState2.setListener(streamListener2); - // MAX_CONCURRENT_STREAMS=1 - receiveMaxConcurrentStreams(1); - ChannelFuture future1 = writeQueue().enqueue( - newCreateStreamCommand(grpcHeaders, streamTransportState1), true); - ChannelFuture future2 = writeQueue().enqueue( - newCreateStreamCommand(grpcHeaders, streamTransportState2), true); - - // GOAWAY - channelRead(goAwayFrame(streamId)); - assertTrue(future1.isSuccess()); - verify(streamListener1).onReady(); - assertThat(future2.cause()).isNotNull(); - assertThat(Status.fromThrowable(future2.cause()).getCode()).isEqualTo(Status.Code.UNAVAILABLE); - assertThat(Status.fromThrowable(future2.cause()).getDescription()).isEqualTo( - "Abrupt GOAWAY closed unsent stream. HTTP/2 error code: NO_ERROR"); - verify(streamListener2).closed(any(Status.class), eq(RpcProgress.REFUSED), any(Metadata.class)); + assertTrue(future1.isDone()); + assertThat(future1.cause().getMessage()).contains("GOAWAY received"); + assertTrue(future2.isDone()); + assertThat(future2.cause().getMessage()).contains("GOAWAY received"); } @Test