Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

netty: fix a race for channelz at server transport creation #6610

Merged
merged 4 commits into from Jan 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 8 additions & 15 deletions netty/src/main/java/io/grpc/netty/NettyServer.java
Expand Up @@ -56,7 +56,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;

Expand Down Expand Up @@ -93,9 +92,8 @@ class NettyServer implements InternalServer, InternalWithLogId {
private final List<? extends ServerStreamTracer.Factory> streamTracerFactories;
private final TransportTracer.Factory transportTracerFactory;
private final InternalChannelz channelz;
// Only modified in event loop but safe to read any time. Set at startup and unset at shutdown.
private final AtomicReference<InternalInstrumented<SocketStats>> listenSocketStats =
new AtomicReference<>();
// Only modified in event loop but safe to read any time.
private volatile InternalInstrumented<SocketStats> listenSocketStats;

NettyServer(
SocketAddress address, ChannelFactory<? extends ServerChannel> channelFactory,
Expand Down Expand Up @@ -149,7 +147,7 @@ public SocketAddress getListenSocketAddress() {

@Override
public InternalInstrumented<SocketStats> getListenSocketStats() {
return listenSocketStats.get();
return listenSocketStats;
}

@Override
Expand Down Expand Up @@ -251,19 +249,13 @@ public void operationComplete(ChannelFuture future) throws Exception {
throw new IOException("Failed to bind", future.cause());
}
channel = future.channel();
Future<?> channelzFuture = channel.eventLoop().submit(new Runnable() {
channel.eventLoop().execute(new Runnable() {
@Override
public void run() {
InternalInstrumented<SocketStats> listenSocket = new ListenSocket(channel);
listenSocketStats.set(listenSocket);
channelz.addListenSocket(listenSocket);
listenSocketStats = new ListenSocket(channel);
channelz.addListenSocket(listenSocketStats);
}
});
try {
channelzFuture.await();
} catch (InterruptedException ex) {
throw new RuntimeException("Interrupted while registering listen socket to channelz", ex);
}
}

@Override
Expand All @@ -278,7 +270,8 @@ public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
log.log(Level.WARNING, "Error shutting down server", future.cause());
}
InternalInstrumented<SocketStats> stats = listenSocketStats.getAndSet(null);
InternalInstrumented<SocketStats> stats = listenSocketStats;
listenSocketStats = null;
if (stats != null) {
channelz.removeListenSocket(stats);
}
Expand Down
46 changes: 32 additions & 14 deletions netty/src/test/java/io/grpc/netty/NettyServerTest.java
Expand Up @@ -29,16 +29,19 @@
import io.grpc.InternalInstrumented;
import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;
import io.grpc.internal.FixedObjectPool;
import io.grpc.internal.ServerListener;
import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerTransport;
import io.grpc.internal.ServerTransportListener;
import io.grpc.internal.SharedResourcePool;
import io.grpc.internal.TransportTracer;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelOption;
import io.netty.channel.ReflectiveChannelFactory;
import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.AsciiString;
import java.net.InetSocketAddress;
import java.net.Socket;
Expand All @@ -48,13 +51,21 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.After;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class NettyServerTest {
private final InternalChannelz channelz = new InternalChannelz();
private final NioEventLoopGroup eventLoop = new NioEventLoopGroup(1);

@After
public void tearDown() throws Exception {
eventLoop.shutdownGracefully(0, 0, TimeUnit.SECONDS);
eventLoop.awaitTermination(5, TimeUnit.SECONDS);
}

@Test
public void startStop() throws Exception {
Expand All @@ -79,10 +90,10 @@ class TestProtocolNegotiator implements ProtocolNegotiator {
TestProtocolNegotiator protocolNegotiator = new TestProtocolNegotiator();
NettyServer ns = new NettyServer(
addr,
Utils.DEFAULT_SERVER_CHANNEL_FACTORY,
new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(),
SharedResourcePool.forResource(Utils.DEFAULT_BOSS_EVENT_LOOP_GROUP),
SharedResourcePool.forResource(Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP),
new FixedObjectPool<>(eventLoop),
new FixedObjectPool<>(eventLoop),
protocolNegotiator,
Collections.<ServerStreamTracer.Factory>emptyList(),
TransportTracer.getDefaultFactory(),
Expand Down Expand Up @@ -119,14 +130,14 @@ public void serverShutdown() {
}

@Test
public void getPort_notStarted() throws Exception {
public void getPort_notStarted() {
InetSocketAddress addr = new InetSocketAddress(0);
NettyServer ns = new NettyServer(
addr,
Utils.DEFAULT_SERVER_CHANNEL_FACTORY,
new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(),
SharedResourcePool.forResource(Utils.DEFAULT_BOSS_EVENT_LOOP_GROUP),
SharedResourcePool.forResource(Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP),
new FixedObjectPool<>(eventLoop),
new FixedObjectPool<>(eventLoop),
ProtocolNegotiators.plaintext(),
Collections.<ServerStreamTracer.Factory>emptyList(),
TransportTracer.getDefaultFactory(),
Expand Down Expand Up @@ -161,10 +172,10 @@ public void childChannelOptions() throws Exception {
InetSocketAddress addr = new InetSocketAddress(0);
NettyServer ns = new NettyServer(
addr,
Utils.DEFAULT_SERVER_CHANNEL_FACTORY,
new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
channelOptions,
SharedResourcePool.forResource(Utils.DEFAULT_BOSS_EVENT_LOOP_GROUP),
SharedResourcePool.forResource(Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP),
new FixedObjectPool<>(eventLoop),
new FixedObjectPool<>(eventLoop),
ProtocolNegotiators.plaintext(),
Collections.<ServerStreamTracer.Factory>emptyList(),
TransportTracer.getDefaultFactory(),
Expand Down Expand Up @@ -211,10 +222,10 @@ public void channelzListenSocket() throws Exception {
InetSocketAddress addr = new InetSocketAddress(0);
NettyServer ns = new NettyServer(
addr,
Utils.DEFAULT_SERVER_CHANNEL_FACTORY,
new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
new HashMap<ChannelOption<?>, Object>(),
SharedResourcePool.forResource(Utils.DEFAULT_BOSS_EVENT_LOOP_GROUP),
SharedResourcePool.forResource(Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP),
new FixedObjectPool<>(eventLoop),
new FixedObjectPool<>(eventLoop),
ProtocolNegotiators.plaintext(),
Collections.<ServerStreamTracer.Factory>emptyList(),
TransportTracer.getDefaultFactory(),
Expand All @@ -239,8 +250,15 @@ public void serverShutdown() {
shutdownCompleted.set(null);
}
});

assertThat(((InetSocketAddress) ns.getListenSocketAddress()).getPort()).isGreaterThan(0);

// SocketStats won't be available until the event loop task of adding SocketStats created by
// ns.start() complete. So submit a noop task and await until it's drained.
eventLoop.submit(new Runnable() {
@Override
public void run() {}
}).await(5, TimeUnit.SECONDS);
InternalInstrumented<SocketStats> listenSocket = ns.getListenSocketStats();
assertSame(listenSocket, channelz.getSocket(id(listenSocket)));

Expand Down