diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index a889a040da6..e71a69559bc 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -19,7 +19,6 @@ import io.grpc.ChannelLogger; import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler; import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler; -import io.grpc.netty.ProtocolNegotiators.ProtocolNegotiationHandler; import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; @@ -40,20 +39,6 @@ public static ChannelLogger negotiationLogger(ChannelHandlerContext ctx) { return ProtocolNegotiators.negotiationLogger(ctx); } - /** - * Buffers all writes until either {@link #writeBufferedAndRemove(ChannelHandlerContext)} or - * {@link #fail(ChannelHandlerContext, Throwable)} is called. This handler allows us to - * write to a {@link io.netty.channel.Channel} before we are allowed to write to it officially - * i.e. before it's active or the TLS Handshake is complete. - */ - public abstract static class AbstractBufferingHandler - extends ProtocolNegotiators.AbstractBufferingHandler { - - protected AbstractBufferingHandler(ChannelHandler... handlers) { - super(handlers); - } - } - /** * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 51c30fd14d8..feeaf9cf358 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -34,13 +34,10 @@ import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpClientUpgradeHandler; @@ -50,7 +47,6 @@ import io.netty.handler.codec.http2.Http2ClientUpgradeCodec; import io.netty.handler.proxy.HttpProxyHandler; import io.netty.handler.proxy.ProxyConnectionEvent; -import io.netty.handler.proxy.ProxyHandler; import io.netty.handler.ssl.OpenSsl; import io.netty.handler.ssl.OpenSslEngine; import io.netty.handler.ssl.SslContext; @@ -59,12 +55,9 @@ import io.netty.util.AsciiString; import io.netty.util.Attribute; import io.netty.util.AttributeMap; -import io.netty.util.ReferenceCountUtil; import java.net.SocketAddress; import java.net.URI; -import java.util.ArrayDeque; import java.util.Arrays; -import java.util.Queue; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -195,20 +188,16 @@ private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress, final @Nullable String proxyUsername, final @Nullable String proxyPassword, final ProtocolNegotiator negotiator) { + checkNotNull(negotiator, "negotiator"); + checkNotNull(proxyAddress, "proxyAddress"); final AsciiString scheme = negotiator.scheme(); - Preconditions.checkNotNull(proxyAddress, "proxyAddress"); - Preconditions.checkNotNull(negotiator, "negotiator"); class ProxyNegotiator implements ProtocolNegotiator { @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler http2Handler) { - HttpProxyHandler proxyHandler; - if (proxyUsername == null || proxyPassword == null) { - proxyHandler = new HttpProxyHandler(proxyAddress); - } else { - proxyHandler = new HttpProxyHandler(proxyAddress, proxyUsername, proxyPassword); - } - return new BufferUntilProxyTunnelledHandler( - proxyHandler, negotiator.newHandler(http2Handler)); + ChannelHandler protocolNegotiationHandler = negotiator.newHandler(http2Handler); + ChannelHandler ppnh = new ProxyProtocolNegotiationHandler( + proxyAddress, proxyUsername, proxyPassword, protocolNegotiationHandler); + return ppnh; } @Override @@ -228,34 +217,45 @@ public void close() { } /** - * Buffers all writes until the HTTP CONNECT tunnel is established. + * A Proxy handler follows {@link ProtocolNegotiationHandler} pattern. Upon successful proxy + * connection, this handler will install {@code next} handler which should be a handler from + * other type of {@link ProtocolNegotiator} to continue negotiating protocol using proxy. */ - static final class BufferUntilProxyTunnelledHandler extends AbstractBufferingHandler { + static final class ProxyProtocolNegotiationHandler extends ProtocolNegotiationHandler { - public BufferUntilProxyTunnelledHandler(ProxyHandler proxyHandler, ChannelHandler handler) { - super(proxyHandler, handler); - } + private final SocketAddress address; + @Nullable private final String userName; + @Nullable private final String password; - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - if (evt instanceof ProxyConnectionEvent) { - writeBufferedAndRemove(ctx); - } - super.userEventTriggered(ctx, evt); + public ProxyProtocolNegotiationHandler( + SocketAddress address, + @Nullable String userName, + @Nullable String password, + ChannelHandler next) { + super(next); + this.address = checkNotNull(address, "address"); + this.userName = userName; + this.password = password; } @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - fail(ctx, unavailableException("Connection broken while trying to CONNECT through proxy")); - super.channelInactive(ctx); + protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { + HttpProxyHandler nettyProxyHandler; + if (userName == null || password == null) { + nettyProxyHandler = new HttpProxyHandler(address); + } else { + nettyProxyHandler = new HttpProxyHandler(address, userName, password); + } + ctx.pipeline().addBefore(ctx.name(), /* newName= */ null, nettyProxyHandler); } @Override - public void close(ChannelHandlerContext ctx, ChannelPromise future) throws Exception { - if (ctx.channel().isActive()) { // This may be a notification that the socket was closed - fail(ctx, unavailableException("Channel closed while trying to CONNECT through proxy")); + protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof ProxyConnectionEvent) { + fireProtocolNegotiationEvent(ctx); + } else { + super.userEventTriggered(ctx, evt); } - super.close(ctx, future); } } @@ -527,208 +527,6 @@ static void logSslEngineDetails(Level level, ChannelHandlerContext ctx, String m log.log(level, builder.toString(), t); } - /** - * Buffers all writes until either {@link #writeBufferedAndRemove(ChannelHandlerContext)} or - * {@link #fail(ChannelHandlerContext, Throwable)} is called. This handler allows us to - * write to a {@link io.netty.channel.Channel} before we are allowed to write to it officially - * i.e. before it's active or the TLS Handshake is complete. - */ - public abstract static class AbstractBufferingHandler extends ChannelDuplexHandler { - - private ChannelHandler[] handlers; - private Queue bufferedWrites = new ArrayDeque<>(); - private boolean writing; - private boolean flushRequested; - private Throwable failCause; - - /** - * @param handlers the ChannelHandlers are added to the pipeline on channelRegistered and - * before this handler. - */ - protected AbstractBufferingHandler(ChannelHandler... handlers) { - this.handlers = handlers; - } - - /** - * When this channel is registered, we will add all the ChannelHandlers passed into our - * constructor to the pipeline. - */ - @Override - public void channelRegistered(ChannelHandlerContext ctx) throws Exception { - /** - * This check is necessary as a channel may be registered with different event loops during it - * lifetime and we only want to configure it once. - */ - if (handlers != null && handlers.length > 0) { - for (ChannelHandler handler : handlers) { - ctx.pipeline().addBefore(ctx.name(), null, handler); - } - ChannelHandler handler0 = handlers[0]; - ChannelHandlerContext handler0Ctx = ctx.pipeline().context(handlers[0]); - handlers = null; - if (handler0Ctx != null) { // The handler may have removed itself immediately - if (handler0 instanceof ChannelInboundHandler) { - ((ChannelInboundHandler) handler0).channelRegistered(handler0Ctx); - } else { - handler0Ctx.fireChannelRegistered(); - } - } - } else { - super.channelRegistered(ctx); - } - } - - /** - * Do not rely on channel handlers to propagate exceptions to us. - * {@link NettyClientHandler} is an example of a class that does not propagate exceptions. - * Add a listener to the connect future directly and do appropriate error handling. - */ - @Override - public void connect(final ChannelHandlerContext ctx, SocketAddress remoteAddress, - SocketAddress localAddress, ChannelPromise promise) throws Exception { - super.connect(ctx, remoteAddress, localAddress, promise); - promise.addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (!future.isSuccess()) { - fail(ctx, future.cause()); - } - } - }); - } - - /** - * If we encounter an exception, then notify all buffered writes that we failed. - */ - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - fail(ctx, cause); - } - - /** - * If this channel becomes inactive, then notify all buffered writes that we failed. - */ - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - fail(ctx, unavailableException("Connection broken while performing protocol negotiation")); - super.channelInactive(ctx); - } - - /** - * Buffers the write until either {@link #writeBufferedAndRemove(ChannelHandlerContext)} is - * called, or we have somehow failed. If we have already failed in the past, then the write - * will fail immediately. - */ - @Override - @SuppressWarnings("FutureReturnValueIgnored") - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) - throws Exception { - /** - * This check handles a race condition between Channel.write (in the calling thread) and the - * removal of this handler (in the event loop thread). - * The problem occurs in e.g. this sequence: - * 1) [caller thread] The write method identifies the context for this handler - * 2) [event loop] This handler removes itself from the pipeline - * 3) [caller thread] The write method delegates to the invoker to call the write method in - * the event loop thread. When this happens, we identify that this handler has been - * removed with "bufferedWrites == null". - */ - if (failCause != null) { - promise.setFailure(failCause); - ReferenceCountUtil.release(msg); - } else if (bufferedWrites == null) { - super.write(ctx, msg, promise); - } else { - bufferedWrites.add(new ChannelWrite(msg, promise)); - } - } - - /** - * Calls to this method will not trigger an immediate flush. The flush will be deferred until - * {@link #writeBufferedAndRemove(ChannelHandlerContext)}. - */ - @Override - public void flush(ChannelHandlerContext ctx) { - /** - * Swallowing any flushes is not only an optimization but also required - * for the SslHandler to work correctly. If the SslHandler receives multiple - * flushes while the handshake is still ongoing, then the handshake "randomly" - * times out. Not sure at this point why this is happening. Doing a single flush - * seems to work but multiple flushes don't ... - */ - if (bufferedWrites == null) { - ctx.flush(); - } else { - flushRequested = true; - } - } - - /** - * If we are still performing protocol negotiation, then this will propagate failures to all - * buffered writes. - */ - @Override - public void close(ChannelHandlerContext ctx, ChannelPromise future) throws Exception { - if (ctx.channel().isActive()) { // This may be a notification that the socket was closed - fail(ctx, unavailableException("Channel closed while performing protocol negotiation")); - } - super.close(ctx, future); - } - - /** - * Propagate failures to all buffered writes. - */ - @SuppressWarnings("FutureReturnValueIgnored") - protected final void fail(ChannelHandlerContext ctx, Throwable cause) { - if (failCause == null) { - failCause = cause; - } - if (bufferedWrites != null) { - while (!bufferedWrites.isEmpty()) { - ChannelWrite write = bufferedWrites.poll(); - write.promise.setFailure(cause); - ReferenceCountUtil.release(write.msg); - } - bufferedWrites = null; - } - - ctx.fireExceptionCaught(cause); - } - - @SuppressWarnings("FutureReturnValueIgnored") - protected final void writeBufferedAndRemove(ChannelHandlerContext ctx) { - if (!ctx.channel().isActive() || writing) { - return; - } - // Make sure that method can't be reentered, so that the ordering - // in the queue can't be messed up. - writing = true; - while (!bufferedWrites.isEmpty()) { - ChannelWrite write = bufferedWrites.poll(); - ctx.write(write.msg, write.promise); - } - assert bufferedWrites.isEmpty(); - bufferedWrites = null; - if (flushRequested) { - ctx.flush(); - } - // Removal has to happen last as the above writes will likely trigger - // new writes that have to be added to the end of queue in order to not - // mess up the ordering. - ctx.pipeline().remove(this); - } - - private static class ChannelWrite { - Object msg; - ChannelPromise promise; - - ChannelWrite(Object msg, ChannelPromise promise) { - this.msg = msg; - this.promise = promise; - } - } - } - /** * Adapts a {@link ProtocolNegotiationEvent} to the {@link GrpcHttp2ConnectionHandler}. */ diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 75e40bb2449..5897d3f76c2 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -67,8 +67,10 @@ import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; import io.netty.channel.Channel; import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFactory; import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.ReflectiveChannelFactory; @@ -915,9 +917,21 @@ public String parse(InputStream stream) { } } - private static class NoopHandler extends ProtocolNegotiators.AbstractBufferingHandler { + private static class NoopHandler extends ChannelDuplexHandler { + + private final GrpcHttp2ConnectionHandler grpcHandler; + public NoopHandler(GrpcHttp2ConnectionHandler grpcHandler) { - super(grpcHandler); + this.grpcHandler = grpcHandler; + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + ctx.pipeline().addBefore(ctx.name(), null, grpcHandler); + } + + public void fail(ChannelHandlerContext ctx, Throwable cause) { + ctx.fireExceptionCaught(cause); } } diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 6c9f3526376..8fde613fff9 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -35,7 +35,6 @@ import io.grpc.SecurityLevel; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.testing.TestUtils; -import io.grpc.netty.ProtocolNegotiators.AbstractBufferingHandler; import io.grpc.netty.ProtocolNegotiators.ClientTlsProtocolNegotiator; import io.grpc.netty.ProtocolNegotiators.HostPort; import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler; @@ -45,6 +44,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; @@ -465,7 +465,10 @@ public void httpProxy_completes() throws Exception { ProtocolNegotiator nego = ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext()); - ChannelHandler handler = nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler()); + // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, + // mocking the behavior using KickStartHandler. + ChannelHandler handler = + new KickStartHandler(nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler())); Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler) .register().sync().channel(); pipeline = channel.pipeline(); @@ -525,7 +528,10 @@ public void httpProxy_500() throws Exception { ProtocolNegotiator nego = ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext()); - ChannelHandler handler = nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler()); + // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, + // mocking the behavior using KickStartHandler. + ChannelHandler handler = + new KickStartHandler(nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler())); Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler) .register().sync().channel(); pipeline = channel.pipeline(); @@ -604,24 +610,6 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { elg.shutdownGracefully(); } - @Test(expected = Test.None.class /* no exception expected */) - @SuppressWarnings("TestExceptionChecker") - public void bufferingHandler_shouldNotThrowForEmptyHandler() throws Exception { - LocalAddress addr = new LocalAddress("local"); - ChannelFuture unused = new Bootstrap() - .channel(LocalChannel.class) - .handler(new BufferingHandlerWithoutHandlers()) - .group(group) - .register().sync(); - ChannelFuture sf = new ServerBootstrap() - .channel(LocalServerChannel.class) - .childHandler(new ChannelHandlerAdapter() {}) - .group(group) - .bind(addr); - // sync will trigger client's NoHandlerBufferingHandler which should not throw - sf.sync(); - } - @Test public void clientTlsHandler_firesNegotiation() throws Exception { SelfSignedCertificate cert = new SelfSignedCertificate("authority"); @@ -815,10 +803,18 @@ private static ByteBuf bb(String s, Channel c) { return ByteBufUtil.writeUtf8(c.alloc(), s); } - private static class BufferingHandlerWithoutHandlers extends AbstractBufferingHandler { + private static final class KickStartHandler extends ChannelDuplexHandler { + + private final ChannelHandler next; - public BufferingHandlerWithoutHandlers(ChannelHandler... handlers) { - super(handlers); + public KickStartHandler(ChannelHandler next) { + this.next = checkNotNull(next, "next"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + ctx.pipeline().replace(ctx.name(), null, next); + ctx.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); } } }