Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix StreamBufferingEncoder GOAWAY bug #11144

Merged
merged 13 commits into from Apr 19, 2021
Merged
Expand Up @@ -69,33 +69,45 @@ public Http2ChannelClosedException() {
}
}

private static final class GoAwayDetail {
private final int lastStreamId;
private final long errorCode;
private final byte[] debugData;

GoAwayDetail(int lastStreamId, long errorCode, byte[] debugData) {
this.lastStreamId = lastStreamId;
this.errorCode = errorCode;
this.debugData = debugData.clone();
}
}

/**
* Thrown by {@link StreamBufferingEncoder} if buffered streams are terminated due to
* receipt of a {@code GOAWAY}.
*/
public static final class Http2GoAwayException extends Http2Exception {
private static final long serialVersionUID = 1326785622777291198L;
private final int lastStreamId;
private final long errorCode;
private final byte[] debugData;
private GoAwayDetail goAwayDetail;
normanmaurer marked this conversation as resolved.
Show resolved Hide resolved

public Http2GoAwayException(int lastStreamId, long errorCode, byte[] debugData) {
this(new GoAwayDetail(lastStreamId, errorCode, debugData));
}

private Http2GoAwayException(GoAwayDetail goAwayDetail) {
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
super(Http2Error.STREAM_CLOSED);
this.lastStreamId = lastStreamId;
this.errorCode = errorCode;
this.debugData = debugData;
this.goAwayDetail = goAwayDetail;
}

public int lastStreamId() {
return lastStreamId;
return goAwayDetail.lastStreamId;
}

public long errorCode() {
return errorCode;
return goAwayDetail.errorCode;
}

public byte[] debugData() {
return debugData;
return goAwayDetail.debugData;
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand All @@ -106,6 +118,7 @@ public byte[] debugData() {
private final TreeMap<Integer, PendingStream> pendingStreams = new TreeMap<Integer, PendingStream>();
private int maxConcurrentStreams;
private boolean closed;
private GoAwayDetail goAwayDetail;

public StreamBufferingEncoder(Http2ConnectionEncoder delegate) {
this(delegate, SMALLEST_MAX_CONCURRENT_STREAMS);
Expand All @@ -118,7 +131,9 @@ public StreamBufferingEncoder(Http2ConnectionEncoder delegate, int initialMaxCon

@Override
public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) {
cancelGoAwayStreams(lastStreamId, errorCode, debugData);
goAwayDetail = new GoAwayDetail(
lastStreamId, errorCode, ByteBufUtil.getBytes(debugData));
normanmaurer marked this conversation as resolved.
Show resolved Hide resolved
cancelGoAwayStreams(goAwayDetail);
}

@Override
Expand Down Expand Up @@ -149,13 +164,14 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2
if (closed) {
return promise.setFailure(new Http2ChannelClosedException());
}
if (isExistingStream(streamId) || connection().goAwayReceived()) {
if (isExistingStream(streamId) || canCreateStream()) {
return super.writeHeaders(ctx, streamId, headers, streamDependency, weight,
exclusive, padding, endOfStream, promise);
}
if (canCreateStream()) {
return super.writeHeaders(ctx, streamId, headers, streamDependency, weight,
exclusive, padding, endOfStream, promise);
if (goAwayDetail != null) {
promise.setFailure(new Http2GoAwayException(
goAwayDetail.lastStreamId, goAwayDetail.lastStreamId, goAwayDetail.debugData));
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
return promise;
normanmaurer marked this conversation as resolved.
Show resolved Hide resolved
}
PendingStream pendingStream = pendingStreams.get(streamId);
if (pendingStream == null) {
Expand Down Expand Up @@ -248,12 +264,12 @@ private void tryCreatePendingStreams() {
}
}

private void cancelGoAwayStreams(int lastStreamId, long errorCode, ByteBuf debugData) {
private void cancelGoAwayStreams(GoAwayDetail goAwayDetail) {
Iterator<PendingStream> iter = pendingStreams.values().iterator();
Exception e = new Http2GoAwayException(lastStreamId, errorCode, ByteBufUtil.getBytes(debugData));
Exception e = new Http2GoAwayException(goAwayDetail);
while (iter.hasNext()) {
PendingStream stream = iter.next();
if (stream.streamId > lastStreamId) {
if (stream.streamId > goAwayDetail.lastStreamId) {
iter.remove();
stream.close(e);
}
Expand Down
Expand Up @@ -49,6 +49,7 @@
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.channel.DefaultMessageSizeEstimator;
import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2GoAwayException;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.ImmediateEventExecutor;
Expand Down Expand Up @@ -111,6 +112,11 @@ public void setup() throws Exception {
when(writer.writeGoAway(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class),
any(ChannelPromise.class)))
.thenAnswer(successAnswer());
when(writer.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class),
normanmaurer marked this conversation as resolved.
Show resolved Hide resolved
anyInt(), anyBoolean(), any(ChannelPromise.class))).thenAnswer(noopAnswer());
when(writer.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class),
anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean(), any(ChannelPromise.class)))
.thenAnswer(noopAnswer());

connection = new DefaultHttp2Connection(false);
connection.remote().flowController(new DefaultHttp2RemoteFlowController(connection));
Expand Down Expand Up @@ -167,7 +173,7 @@ public void multipleWritesToActiveStream() {
encoder.writeData(ctx, 3, data(), 0, false, newPromise());
encoderWriteHeaders(3, newPromise());

writeVerifyWriteHeaders(times(2), 3);
writeVerifyWriteHeaders(times(1), 3);
// Contiguous data writes are coalesced
ArgumentCaptor<ByteBuf> bufCaptor = ArgumentCaptor.forClass(ByteBuf.class);
verify(writer, times(1))
Expand Down Expand Up @@ -245,18 +251,32 @@ public void receivingGoAwayFailsBufferedStreams() throws Http2Exception {
futures.add(encoderWriteHeaders(streamId, newPromise()));
streamId += 2;
}
assertEquals(5, connection.numActiveStreams());
assertEquals(4, encoder.numBufferedStreams());

connection.goAwayReceived(11, 8, EMPTY_BUFFER);

assertEquals(5, connection.numActiveStreams());
assertEquals(0, encoder.numBufferedStreams());
int failCount = 0;
for (ChannelFuture f : futures) {
if (f.cause() != null) {
assertTrue(f.cause() instanceof Http2GoAwayException);
failCount++;
}
}
assertEquals(9, failCount);
assertEquals(4, failCount);
}

@Test
public void receivingGoAwayFailsNewStreamIfMaxConcurrentStreamsReached() throws Http2Exception {
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(1);
encoderWriteHeaders(3, newPromise());
connection.goAwayReceived(11, 8, EMPTY_BUFFER);
ChannelFuture f = encoderWriteHeaders(5, newPromise());

assertTrue(f.cause() instanceof Http2GoAwayException);
assertEquals(0, encoder.numBufferedStreams());
}

Expand Down Expand Up @@ -533,6 +553,20 @@ public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
};
}

private Answer<ChannelFuture> noopAnswer() {
return new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
for (Object a : invocation.getArguments()) {
if (a instanceof ChannelPromise) {
return (ChannelFuture) a;
}
}
return newPromise();
}
};
}

private ChannelPromise newPromise() {
return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE);
}
Expand Down