Skip to content

Commit

Permalink
Close the context
Browse files Browse the repository at this point in the history
  • Loading branch information
sorra committed Sep 17, 2023
1 parent 5b78a53 commit d7a917e
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 14 deletions.
Expand Up @@ -104,6 +104,8 @@ public void onCancel() {
super.onCancel();
} finally {
context.detach(previous);
// Cancel the timeout when the call is finished.
context.close();
}
}

Expand All @@ -114,6 +116,8 @@ public void onComplete() {
super.onComplete();
} finally {
context.detach(previous);
// Cancel the timeout when the call is finished.
context.close();
}
}

Expand Down
4 changes: 4 additions & 0 deletions util/src/main/java/io/grpc/util/ServerTimeoutManager.java
Expand Up @@ -84,6 +84,10 @@ public Context.CancellableContext startTimeoutContext(ServerCall<?, ?> serverCal
if (c.cancellationCause() == null) {
return;
}
if (logFunction != null) {
logFunction.accept("server call timeout for " +
serverCall.getMethodDescriptor().getFullMethodName());
}
serverCall.close(Status.CANCELLED.withDescription("server call timeout"), new Metadata());
};
Context.CancellableContext context = Context.current().withDeadline(
Expand Down
106 changes: 92 additions & 14 deletions util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java
Expand Up @@ -30,7 +30,10 @@
import io.grpc.StatusRuntimeException;
import io.grpc.stub.ServerCalls;
import io.grpc.stub.StreamObserver;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -83,13 +86,13 @@ public void invoke(Integer req, StreamObserver<Integer> responseObserver) {
public void unary_setShouldInterrupt_exceedingTimeout_isInterrupted() {
ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD);
ServerCallHandler<Integer, Integer> callHandler =
ServerCalls.asyncUnaryCall(sleepingUnaryMethod(100));
StringBuffer logBuf = new StringBuffer();
ServerCalls.asyncUnaryCall(sleepingUnaryMethod(20));
StringWriter logWriter = new StringWriter();

ServerTimeoutManager serverTimeoutManager =
ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS)
.setShouldInterrupt(true)
.setLogFunction(logBuf::append)
.setLogFunction(new PrintWriter(logWriter)::println)
.build();
ServerCall.Listener<Integer> listener = new ServerCallTimeoutInterceptor(serverTimeoutManager)
.interceptCall(serverCall, new Metadata(), callHandler);
Expand All @@ -101,19 +104,20 @@ public void unary_setShouldInterrupt_exceedingTimeout_isInterrupted() {
assertThat(serverCall.responses).isEmpty();
assertEquals(Status.Code.CANCELLED, serverCall.status.getCode());
assertEquals("server call timeout", serverCall.status.getDescription());
assertThat(logBuf.toString()).startsWith("Interrupted RPC thread ");
assertThat(logWriter.toString())
.startsWith("server call timeout for some/unary\nInterrupted RPC thread ");
}

@Test
public void unary_byDefault_exceedingTimeout_isCancelledButNotInterrupted() {
ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD);
ServerCallHandler<Integer, Integer> callHandler =
ServerCalls.asyncUnaryCall(sleepingUnaryMethod(100));
StringBuffer logBuf = new StringBuffer();
ServerCalls.asyncUnaryCall(sleepingUnaryMethod(20));
StringWriter logWriter = new StringWriter();

ServerTimeoutManager serverTimeoutManager =
ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS)
.setLogFunction(logBuf::append)
.setLogFunction(new PrintWriter(logWriter)::println)
.build();
ServerCall.Listener<Integer> listener = new ServerCallTimeoutInterceptor(serverTimeoutManager)
.interceptCall(serverCall, new Metadata(), callHandler);
Expand All @@ -125,20 +129,20 @@ public void unary_byDefault_exceedingTimeout_isCancelledButNotInterrupted() {
assertThat(serverCall.responses).isEmpty();
assertEquals(Status.Code.CANCELLED, serverCall.status.getCode());
assertEquals("server call timeout", serverCall.status.getDescription());
assertThat(logBuf.toString()).isEmpty();
assertEquals("server call timeout for some/unary\n", logWriter.toString());
}

@Test
public void unary_setShouldInterrupt_withinTimeout_isNotCancelledOrInterrupted() {
ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD);
ServerCallHandler<Integer, Integer> callHandler =
ServerCalls.asyncUnaryCall(sleepingUnaryMethod(0));
StringBuffer logBuf = new StringBuffer();
StringWriter logWriter = new StringWriter();

ServerTimeoutManager serverTimeoutManager =
ServerTimeoutManager.newBuilder(100, TimeUnit.MILLISECONDS)
.setShouldInterrupt(true)
.setLogFunction(logBuf::append)
.setLogFunction(new PrintWriter(logWriter)::println)
.build();
ServerCall.Listener<Integer> listener = new ServerCallTimeoutInterceptor(serverTimeoutManager)
.interceptCall(serverCall, new Metadata(), callHandler);
Expand All @@ -149,7 +153,7 @@ public void unary_setShouldInterrupt_withinTimeout_isNotCancelledOrInterrupted()

assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42));
assertEquals(Status.Code.OK, serverCall.status.getCode());
assertThat(logBuf.toString()).isEmpty();
assertThat(logWriter.toString()).isEmpty();
}

@Test
Expand All @@ -166,9 +170,7 @@ public void invoke(Integer req, StreamObserver<Integer> responseObserver) {
});

ServerTimeoutManager serverTimeoutManager =
ServerTimeoutManager.newBuilder(0, TimeUnit.NANOSECONDS)
.setShouldInterrupt(true)
.build();
ServerTimeoutManager.newBuilder(0, TimeUnit.NANOSECONDS).build();
ServerCall.Listener<Integer> listener = new ServerCallTimeoutInterceptor(serverTimeoutManager)
.interceptCall(serverCall, new Metadata(), callHandler);
serverTimeoutManager.shutdown();
Expand Down Expand Up @@ -199,6 +201,82 @@ public StreamObserver<Integer> invoke(StreamObserver<Integer> responseObserver)
ServerCallTimeoutInterceptor.TimeoutServerCallListener.class, listener.getClass());
}

@Test
public void allStagesCanKnowCancellation() throws Exception {
List<String> cancelledStages = Collections.synchronizedList(new ArrayList<>());
ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD);
ServerCallHandler<Integer, Integer> callHandler = new ServerCallHandler<Integer, Integer>() {
private final ServerCallHandler<Integer, Integer> innerHandler =
ServerCalls.asyncUnaryCall(sleepingUnaryMethod(0));

@Override
public ServerCall.Listener<Integer> startCall(ServerCall<Integer, Integer> call, Metadata headers) {
ServerCall.Listener<Integer> delegate = innerHandler.startCall(call, headers);
return new ServerCall.Listener<Integer>() {
@Override
public void onMessage(Integer message) {
if (Context.current().isCancelled()) {
cancelledStages.add("onMessage");
}
delegate.onMessage(message);
}

@Override
public void onHalfClose() {
if (Context.current().isCancelled()) {
cancelledStages.add("onHalfClose");
}
delegate.onHalfClose();
}

@Override
public void onCancel() {
if (Context.current().isCancelled()) {
cancelledStages.add("onCancel");
}
delegate.onCancel();
}

@Override
public void onComplete() {
if (Context.current().isCancelled()) {
cancelledStages.add("onComplete");
}
delegate.onComplete();
}

@Override
public void onReady() {
if (Context.current().isCancelled()) {
cancelledStages.add("onReady");
}
delegate.onReady();
}
};
}
};

ServerTimeoutManager serverTimeoutManager =
ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS).build();
ServerCall.Listener<Integer> listener = new ServerCallTimeoutInterceptor(serverTimeoutManager)
.interceptCall(serverCall, new Metadata(), callHandler);
// Let it timeout
Thread.sleep(20);
listener.onMessage(42);
listener.onHalfClose();
listener.onReady();
listener.onComplete();
listener.onCancel();
serverTimeoutManager.shutdown();

assertThat(serverCall.responses).isEmpty();
assertEquals(Status.Code.CANCELLED, serverCall.status.getCode());
assertEquals("server call timeout", serverCall.status.getDescription());
assertEquals(
Arrays.asList("onMessage", "onHalfClose", "onReady", "onComplete", "onCancel"),
cancelledStages);
}

private static class ServerCallRecorder extends ServerCall<Integer, Integer> {
private final MethodDescriptor<Integer, Integer> methodDescriptor;
private final List<Integer> requestCalls = new ArrayList<>();
Expand Down

0 comments on commit d7a917e

Please sign in to comment.