From 115257b6f39c69fee7faf2ba66f414ab4552d8ee Mon Sep 17 00:00:00 2001 From: Pierre De Rop Date: Wed, 3 Aug 2022 13:44:01 +0200 Subject: [PATCH] Revise HttpMetricsHandlerTests and fix testServerConnectionsRecorder (#2391) Fixed race condition in HttpMetricsHandlerTests.testServerConnectionsRecorder, and refactored the HttpMetricsHandlerTests tests by suppressing observeDisconnect usage. Fixes #2368 --- .../netty/http/HttpMetricsHandlerTests.java | 417 ++++++++++++------ 1 file changed, 287 insertions(+), 130 deletions(-) diff --git a/reactor-netty-http/src/test/java/reactor/netty/http/HttpMetricsHandlerTests.java b/reactor-netty-http/src/test/java/reactor/netty/http/HttpMetricsHandlerTests.java index 6c68c8a17b..9f3eed711b 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/http/HttpMetricsHandlerTests.java +++ b/reactor-netty-http/src/test/java/reactor/netty/http/HttpMetricsHandlerTests.java @@ -23,10 +23,19 @@ import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http2.HttpConversionUtil; import io.netty.handler.ssl.SslProvider; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; @@ -39,7 +48,7 @@ import reactor.core.publisher.Mono; import reactor.netty.BaseHttpTest; import reactor.netty.ByteBufFlux; -import reactor.netty.ConnectionObserver; +import reactor.netty.NettyPipeline; import reactor.netty.http.client.ContextAwareHttpClientMetricsRecorder; import reactor.netty.http.client.HttpClient; import reactor.netty.http.server.ContextAwareHttpServerMetricsRecorder; @@ -65,6 +74,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; @@ -107,6 +117,7 @@ class HttpMetricsHandlerTests extends BaseHttpTest { @BeforeAll static void createSelfSignedCertificate() throws CertificateException { + Assertions.setMaxStackTraceElementsDisplayed(100); ssc = new SelfSignedCertificate(); serverCtx11 = Http11SslContextSpec.forServer(ssc.certificate(), ssc.privateKey()) .configure(builder -> builder.sslProvider(SslProvider.JDK)); @@ -169,18 +180,19 @@ void tearDown() { @MethodSource("httpCompatibleProtocols") void testExistingEndpoint(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols, @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx) throws Exception { - CountDownLatch latch1 = new CountDownLatch(4); // expect to observe 2 server disconnect + 2 client disconnect events - AtomicReference latchRef = new AtomicReference<>(latch1); - ConnectionObserver observerDisconnect = observeDisconnect(latchRef); - + CountDownLatch responseSent = new CountDownLatch(1); // response fully sent by the server + AtomicReference responseSentRef = new AtomicReference<>(responseSent); + ResponseSentHandler responseSentHandler = ResponseSentHandler.INSTANCE; disposableServer = customizeServerOptions(httpServer, serverCtx, serverProtocols) - .childObserve(observerDisconnect) + .doOnConnection(cnx -> responseSentHandler.register(responseSentRef, cnx.channel().pipeline())) .bindNow(); AtomicReference serverAddress = new AtomicReference<>(); - httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols).doAfterRequest((req, conn) -> - serverAddress.set(conn.channel().remoteAddress()) - ).observe(observerDisconnect); + CountDownLatch clientCompleted = new CountDownLatch(1); // client received full response + AtomicReference clientCompletedRef = new AtomicReference<>(clientCompleted); + httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols) + .doAfterRequest((req, conn) -> serverAddress.set(conn.channel().remoteAddress())) + .doAfterResponseSuccess((resp, conn) -> clientCompletedRef.get().countDown()); StepVerifier.create(httpClient.post() .uri("/1") @@ -192,7 +204,8 @@ void testExistingEndpoint(HttpProtocol[] serverProtocols, HttpProtocol[] clientP .expectComplete() .verify(Duration.ofSeconds(30)); - assertThat(latch1.await(30, TimeUnit.SECONDS)).as("latch await").isTrue(); + assertThat(responseSentRef.get().await(30, TimeUnit.SECONDS)).as("responseSentRef latch await").isTrue(); + assertThat(clientCompletedRef.get().await(30, TimeUnit.SECONDS)).as("clientCompletedRef latch await").isTrue(); InetSocketAddress sa = (InetSocketAddress) serverAddress.get(); @@ -213,8 +226,8 @@ else if (clientProtocols.length == 2 && checkExpectationsExisting("/1", sa.getHostString() + ":" + sa.getPort(), 1, serverCtx != null, numWrites[0], bytesWrite[0]); - CountDownLatch latch2 = new CountDownLatch(4); // expect to observe 2 server disconnect + 2 client disconnect events - latchRef.set(latch2); + responseSentRef.set(new CountDownLatch(1)); + clientCompletedRef.set(new CountDownLatch(1)); StepVerifier.create(httpClient.post() .uri("/2?i=1&j=2") @@ -226,7 +239,8 @@ else if (clientProtocols.length == 2 && .expectComplete() .verify(Duration.ofSeconds(30)); - assertThat(latch2.await(30, TimeUnit.SECONDS)).as("latch await").isTrue(); + assertThat(responseSentRef.get().await(30, TimeUnit.SECONDS)).as("responseSentRef latch await").isTrue(); + assertThat(clientCompletedRef.get().await(30, TimeUnit.SECONDS)).as("clientCompletedRef latch await").isTrue(); sa = (InetSocketAddress) serverAddress.get(); @@ -285,23 +299,28 @@ void testRecordingFailsClientSide(HttpProtocol[] serverProtocols, HttpProtocol[] @MethodSource("httpCompatibleProtocols") void testNonExistingEndpoint(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols, @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx) throws Exception { - // For HTTP11, we expect to observe 2 DISCONNECTS for client, and 2 DISCONNECT for server. - // Else, we expect to observe 2 DISCONNECTS for client, and 1 DISCONNECT for server. - boolean isHTTP11 = clientProtocols.length == 1 && clientProtocols[0] == HttpProtocol.HTTP11; - int expectedDisconnects = isHTTP11 ? 4 : 3; - - CountDownLatch latch = new CountDownLatch(expectedDisconnects); - AtomicReference latchRef = new AtomicReference<>(latch); - ConnectionObserver observerDisconnect = observeDisconnect(latchRef); - + CountDownLatch responseSent = new CountDownLatch(1); // response fully sent by the server + AtomicReference responseSentRef = new AtomicReference<>(responseSent); + ResponseSentHandler responseSentHandler = ResponseSentHandler.INSTANCE; + CountDownLatch requestReceived = new CountDownLatch(1); // request fully received by the server + AtomicReference requestReceivedRef = new AtomicReference<>(requestReceived); + RequestReceivedHandler requestReceivedHandler = RequestReceivedHandler.INSTANCE; + + // the requestReceivedHandler is used to detect when the server has received the last client request content + // the responseSentHandler is used to detect when the server has sent the last response content disposableServer = customizeServerOptions(httpServer, serverCtx, serverProtocols) - .childObserve(observerDisconnect) + .doOnConnection(cnx -> { + responseSentHandler.register(responseSentRef, cnx.channel().pipeline()); + requestReceivedHandler.register(requestReceivedRef, cnx.channel().pipeline()); + }) .bindNow(); AtomicReference serverAddress = new AtomicReference<>(); - httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols).doAfterRequest((req, conn) -> - serverAddress.set(conn.channel().remoteAddress()) - ).observe(observerDisconnect); + CountDownLatch clientCompleted = new CountDownLatch(1); + AtomicReference clientCompletedRef = new AtomicReference<>(clientCompleted); + httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols) + .doAfterRequest((req, conn) -> serverAddress.set(conn.channel().remoteAddress())) + .doAfterResponseSuccess((rsp, conn) -> clientCompletedRef.get().countDown()); StepVerifier.create(httpClient .headers(h -> h.add("Connection", "close")) @@ -313,7 +332,9 @@ void testNonExistingEndpoint(HttpProtocol[] serverProtocols, HttpProtocol[] clie .expectComplete() .verify(Duration.ofSeconds(30)); - assertThat(latch.await(30, TimeUnit.SECONDS)).as("latch await").isTrue(); + assertThat(requestReceivedRef.get().await(30, TimeUnit.SECONDS)).as("requestReceivedRef latch await").isTrue(); + assertThat(responseSentRef.get().await(30, TimeUnit.SECONDS)).as("responseSentRef latch await").isTrue(); + assertThat(clientCompletedRef.get().await(30, TimeUnit.SECONDS)).as("clientCompletedRef latch await").isTrue(); InetSocketAddress sa = (InetSocketAddress) serverAddress.get(); @@ -343,8 +364,9 @@ else if (protocols.contains(HttpProtocol.H2) || protocols.contains(HttpProtocol. checkExpectationsNonExisting(sa.getHostString() + ":" + sa.getPort(), 1, 1, serverCtx != null, numWrites[0], numReads[0], bytesWrite[0], bytesRead[0]); - CountDownLatch latch2 = new CountDownLatch(expectedDisconnects); - latchRef.set(latch2); + requestReceivedRef.set(new CountDownLatch(1)); + responseSentRef.set(new CountDownLatch(1)); + clientCompletedRef.set(new CountDownLatch(1)); StepVerifier.create(httpClient .headers(h -> h.add("Connection", "close")) @@ -356,7 +378,9 @@ else if (protocols.contains(HttpProtocol.H2) || protocols.contains(HttpProtocol. .expectComplete() .verify(Duration.ofSeconds(30)); - assertThat(latch2.await(30, TimeUnit.SECONDS)).as("latch await").isTrue(); + assertThat(requestReceivedRef.get().await(30, TimeUnit.SECONDS)).as("requestReceivedRef latch await").isTrue(); + assertThat(responseSentRef.get().await(30, TimeUnit.SECONDS)).as("responseSentRef latch await").isTrue(); + assertThat(clientCompletedRef.get().await(30, TimeUnit.SECONDS)).as("clientCompletedRef latch await").isTrue(); sa = (InetSocketAddress) serverAddress.get(); checkExpectationsNonExisting(sa.getHostString() + ":" + sa.getPort(), connIndex, 2, serverCtx != null, @@ -367,19 +391,19 @@ else if (protocols.contains(HttpProtocol.H2) || protocols.contains(HttpProtocol. @MethodSource("httpCompatibleProtocols") void testUriTagValueFunction(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols, @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx) throws Exception { - CountDownLatch latch1 = new CountDownLatch(4); // expect to observe 2 server disconnect + 2 client disconnect events - AtomicReference latchRef = new AtomicReference<>(latch1); - ConnectionObserver observerDisconnect = observeDisconnect(latchRef); + CountDownLatch responseSent = new CountDownLatch(1); // response fully sent by the server + CountDownLatch clientCompleted = new CountDownLatch(1); // client received full response + ResponseSentHandler responseSentHandler = ResponseSentHandler.INSTANCE; disposableServer = customizeServerOptions(httpServer, serverCtx, serverProtocols) + .doOnConnection(cnx -> responseSentHandler.register(responseSent, cnx.channel().pipeline())) .metrics(true, s -> "testUriTagValueResolver") - .childObserve(observerDisconnect) .bindNow(); AtomicReference serverAddress = new AtomicReference<>(); - httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols).doAfterRequest((req, conn) -> - serverAddress.set(conn.channel().remoteAddress()) - ).observe(observerDisconnect); + httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols) + .doAfterResponseSuccess((res, conn) -> clientCompleted.countDown()) + .doAfterRequest((req, conn) -> serverAddress.set(conn.channel().remoteAddress())); StepVerifier.create(httpClient.metrics(true, s -> "testUriTagValueResolver") .post() @@ -392,7 +416,8 @@ void testUriTagValueFunction(HttpProtocol[] serverProtocols, HttpProtocol[] clie .expectComplete() .verify(Duration.ofSeconds(30)); - assertThat(latch1.await(30, TimeUnit.SECONDS)).as("latch await").isTrue(); + assertThat(responseSent.await(30, TimeUnit.SECONDS)).as("responseSent latch await").isTrue(); + assertThat(clientCompleted.await(30, TimeUnit.SECONDS)).as("clientCompleted latch await").isTrue(); InetSocketAddress sa = (InetSocketAddress) serverAddress.get(); @@ -418,27 +443,29 @@ else if (clientProtocols.length == 2 && @MethodSource("httpCompatibleProtocols") void testUriTagValueFunctionNotSharedForClient(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols, @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx) throws Exception { - CountDownLatch latch1 = new CountDownLatch(4); // expect to observe 2 server disconnect + 2 client disconnect events - AtomicReference latchRef = new AtomicReference<>(latch1); - ConnectionObserver observerDisconnect = observeDisconnect(latchRef); - + CountDownLatch responseSent = new CountDownLatch(1); // response fully sent by the server + AtomicReference responseSentRef = new AtomicReference<>(responseSent); + ResponseSentHandler responseSentHandler = ResponseSentHandler.INSTANCE; disposableServer = - customizeServerOptions(httpServer, serverCtx, serverProtocols).metrics(true, - s -> { - if ("/1".equals(s)) { - return "testUriTagValueFunctionNotShared_1"; - } - else { - return "testUriTagValueFunctionNotShared_2"; - } - }) - .childObserve(observerDisconnect) + customizeServerOptions(httpServer, serverCtx, serverProtocols) + .doOnConnection(cnx -> responseSentHandler.register(responseSentRef, cnx.channel().pipeline())) + .metrics(true, + s -> { + if ("/1".equals(s)) { + return "testUriTagValueFunctionNotShared_1"; + } + else { + return "testUriTagValueFunctionNotShared_2"; + } + }) .bindNow(); + CountDownLatch clientCompleted = new CountDownLatch(1); // client received full response + AtomicReference clientCompletedRef = new AtomicReference<>(clientCompleted); AtomicReference serverAddress = new AtomicReference<>(); - httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols).doAfterRequest((req, conn) -> - serverAddress.set(conn.channel().remoteAddress()) - ).observe(observerDisconnect); + httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols) + .doAfterRequest((req, conn) -> serverAddress.set(conn.channel().remoteAddress())) + .doAfterResponseSuccess((resp, conn) -> clientCompletedRef.get().countDown()); httpClient.metrics(true, s -> "testUriTagValueFunctionNotShared_1") .post() @@ -452,7 +479,8 @@ void testUriTagValueFunctionNotSharedForClient(HttpProtocol[] serverProtocols, H .expectComplete() .verify(Duration.ofSeconds(30)); - assertThat(latch1.await(30, TimeUnit.SECONDS)).as("latch await").isTrue(); + assertThat(responseSentRef.get().await(30, TimeUnit.SECONDS)).as("responseSentRef latch await").isTrue(); + assertThat(clientCompletedRef.get().await(30, TimeUnit.SECONDS)).as("clientCompletedRef latch await").isTrue(); InetSocketAddress sa = (InetSocketAddress) serverAddress.get(); @@ -470,8 +498,8 @@ else if (clientProtocols.length == 2 && checkExpectationsExisting("testUriTagValueFunctionNotShared_1", sa.getHostString() + ":" + sa.getPort(), 1, serverCtx != null, numWrites[0], bytesWrite[0]); - CountDownLatch latch2 = new CountDownLatch(4); // expect to observe 2 server disconnect + 2 client disconnect events - latchRef.set(latch2); + responseSentRef.set(new CountDownLatch(1)); + clientCompletedRef.set(new CountDownLatch(1)); httpClient.metrics(true, s -> "testUriTagValueFunctionNotShared_2") .post() @@ -485,7 +513,8 @@ else if (clientProtocols.length == 2 && .expectComplete() .verify(Duration.ofSeconds(30)); - assertThat(latch2.await(30, TimeUnit.SECONDS)).as("latch await").isTrue(); + assertThat(responseSentRef.get().await(30, TimeUnit.SECONDS)).as("responseSentRef latch await").isTrue(); + assertThat(clientCompletedRef.get().await(30, TimeUnit.SECONDS)).as("clientCompletedRef await").isTrue(); sa = (InetSocketAddress) serverAddress.get(); @@ -496,16 +525,13 @@ else if (clientProtocols.length == 2 && @ParameterizedTest @MethodSource("httpCompatibleProtocols") void testContextAwareRecorderOnClient(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols, - @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx) throws Exception { + @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx) { disposableServer = customizeServerOptions(httpServer, serverCtx, serverProtocols).bindNow(); ClientContextAwareRecorder recorder = ClientContextAwareRecorder.INSTANCE; - CountDownLatch latch = new CountDownLatch(1); httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols); - httpClient.doOnResponse((res, conn) -> conn.channel() - .closeFuture() - .addListener(f -> latch.countDown())) - .metrics(true, () -> recorder) + + httpClient.metrics(true, () -> recorder) .post() .uri("/1") .send(body) @@ -518,8 +544,6 @@ void testContextAwareRecorderOnClient(HttpProtocol[] serverProtocols, HttpProtoc .expectComplete() .verify(Duration.ofSeconds(30)); - assertThat(latch.await(30, TimeUnit.SECONDS)).as("latch await").isTrue(); - assertThat(recorder.onDataReceivedContextView).isTrue(); assertThat(recorder.onDataSentContextView).isTrue(); } @@ -528,18 +552,18 @@ void testContextAwareRecorderOnClient(HttpProtocol[] serverProtocols, HttpProtoc @MethodSource("httpCompatibleProtocols") void testContextAwareRecorderOnServer(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols, @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx) throws Exception { + CountDownLatch responseSent = new CountDownLatch(1); // response fully sent by the server ServerContextAwareRecorder recorder = ServerContextAwareRecorder.INSTANCE; + ResponseSentHandler responseSentHandler = ResponseSentHandler.INSTANCE; disposableServer = customizeServerOptions(httpServer, serverCtx, serverProtocols).metrics(true, () -> recorder) - .mapHandle((mono, conn) -> mono.contextWrite(Context.of("testContextAwareRecorder", "OK"))) - .bindNow(); + .doOnConnection(cnx -> responseSentHandler.register(responseSent, cnx.channel().pipeline())) + .mapHandle((mono, conn) -> mono.contextWrite(Context.of("testContextAwareRecorder", "OK"))) + .bindNow(); - CountDownLatch latch = new CountDownLatch(1); httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols); - httpClient.doOnResponse((res, conn) -> conn.channel() - .closeFuture() - .addListener(f -> latch.countDown())) - .post() + + httpClient.post() .uri("/1") .send(body) .responseContent() @@ -550,7 +574,7 @@ void testContextAwareRecorderOnServer(HttpProtocol[] serverProtocols, HttpProtoc .expectComplete() .verify(Duration.ofSeconds(30)); - assertThat(latch.await(30, TimeUnit.SECONDS)).as("latch await").isTrue(); + assertThat(responseSent.await(30, TimeUnit.SECONDS)).as("responseSent latch await").isTrue(); assertThat(recorder.onDataReceivedContextView).isTrue(); assertThat(recorder.onDataSentContextView).isTrue(); @@ -560,20 +584,22 @@ void testContextAwareRecorderOnServer(HttpProtocol[] serverProtocols, HttpProtoc @MethodSource("httpCompatibleProtocols") void testServerConnectionsMicrometer(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols, @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx) throws Exception { - CountDownLatch latch = new CountDownLatch(4); // expect to observe 2 server disconnect + 2 client disconnect events - AtomicReference latchRef = new AtomicReference<>(latch); - ConnectionObserver observerDisconnect = observeDisconnect(latchRef); - + CountDownLatch responseSent = new CountDownLatch(1); // response fully sent by the server + CountDownLatch serverClosed = new CountDownLatch(1); // socket closed on the server side + ResponseSentHandler responseSentHandler = ResponseSentHandler.INSTANCE; + ServerCloseHandler serverCloseHandler = ServerCloseHandler.INSTANCE; boolean isHttp11 = clientProtocols.length == 1 && clientProtocols[0] == HttpProtocol.HTTP11; - disposableServer = customizeServerOptions(httpServer, serverCtx, serverProtocols) + HttpServer server = customizeServerOptions(httpServer, serverCtx, serverProtocols) .metrics(true, Function.identity()) - .childObserve(observerDisconnect) - .bindNow(); + .doOnConnection(cnx -> responseSentHandler.register(responseSent, cnx.channel().pipeline())); + + server = isHttp11 ? + server.doOnChannelInit((cnxObs, ch, sockAddr) -> serverCloseHandler.register(ch, serverClosed, isHttp11)) : server; + disposableServer = server.bindNow(); AtomicReference clientAddress = new AtomicReference<>(); - httpClient = httpClient.doAfterRequest((req, conn) -> - clientAddress.set(conn.channel().localAddress()) - ).observe(observerDisconnect); + httpClient = httpClient + .doAfterRequest((req, conn) -> clientAddress.set(conn.channel().localAddress())); String uri = "/4"; String address = formatSocketAddress(disposableServer.address()); @@ -592,10 +618,11 @@ void testServerConnectionsMicrometer(HttpProtocol[] serverProtocols, HttpProtoco .expectComplete() .verify(Duration.ofSeconds(30)); - assertThat(latch.await(30, TimeUnit.SECONDS)).as("latch await").isTrue(); + assertThat(responseSent.await(30, TimeUnit.SECONDS)).as("responseSent latch await").isTrue(); // now check the server counters if (isHttp11) { + assertThat(serverClosed.await(30, TimeUnit.SECONDS)).as("serverClosed latch await").isTrue(); checkGauge(SERVER_CONNECTIONS_TOTAL, true, 0, URI, HTTP, LOCAL_ADDRESS, address); checkGauge(SERVER_CONNECTIONS_ACTIVE, true, 0, URI, HTTP, LOCAL_ADDRESS, address); } @@ -621,22 +648,18 @@ void testServerConnectionsRecorder(HttpProtocol[] serverProtocols, HttpProtocol[ // ServerRecorder.INSTANCE.reset() (AfterEach) and thus leave ServerRecorder.INSTANCE in a bad state ServerRecorder.INSTANCE.reset(); boolean isHttp11 = clientProtocols.length == 1 && clientProtocols[0] == HttpProtocol.HTTP11; + CountDownLatch serverClosed = new CountDownLatch(1); + ServerCloseHandler serverCloseHandler = ServerCloseHandler.INSTANCE; + disposableServer = customizeServerOptions(httpServer, serverCtx, serverProtocols) - .metrics(true, () -> { - ServerRecorder.INSTANCE.done = isHttp11 ? new CountDownLatch(4) : new CountDownLatch(1); - return ServerRecorder.INSTANCE; - }, - Function.identity()) + .doOnConnection(c -> serverCloseHandler.register(c.channel(), serverClosed, isHttp11)) + .metrics(true, ServerRecorder.supplier(), Function.identity()) .bindNow(); String address = formatSocketAddress(disposableServer.address()); - CountDownLatch latch = new CountDownLatch(1); - httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols); - httpClient.doOnResponse((res, conn) -> - conn.channel() - .closeFuture() - .addListener(f -> latch.countDown())) + + httpClient .metrics(true, Function.identity()) .post() .uri("/5") @@ -649,8 +672,15 @@ void testServerConnectionsRecorder(HttpProtocol[] serverProtocols, HttpProtocol[ .expectComplete() .verify(Duration.ofSeconds(30)); - assertThat(latch.await(30, TimeUnit.SECONDS)).as("latch await").isTrue(); - assertThat(ServerRecorder.INSTANCE.done.await(30, TimeUnit.SECONDS)).as("recorder latch await").isTrue(); + // dispose the client connection provider now, before asserting test expectations. + provider.disposeLater() + .block(Duration.ofSeconds(30)); + + // now the socket is closed, wait for the ServerRecorder to be called in recordServerConnectionClosed before asserting test expectations + assertThat(serverClosed.await(30, TimeUnit.SECONDS)).as("serverClosed latch await").isTrue(); + + // now we can assert test expectations + assertThat(ServerRecorder.INSTANCE.error.get()).isNull(); if (isHttp11) { assertThat(ServerRecorder.INSTANCE.onServerConnectionsAmount.get()).isEqualTo(0); assertThat(ServerRecorder.INSTANCE.onActiveConnectionsAmount.get()).isEqualTo(0); @@ -658,7 +688,7 @@ void testServerConnectionsRecorder(HttpProtocol[] serverProtocols, HttpProtocol[ assertThat(ServerRecorder.INSTANCE.onInactiveConnectionsLocalAddr.get()).isEqualTo(address); } else { - assertThat(ServerRecorder.INSTANCE.onServerConnectionsAmount.get()).isEqualTo(1); + assertThat(ServerRecorder.INSTANCE.onServerConnectionsAmount.get()).isEqualTo(0); } disposableServer.disposeNow(); @@ -669,12 +699,10 @@ void testIssue896() throws Exception { disposableServer = httpServer.noSSL() .bindNow(); - // the client will observe three DISCONNECT: one when a NotSSLRecordException is caught, - // one when DecoderException is caught, and one when the connection becomes inactive - CountDownLatch latch = new CountDownLatch(3); - AtomicReference latchRef = new AtomicReference<>(latch); + // The client should get two errors: NotSSLRecordException, and DecoderException. + CountDownLatch latch = new CountDownLatch(2); httpClient - .observe(observeDisconnect(latchRef)) + .doOnChannelInit((o, c, address) -> ClientExceptionHandler.INSTANCE.register(c, latch)) .secure(spec -> spec.sslContext(clientCtx11)) .post() .uri("/1") @@ -695,19 +723,20 @@ void testIssue896() throws Exception { @MethodSource("http11CompatibleProtocols") void testBadRequest(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols, @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx) throws Exception { - CountDownLatch latch1 = new CountDownLatch(4); // expect to observe 2 server disconnect + 2 client disconnect events - AtomicReference latchRef = new AtomicReference<>(latch1); - ConnectionObserver observerDisconnect = observeDisconnect(latchRef); + CountDownLatch serverClosed = new CountDownLatch(1); + CountDownLatch clientCompleted = new CountDownLatch(1); + boolean isHttp11 = clientProtocols.length == 1 && clientProtocols[0] == HttpProtocol.HTTP11; + ServerCloseHandler serverCloseHandler = ServerCloseHandler.INSTANCE; disposableServer = customizeServerOptions(httpServer, serverCtx, serverProtocols) + .doOnChannelInit((obs, c, s) -> serverCloseHandler.register(c, serverClosed, isHttp11)) .httpRequestDecoder(spec -> spec.maxHeaderSize(32)) - .childObserve(observerDisconnect) .bindNow(); AtomicReference serverAddress = new AtomicReference<>(); - httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols).doAfterRequest((req, conn) -> - serverAddress.set(conn.channel().remoteAddress()) - ).observe(observerDisconnect); + httpClient = customizeClientOptions(httpClient, clientCtx, clientProtocols) + .doAfterRequest((req, conn) -> serverAddress.set(conn.channel().remoteAddress())) + .doAfterResponseSuccess((resp, conn) -> clientCompleted.countDown()); httpClient.get() .uri("/max_header_size") @@ -717,21 +746,20 @@ void testBadRequest(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtoco .expectComplete() .verify(Duration.ofSeconds(30)); - assertThat(latch1.await(30, TimeUnit.SECONDS)).as("latch await").isTrue(); + // dispose the client connection provider now, before asserting test expectations. + provider.disposeLater() + .block(Duration.ofSeconds(30)); + // now the socket is closed, wait for the ServerRecorder to be called in recordServerConnectionClosed before asserting test expectations + assertThat(serverClosed.await(30, TimeUnit.SECONDS)).as("serverClosed latch await").isTrue(); + + // Ensure client has fully received the response before asserting test expectations + assertThat(clientCompleted.await(30, TimeUnit.SECONDS)).as("clientCompleted latch await").isTrue(); InetSocketAddress sa = (InetSocketAddress) serverAddress.get(); checkExpectationsBadRequest(sa.getHostString() + ":" + sa.getPort(), serverCtx != null); } - private ConnectionObserver observeDisconnect(AtomicReference latchRef) { - return (connection, state) -> { - if (state == ConnectionObserver.State.DISCONNECTING) { - latchRef.get().countDown(); - } - }; - } - private void checkServerConnectionsMicrometer(HttpServerRequest request) { String address = formatSocketAddress(request.hostAddress()); boolean isHttp2 = request.requestHeaders().contains(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text()); @@ -745,15 +773,20 @@ private void checkServerConnectionsMicrometer(HttpServerRequest request) { } private void checkServerConnectionsRecorder(HttpServerRequest request) { - String address = formatSocketAddress(request.hostAddress()); - boolean isHttp2 = request.requestHeaders().contains(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text()); - assertThat(ServerRecorder.INSTANCE.onServerConnectionsAmount.get()).isEqualTo(1); - assertThat(ServerRecorder.INSTANCE.onServerConnectionsLocalAddr.get()).isEqualTo(address); - if (!isHttp2) { - assertThat(ServerRecorder.INSTANCE.onActiveConnectionsAmount.get()).isEqualTo(1); - assertThat(ServerRecorder.INSTANCE.onActiveConnectionsLocalAddr.get()).isEqualTo(address); + try { + String address = formatSocketAddress(request.hostAddress()); + boolean isHttp2 = request.requestHeaders().contains(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text()); + assertThat(ServerRecorder.INSTANCE.onServerConnectionsAmount.get()).isEqualTo(1); + assertThat(ServerRecorder.INSTANCE.onServerConnectionsLocalAddr.get()).isEqualTo(address); + if (!isHttp2) { + assertThat(ServerRecorder.INSTANCE.onActiveConnectionsAmount.get()).isEqualTo(1); + assertThat(ServerRecorder.INSTANCE.onActiveConnectionsLocalAddr.get()).isEqualTo(address); + } + assertThat(ServerRecorder.INSTANCE.onInactiveConnectionsLocalAddr.get()).isNull(); + } + catch (Throwable error) { + ServerRecorder.INSTANCE.error.set(error); } - assertThat(ServerRecorder.INSTANCE.onInactiveConnectionsLocalAddr.get()).isNull(); } private void checkExpectationsExisting(String uri, String serverAddress, int connIndex, boolean checkTls, @@ -826,7 +859,6 @@ private void checkExpectationsNonExisting(String serverAddress, int connIndex, i checkCounter(CLIENT_ERRORS, summaryTags2, false, 0); } - private void checkExpectationsBadRequest(String serverAddress, boolean checkTls) { String uri = "/max_header_size"; String[] timerTags1 = new String[] {URI, uri, METHOD, "GET", STATUS, "413"}; @@ -1082,6 +1114,8 @@ public void recordResolveAddressTime(SocketAddress socketAddress, Duration durat static final class ServerRecorder implements HttpServerMetricsRecorder { static final ServerRecorder INSTANCE = new ServerRecorder(); + static final Supplier SUPPLIER = () -> INSTANCE; + private final AtomicReference error = new AtomicReference<>(); private final AtomicInteger onServerConnectionsAmount = new AtomicInteger(); private final AtomicReference onServerConnectionsLocalAddr = new AtomicReference<>(); private final AtomicReference onActiveConnectionsLocalAddr = new AtomicReference<>(); @@ -1089,7 +1123,12 @@ static final class ServerRecorder implements HttpServerMetricsRecorder { private final AtomicInteger onActiveConnectionsAmount = new AtomicInteger(); private volatile CountDownLatch done = new CountDownLatch(4); + static Supplier supplier() { + return SUPPLIER; + } + void reset() { + error.set(null); onServerConnectionsAmount.set(0); onServerConnectionsLocalAddr.set(null); onActiveConnectionsLocalAddr.set(null); @@ -1174,4 +1213,122 @@ public void recordConnectTime(SocketAddress socketAddress, Duration duration, St public void recordResolveAddressTime(SocketAddress socketAddress, Duration duration, String s) { } } + + /** + * Server Handler used to detect when the last http response content has been sent to the client. + * Handler placed before the HttpMetricsHandler on the Server pipeline. + * Metrics are up-to-date when the latch is counted down. + */ + static final class ResponseSentHandler extends ChannelOutboundHandlerAdapter { + final static String HANDLER_NAME = "ServerCompletedHandler.handler"; + final static ResponseSentHandler INSTANCE = new ResponseSentHandler(); + AtomicReference latchRef; + + void register(AtomicReference latchRef, ChannelPipeline pipeline) { + this.latchRef = latchRef; + pipeline.addBefore(NettyPipeline.HttpMetricsHandler, HANDLER_NAME, this); + } + + void register(CountDownLatch latch, ChannelPipeline pipeline) { + register(new AtomicReference<>(latch), pipeline); + } + + @Override + @SuppressWarnings("FutureReturnValueIgnored") + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (msg instanceof LastHttpContent) { + promise.addListener(future -> latchRef.get().countDown()); + } + + ctx.write(msg, promise); + } + + @Override + public boolean isSharable() { + return true; // A server may accept multiple connections, hence this handler must be sharable + } + } + + /** + * Server Handler used to detect when the last http client request content has been received by the server. + * Handler placed after the HttpMetricsHandler on the Server pipeline. + * Metrics are up-to-date when the latch is counted down. + */ + static final class RequestReceivedHandler extends ChannelInboundHandlerAdapter { + final static RequestReceivedHandler INSTANCE = new RequestReceivedHandler(); + final static String HANDLER_NAME = "ServerReceivedHandler.handler"; + AtomicReference latchRef; + + void register(AtomicReference latchRef, ChannelPipeline pipeline) { + this.latchRef = latchRef; + pipeline.addAfter(NettyPipeline.HttpMetricsHandler, HANDLER_NAME, this); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof LastHttpContent) { + latchRef.get().countDown(); + } + ctx.fireChannelRead(msg); + } + + @Override + public boolean isSharable() { + return true; + } + } + + /** + * Server handler used to wait until the client socket is closed on the server side. + * For HTTP1.1, the handler is placed before the ReactorBridge, so all previous handlers will see + * the close before this handler. For HTTP2, the handler is placed lastly on the pipeline. + */ + static final class ServerCloseHandler extends ChannelInboundHandlerAdapter { + static final ServerCloseHandler INSTANCE = new ServerCloseHandler(); + static final String HANDLER_NAME = "ServerCloseHandler.handler"; + private CountDownLatch latch; + + void register(Channel channel, CountDownLatch latch, boolean http11) { + this.latch = latch; + + if (http11) { + channel.pipeline().addBefore(NettyPipeline.ReactiveBridge, HANDLER_NAME, this); + } + else { + channel.parent().pipeline().addLast(HANDLER_NAME, this); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + latch.countDown(); + ctx.fireChannelInactive(); + } + + @Override + public boolean isSharable() { + return true; + } + } + + /** + * Handler used to get notified when an exception occurs on the HttpClientMetricsHandler. This handler is placed + * after the reactor.left.httpMetricsHandler. + */ + static final class ClientExceptionHandler extends ChannelDuplexHandler { + static final ClientExceptionHandler INSTANCE = new ClientExceptionHandler(); + static final String HANDLER_NAME = "ExceptionHandler.handler"; + private CountDownLatch latch; + + void register(Channel channel, CountDownLatch latch) { + this.latch = latch; + channel.pipeline().addAfter(NettyPipeline.HttpMetricsHandler, HANDLER_NAME, this); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + latch.countDown(); + ctx.fireExceptionCaught(cause); + } + } }