Skip to content

Commit

Permalink
netty: fix a race for channelz at server transport creation
Browse files Browse the repository at this point in the history
A race condition was reported by user in grpc#6601:

`ServerImpl.start()` calls `NettyServer.start()` while holding `ServerImpl.lock`. `NettyServer.start()` awaits a submitted runnable in eventloop. However, this pending runnable may never be executed because the eventloop might be executing some other task, like `ServerListenerImpl.transportCreated()`, that is trying to acquire `ServerImpl.lock` causing a deadlock.

This PR resolves the particular issue reported in grpc#6601 for server with a single port, but `NettyServer` (https://github.com/grpc/grpc-java/blob/v1.26.0/netty/src/main/java/io/grpc/netty/NettyServer.java#L251) and `ServerImpl` (https://github.com/grpc/grpc-java/blob/v1.26.0/core/src/main/java/io/grpc/internal/ServerImpl.java#L184) in general still have the same potential risk of deadlock, which need further fix.
  • Loading branch information
dapengzhang0 committed Jan 22, 2020
1 parent 5acb70e commit a3af83b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 29 deletions.
23 changes: 8 additions & 15 deletions netty/src/main/java/io/grpc/netty/NettyServer.java
Expand Up @@ -56,7 +56,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;

Expand Down Expand Up @@ -93,9 +92,8 @@ class NettyServer implements InternalServer, InternalWithLogId {
private final List<? extends ServerStreamTracer.Factory> streamTracerFactories;
private final TransportTracer.Factory transportTracerFactory;
private final InternalChannelz channelz;
// Only modified in event loop but safe to read any time. Set at startup and unset at shutdown.
private final AtomicReference<InternalInstrumented<SocketStats>> listenSocketStats =
new AtomicReference<>();
// Only modified in event loop but safe to read any time.
private volatile InternalInstrumented<SocketStats> listenSocketStats;

NettyServer(
SocketAddress address, ChannelFactory<? extends ServerChannel> channelFactory,
Expand Down Expand Up @@ -149,7 +147,7 @@ public SocketAddress getListenSocketAddress() {

@Override
public InternalInstrumented<SocketStats> getListenSocketStats() {
return listenSocketStats.get();
return listenSocketStats;
}

@Override
Expand Down Expand Up @@ -251,19 +249,13 @@ public void operationComplete(ChannelFuture future) throws Exception {
throw new IOException("Failed to bind", future.cause());
}
channel = future.channel();
Future<?> channelzFuture = channel.eventLoop().submit(new Runnable() {
channel.eventLoop().execute(new Runnable() {
@Override
public void run() {
InternalInstrumented<SocketStats> listenSocket = new ListenSocket(channel);
listenSocketStats.set(listenSocket);
channelz.addListenSocket(listenSocket);
listenSocketStats = new ListenSocket(channel);
channelz.addListenSocket(listenSocketStats);
}
});
try {
channelzFuture.await();
} catch (InterruptedException ex) {
throw new RuntimeException("Interrupted while registering listen socket to channelz", ex);
}
}

@Override
Expand All @@ -278,7 +270,8 @@ public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
log.log(Level.WARNING, "Error shutting down server", future.cause());
}
InternalInstrumented<SocketStats> stats = listenSocketStats.getAndSet(null);
InternalInstrumented<SocketStats> stats = listenSocketStats;
listenSocketStats = null;
if (stats != null) {
channelz.removeListenSocket(stats);
}
Expand Down
46 changes: 32 additions & 14 deletions netty/src/test/java/io/grpc/netty/NettyServerTest.java
Expand Up @@ -29,16 +29,19 @@
import io.grpc.InternalInstrumented;
import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;
import io.grpc.internal.FixedObjectPool;
import io.grpc.internal.ServerListener;
import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerTransport;
import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.SharedResourcePool;
import io.grpc.internal.TransportTracer;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelOption;
import io.netty.channel.ReflectiveChannelFactory;
import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.AsciiString;
import java.net.InetSocketAddress;
import java.net.Socket;
Expand All @@ -48,13 +51,21 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.After;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class NettyServerTest {
private final InternalChannelz channelz = new InternalChannelz();
private final NioEventLoopGroup eventLoop = new NioEventLoopGroup(1);

@After
public void tearDown() throws Exception {
eventLoop.shutdownGracefully(0, 0, TimeUnit.SECONDS);
eventLoop.awaitTermination(5, TimeUnit.SECONDS);
}

@Test
public void startStop() throws Exception {
Expand All @@ -79,10 +90,10 @@ class TestProtocolNegotiator implements ProtocolNegotiator {
TestProtocolNegotiator protocolNegotiator = new TestProtocolNegotiator();
NettyServer ns = new NettyServer(
addr,
Utils.DEFAULT_SERVER_CHANNEL_FACTORY,
new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(),
SharedResourcePool.forResource(Utils.DEFAULT_BOSS_EVENT_LOOP_GROUP),
SharedResourcePool.forResource(Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP),
new FixedObjectPool<>(eventLoop),
new FixedObjectPool<>(eventLoop),
protocolNegotiator,
Collections.<ServerStreamTracer.Factory>emptyList(),
TransportTracer.getDefaultFactory(),
Expand Down Expand Up @@ -119,14 +130,14 @@ public void serverShutdown() {
}

@Test
public void getPort_notStarted() throws Exception {
public void getPort_notStarted() {
InetSocketAddress addr = new InetSocketAddress(0);
NettyServer ns = new NettyServer(
addr,
Utils.DEFAULT_SERVER_CHANNEL_FACTORY,
new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(),
SharedResourcePool.forResource(Utils.DEFAULT_BOSS_EVENT_LOOP_GROUP),
SharedResourcePool.forResource(Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP),
new FixedObjectPool<>(eventLoop),
new FixedObjectPool<>(eventLoop),
ProtocolNegotiators.plaintext(),
Collections.<ServerStreamTracer.Factory>emptyList(),
TransportTracer.getDefaultFactory(),
Expand Down Expand Up @@ -161,10 +172,10 @@ public void childChannelOptions() throws Exception {
InetSocketAddress addr = new InetSocketAddress(0);
NettyServer ns = new NettyServer(
addr,
Utils.DEFAULT_SERVER_CHANNEL_FACTORY,
new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
channelOptions,
SharedResourcePool.forResource(Utils.DEFAULT_BOSS_EVENT_LOOP_GROUP),
SharedResourcePool.forResource(Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP),
new FixedObjectPool<>(eventLoop),
new FixedObjectPool<>(eventLoop),
ProtocolNegotiators.plaintext(),
Collections.<ServerStreamTracer.Factory>emptyList(),
TransportTracer.getDefaultFactory(),
Expand Down Expand Up @@ -211,10 +222,10 @@ public void channelzListenSocket() throws Exception {
InetSocketAddress addr = new InetSocketAddress(0);
NettyServer ns = new NettyServer(
addr,
Utils.DEFAULT_SERVER_CHANNEL_FACTORY,
new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(),
SharedResourcePool.forResource(Utils.DEFAULT_BOSS_EVENT_LOOP_GROUP),
SharedResourcePool.forResource(Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP),
new FixedObjectPool<>(eventLoop),
new FixedObjectPool<>(eventLoop),
ProtocolNegotiators.plaintext(),
Collections.<ServerStreamTracer.Factory>emptyList(),
TransportTracer.getDefaultFactory(),
Expand All @@ -239,8 +250,15 @@ public void serverShutdown() {
shutdownCompleted.set(null);
}
});

assertThat(((InetSocketAddress) ns.getListenSocketAddress()).getPort()).isGreaterThan(0);

// SocketStats won't be available until the event loop task of adding SocketStats created by
// ns.start() complete. So submit a noop task and await until it's drained.
eventLoop.submit(new Runnable() {
@Override
public void run() {}
}).await(5, TimeUnit.SECONDS);
InternalInstrumented<SocketStats> listenSocket = ns.getListenSocketStats();
assertSame(listenSocket, channelz.getSocket(id(listenSocket)));

Expand Down

0 comments on commit a3af83b

Please sign in to comment.