diff --git a/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java b/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java index d19fa6ddc6bb..f63ae3647732 100644 --- a/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java +++ b/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java @@ -104,6 +104,8 @@ public void onCancel() { super.onCancel(); } finally { context.detach(previous); + // Cancel the timeout when the call is finished. + context.close(); } } @@ -114,6 +116,8 @@ public void onComplete() { super.onComplete(); } finally { context.detach(previous); + // Cancel the timeout when the call is finished. + context.close(); } } diff --git a/util/src/main/java/io/grpc/util/ServerTimeoutManager.java b/util/src/main/java/io/grpc/util/ServerTimeoutManager.java index 6d767edd61d2..9fbbedb34be3 100644 --- a/util/src/main/java/io/grpc/util/ServerTimeoutManager.java +++ b/util/src/main/java/io/grpc/util/ServerTimeoutManager.java @@ -84,6 +84,10 @@ public Context.CancellableContext startTimeoutContext(ServerCall serverCal if (c.cancellationCause() == null) { return; } + if (logFunction != null) { + logFunction.accept("server call timeout for " + + serverCall.getMethodDescriptor().getFullMethodName()); + } serverCall.close(Status.CANCELLED.withDescription("server call timeout"), new Metadata()); }; Context.CancellableContext context = Context.current().withDeadline( diff --git a/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java b/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java index 9bc6df1a3df9..74805268eb6d 100644 --- a/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java +++ b/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java @@ -30,7 +30,10 @@ import io.grpc.StatusRuntimeException; import io.grpc.stub.ServerCalls; import io.grpc.stub.StreamObserver; +import java.io.PrintWriter; +import java.io.StringWriter; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.concurrent.TimeUnit; @@ -83,13 +86,13 @@ public void invoke(Integer req, StreamObserver responseObserver) { public void unary_setShouldInterrupt_exceedingTimeout_isInterrupted() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall(sleepingUnaryMethod(100)); - StringBuffer logBuf = new StringBuffer(); + ServerCalls.asyncUnaryCall(sleepingUnaryMethod(20)); + StringWriter logWriter = new StringWriter(); ServerTimeoutManager serverTimeoutManager = ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS) .setShouldInterrupt(true) - .setLogFunction(logBuf::append) + .setLogFunction(new PrintWriter(logWriter)::println) .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); @@ -101,19 +104,20 @@ public void unary_setShouldInterrupt_exceedingTimeout_isInterrupted() { assertThat(serverCall.responses).isEmpty(); assertEquals(Status.Code.CANCELLED, serverCall.status.getCode()); assertEquals("server call timeout", serverCall.status.getDescription()); - assertThat(logBuf.toString()).startsWith("Interrupted RPC thread "); + assertThat(logWriter.toString()) + .startsWith("server call timeout for some/unary\nInterrupted RPC thread "); } @Test public void unary_byDefault_exceedingTimeout_isCancelledButNotInterrupted() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall(sleepingUnaryMethod(100)); - StringBuffer logBuf = new StringBuffer(); + ServerCalls.asyncUnaryCall(sleepingUnaryMethod(20)); + StringWriter logWriter = new StringWriter(); ServerTimeoutManager serverTimeoutManager = ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS) - .setLogFunction(logBuf::append) + .setLogFunction(new PrintWriter(logWriter)::println) .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); @@ -125,7 +129,7 @@ public void unary_byDefault_exceedingTimeout_isCancelledButNotInterrupted() { assertThat(serverCall.responses).isEmpty(); assertEquals(Status.Code.CANCELLED, serverCall.status.getCode()); assertEquals("server call timeout", serverCall.status.getDescription()); - assertThat(logBuf.toString()).isEmpty(); + assertEquals("server call timeout for some/unary\n", logWriter.toString()); } @Test @@ -133,12 +137,12 @@ public void unary_setShouldInterrupt_withinTimeout_isNotCancelledOrInterrupted() ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = ServerCalls.asyncUnaryCall(sleepingUnaryMethod(0)); - StringBuffer logBuf = new StringBuffer(); + StringWriter logWriter = new StringWriter(); ServerTimeoutManager serverTimeoutManager = ServerTimeoutManager.newBuilder(100, TimeUnit.MILLISECONDS) .setShouldInterrupt(true) - .setLogFunction(logBuf::append) + .setLogFunction(new PrintWriter(logWriter)::println) .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); @@ -149,7 +153,7 @@ public void unary_setShouldInterrupt_withinTimeout_isNotCancelledOrInterrupted() assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); assertEquals(Status.Code.OK, serverCall.status.getCode()); - assertThat(logBuf.toString()).isEmpty(); + assertThat(logWriter.toString()).isEmpty(); } @Test @@ -166,9 +170,7 @@ public void invoke(Integer req, StreamObserver responseObserver) { }); ServerTimeoutManager serverTimeoutManager = - ServerTimeoutManager.newBuilder(0, TimeUnit.NANOSECONDS) - .setShouldInterrupt(true) - .build(); + ServerTimeoutManager.newBuilder(0, TimeUnit.NANOSECONDS).build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); serverTimeoutManager.shutdown(); @@ -199,6 +201,82 @@ public StreamObserver invoke(StreamObserver responseObserver) ServerCallTimeoutInterceptor.TimeoutServerCallListener.class, listener.getClass()); } + @Test + public void allStagesCanKnowCancellation() throws Exception { + List cancelledStages = Collections.synchronizedList(new ArrayList<>()); + ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); + ServerCallHandler callHandler = new ServerCallHandler() { + private final ServerCallHandler innerHandler = + ServerCalls.asyncUnaryCall(sleepingUnaryMethod(0)); + + @Override + public ServerCall.Listener startCall(ServerCall call, Metadata headers) { + ServerCall.Listener delegate = innerHandler.startCall(call, headers); + return new ServerCall.Listener() { + @Override + public void onMessage(Integer message) { + if (Context.current().isCancelled()) { + cancelledStages.add("onMessage"); + } + delegate.onMessage(message); + } + + @Override + public void onHalfClose() { + if (Context.current().isCancelled()) { + cancelledStages.add("onHalfClose"); + } + delegate.onHalfClose(); + } + + @Override + public void onCancel() { + if (Context.current().isCancelled()) { + cancelledStages.add("onCancel"); + } + delegate.onCancel(); + } + + @Override + public void onComplete() { + if (Context.current().isCancelled()) { + cancelledStages.add("onComplete"); + } + delegate.onComplete(); + } + + @Override + public void onReady() { + if (Context.current().isCancelled()) { + cancelledStages.add("onReady"); + } + delegate.onReady(); + } + }; + } + }; + + ServerTimeoutManager serverTimeoutManager = + ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS).build(); + ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) + .interceptCall(serverCall, new Metadata(), callHandler); + // Let it timeout + Thread.sleep(20); + listener.onMessage(42); + listener.onHalfClose(); + listener.onReady(); + listener.onComplete(); + listener.onCancel(); + serverTimeoutManager.shutdown(); + + assertThat(serverCall.responses).isEmpty(); + assertEquals(Status.Code.CANCELLED, serverCall.status.getCode()); + assertEquals("server call timeout", serverCall.status.getDescription()); + assertEquals( + Arrays.asList("onMessage", "onHalfClose", "onReady", "onComplete", "onCancel"), + cancelledStages); + } + private static class ServerCallRecorder extends ServerCall { private final MethodDescriptor methodDescriptor; private final List requestCalls = new ArrayList<>();