From 3798f5548afd0e4ec3cd88e56f37cf550310ed15 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Sun, 9 Jul 2023 22:05:21 +0800 Subject: [PATCH 01/22] Server-side timeout mechanism --- .../java/io/grpc/ServerTimeoutManager.java | 113 ++++++++++++++++++ .../io/grpc/TimeoutServerInterceptor.java | 73 +++++++++++ 2 files changed, 186 insertions(+) create mode 100644 api/src/main/java/io/grpc/ServerTimeoutManager.java create mode 100644 api/src/main/java/io/grpc/TimeoutServerInterceptor.java diff --git a/api/src/main/java/io/grpc/ServerTimeoutManager.java b/api/src/main/java/io/grpc/ServerTimeoutManager.java new file mode 100644 index 00000000000..4629863a2b5 --- /dev/null +++ b/api/src/main/java/io/grpc/ServerTimeoutManager.java @@ -0,0 +1,113 @@ +/* + * Copyright 2014 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +/** A global instance that schedules the timeout tasks. */ +public class ServerTimeoutManager { + private final int timeout; + private final TimeUnit unit; + + private final Consumer logFunction; + + private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); + + /** + * Creates a manager. Please make it a singleton and remember to shut it down. + * + * @param timeout Configurable timeout threshold. A value less than 0 (e.g. 0 or -1) means not to + * check timeout. + * @param unit The unit of the timeout. + * @param logFunction An optional function that can log (e.g. Logger::warn). Through this, + * we avoid depending on a specific logger library. + */ + public ServerTimeoutManager(int timeout, TimeUnit unit, Consumer logFunction) { + this.timeout = timeout; + this.unit = unit; + this.logFunction = logFunction; + } + + /** Please call shutdown() when the application exits. */ + public void shutdown() { + scheduler.shutdownNow(); + } + + /** + * Schedules a timeout and calls the RPC method invocation. + * Invalidates the timeout if the invocation completes in time. + * + * @param invocation The RPC method invocation that processes a request. + */ + public void intercept(Runnable invocation) { + if (timeout <= 0) { + invocation.run(); + return; + } + + TimeoutTask timeoutTask = schedule(Thread.currentThread()); + try { + invocation.run(); + } finally { + // If it completes in time, invalidate the timeout. + timeoutTask.invalidate(); + } + } + + private TimeoutTask schedule(Thread thread) { + TimeoutTask timeoutTask = new TimeoutTask(thread); + if (!scheduler.isShutdown()) { + scheduler.schedule(timeoutTask, timeout, unit); + } + return timeoutTask; + } + + private class TimeoutTask implements Runnable { + /** null thread means the task is invalid and will do nothing */ + private final AtomicReference threadReference = new AtomicReference<>(); + + private TimeoutTask(Thread thread) { + threadReference.set(thread); + } + + @Override + public void run() { + // Ensure the reference is consumed only once. + Thread thread = threadReference.getAndSet(null); + if (thread != null) { + thread.interrupt(); + if (logFunction != null) { + logFunction.accept( + "Interrupted RPC thread " + + thread.getName() + + " for timeout at " + + timeout + + " " + + unit); + } + } + } + + private void invalidate() { + threadReference.set(null); + } + } +} diff --git a/api/src/main/java/io/grpc/TimeoutServerInterceptor.java b/api/src/main/java/io/grpc/TimeoutServerInterceptor.java new file mode 100644 index 00000000000..6c9736c16ac --- /dev/null +++ b/api/src/main/java/io/grpc/TimeoutServerInterceptor.java @@ -0,0 +1,73 @@ +/* + * Copyright 2014 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +/** + * An optional ServerInterceptor that can interrupt server calls that are running for too long time. + * + *

You can add it to your server using the ServerBuilder#intercept(ServerInterceptor) method. + * + *

Limitation: it only applies the timeout to unary calls (streaming calls will run without timeout). + */ +public class TimeoutServerInterceptor implements ServerInterceptor { + + private final ServerTimeoutManager serverTimeoutManager; + + public TimeoutServerInterceptor(ServerTimeoutManager serverTimeoutManager) { + this.serverTimeoutManager = serverTimeoutManager; + } + + @Override + public ServerCall.Listener interceptCall( + ServerCall serverCall, + Metadata metadata, + ServerCallHandler serverCallHandler) { + return new TimeoutServerCallListener<>( + serverCallHandler.startCall(serverCall, metadata), serverCall, serverTimeoutManager); + } + + /** A listener that intercepts the RPC method invocation for timeout control. */ + private static class TimeoutServerCallListener + extends ForwardingServerCallListener.SimpleForwardingServerCallListener { + + private final ServerCall serverCall; + private final ServerTimeoutManager serverTimeoutManager; + + private TimeoutServerCallListener( + ServerCall.Listener delegate, + ServerCall serverCall, + ServerTimeoutManager serverTimeoutManager) { + super(delegate); + this.serverCall = serverCall; + this.serverTimeoutManager = serverTimeoutManager; + } + + /** + * Only intercepts unary calls because the timeout is inapplicable to streaming calls. + * Intercepts onHalfClose() because the RPC method is called in it. See + * io.grpc.stub.ServerCalls.UnaryServerCallHandler.UnaryServerCallListener + */ + @Override + public void onHalfClose() { + if (serverCall.getMethodDescriptor().getType().clientSendsOneMessage()) { + serverTimeoutManager.intercept(super::onHalfClose); + } else { + super.onHalfClose(); + } + } + } +} From f26f928197c6d7897185c1fc6a8a866690f8919d Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Mon, 17 Jul 2023 14:39:08 +0800 Subject: [PATCH 02/22] Move the unary call if-condition in TimeoutServerInterceptor --- .../io/grpc/TimeoutServerInterceptor.java | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/api/src/main/java/io/grpc/TimeoutServerInterceptor.java b/api/src/main/java/io/grpc/TimeoutServerInterceptor.java index 6c9736c16ac..e37f6a8557f 100644 --- a/api/src/main/java/io/grpc/TimeoutServerInterceptor.java +++ b/api/src/main/java/io/grpc/TimeoutServerInterceptor.java @@ -18,8 +18,9 @@ /** * An optional ServerInterceptor that can interrupt server calls that are running for too long time. + * In this way, it prevents problematic code from using up all threads. * - *

You can add it to your server using the ServerBuilder#intercept(ServerInterceptor) method. + *

How to use: you can add it to your server using the ServerBuilder#intercept(ServerInterceptor) method. * *

Limitation: it only applies the timeout to unary calls (streaming calls will run without timeout). */ @@ -36,38 +37,35 @@ public ServerCall.Listener interceptCall( ServerCall serverCall, Metadata metadata, ServerCallHandler serverCallHandler) { - return new TimeoutServerCallListener<>( - serverCallHandler.startCall(serverCall, metadata), serverCall, serverTimeoutManager); + // Only intercepts unary calls because the timeout is inapplicable to streaming calls. + if (serverCall.getMethodDescriptor().getType().clientSendsOneMessage()) { + return new TimeoutServerCallListener<>( + serverCallHandler.startCall(serverCall, metadata), serverTimeoutManager); + } else { + return serverCallHandler.startCall(serverCall, metadata); + } } /** A listener that intercepts the RPC method invocation for timeout control. */ private static class TimeoutServerCallListener extends ForwardingServerCallListener.SimpleForwardingServerCallListener { - private final ServerCall serverCall; private final ServerTimeoutManager serverTimeoutManager; private TimeoutServerCallListener( ServerCall.Listener delegate, - ServerCall serverCall, ServerTimeoutManager serverTimeoutManager) { super(delegate); - this.serverCall = serverCall; this.serverTimeoutManager = serverTimeoutManager; } /** - * Only intercepts unary calls because the timeout is inapplicable to streaming calls. * Intercepts onHalfClose() because the RPC method is called in it. See * io.grpc.stub.ServerCalls.UnaryServerCallHandler.UnaryServerCallListener */ @Override public void onHalfClose() { - if (serverCall.getMethodDescriptor().getType().clientSendsOneMessage()) { - serverTimeoutManager.intercept(super::onHalfClose); - } else { - super.onHalfClose(); - } + serverTimeoutManager.intercept(super::onHalfClose); } } } From 92c0ad3b148d3b96baa18a3c19d45c327d0f1c16 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Mon, 17 Jul 2023 17:39:01 +0800 Subject: [PATCH 03/22] replace TimeoutTask invalidation with Future cancelation --- .../java/io/grpc/ServerTimeoutManager.java | 22 +++++++++---------- .../io/grpc/TimeoutServerInterceptor.java | 5 +++-- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/api/src/main/java/io/grpc/ServerTimeoutManager.java b/api/src/main/java/io/grpc/ServerTimeoutManager.java index 4629863a2b5..4e70fb66133 100644 --- a/api/src/main/java/io/grpc/ServerTimeoutManager.java +++ b/api/src/main/java/io/grpc/ServerTimeoutManager.java @@ -17,6 +17,7 @@ package io.grpc; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -63,25 +64,26 @@ public void intercept(Runnable invocation) { return; } - TimeoutTask timeoutTask = schedule(Thread.currentThread()); + Future timeoutFuture = schedule(Thread.currentThread()); try { invocation.run(); } finally { - // If it completes in time, invalidate the timeout. - timeoutTask.invalidate(); + // If it completes in time, cancel the timeout. + if (timeoutFuture != null) { + timeoutFuture.cancel(false); + } } } - private TimeoutTask schedule(Thread thread) { + private Future schedule(Thread thread) { TimeoutTask timeoutTask = new TimeoutTask(thread); - if (!scheduler.isShutdown()) { - scheduler.schedule(timeoutTask, timeout, unit); + if (scheduler.isShutdown()) { + return null; } - return timeoutTask; + return scheduler.schedule(timeoutTask, timeout, unit); } private class TimeoutTask implements Runnable { - /** null thread means the task is invalid and will do nothing */ private final AtomicReference threadReference = new AtomicReference<>(); private TimeoutTask(Thread thread) { @@ -105,9 +107,5 @@ public void run() { } } } - - private void invalidate() { - threadReference.set(null); - } } } diff --git a/api/src/main/java/io/grpc/TimeoutServerInterceptor.java b/api/src/main/java/io/grpc/TimeoutServerInterceptor.java index e37f6a8557f..25f41030a5b 100644 --- a/api/src/main/java/io/grpc/TimeoutServerInterceptor.java +++ b/api/src/main/java/io/grpc/TimeoutServerInterceptor.java @@ -20,9 +20,10 @@ * An optional ServerInterceptor that can interrupt server calls that are running for too long time. * In this way, it prevents problematic code from using up all threads. * - *

How to use: you can add it to your server using the ServerBuilder#intercept(ServerInterceptor) method. + *

How to use: you can add it to your server using ServerBuilder#intercept(ServerInterceptor). * - *

Limitation: it only applies the timeout to unary calls (streaming calls will run without timeout). + *

Limitation: it only applies the timeout to unary calls + * (streaming calls will still run without timeout). */ public class TimeoutServerInterceptor implements ServerInterceptor { From e8c9894a3b2b57fbedc2dece93b18528b751786b Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Sat, 29 Jul 2023 21:32:06 +0800 Subject: [PATCH 04/22] Rename interceptor class --- ...rverInterceptor.java => ServerCallTimeoutInterceptor.java} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename api/src/main/java/io/grpc/{TimeoutServerInterceptor.java => ServerCallTimeoutInterceptor.java} (94%) diff --git a/api/src/main/java/io/grpc/TimeoutServerInterceptor.java b/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java similarity index 94% rename from api/src/main/java/io/grpc/TimeoutServerInterceptor.java rename to api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java index 25f41030a5b..4008498e3e6 100644 --- a/api/src/main/java/io/grpc/TimeoutServerInterceptor.java +++ b/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java @@ -25,11 +25,11 @@ *

Limitation: it only applies the timeout to unary calls * (streaming calls will still run without timeout). */ -public class TimeoutServerInterceptor implements ServerInterceptor { +public class ServerCallTimeoutInterceptor implements ServerInterceptor { private final ServerTimeoutManager serverTimeoutManager; - public TimeoutServerInterceptor(ServerTimeoutManager serverTimeoutManager) { + public ServerCallTimeoutInterceptor(ServerTimeoutManager serverTimeoutManager) { this.serverTimeoutManager = serverTimeoutManager; } From 4622550a9801de1a670f06c4031e408cba14e05f Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Sun, 6 Aug 2023 21:07:24 +0800 Subject: [PATCH 05/22] add unit tests --- .../java/io/grpc/stub/ServerCallsTest.java | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java index 7227d26c5b8..39e2d886b56 100644 --- a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java @@ -30,7 +30,9 @@ import io.grpc.MethodDescriptor; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; +import io.grpc.ServerCallTimeoutInterceptor; import io.grpc.ServerServiceDefinition; +import io.grpc.ServerTimeoutManager; import io.grpc.ServiceDescriptor; import io.grpc.Status; import io.grpc.StatusRuntimeException; @@ -43,6 +45,7 @@ import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; @@ -530,6 +533,58 @@ public void invoke(Integer req, StreamObserver responseObserver) { listener.onHalfClose(); } + @Test + public void callWithinTimeout() { + ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + new ServerCalls.UnaryMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + responseObserver.onNext(42); + responseObserver.onCompleted(); + } + }); + + ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager(100, TimeUnit.MILLISECONDS, null); + ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) + .interceptCall(serverCall, new Metadata(), callHandler); + listener.onMessage(1); + listener.onHalfClose(); + + assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); + assertEquals(Status.Code.OK, serverCall.status.getCode()); + } + + @Test + public void callExceedsTimeout() { + ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + new ServerCalls.UnaryMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + try { + Thread.sleep(100); + responseObserver.onNext(42); + responseObserver.onCompleted(); + } catch (InterruptedException e) { + Status status = Status.ABORTED.withDescription(e.getMessage()); + responseObserver.onError(new StatusRuntimeException(status)); + } + } + }); + + ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager(1, TimeUnit.MILLISECONDS, null); + ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager).interceptCall(serverCall, new Metadata(), callHandler); + listener.onMessage(1); + listener.onHalfClose(); + + assertThat(serverCall.responses).isEmpty(); + assertEquals(Status.Code.ABORTED, serverCall.status.getCode()); + assertEquals("sleep interrupted", serverCall.status.getDescription()); + } + @Test public void inprocessTransportManualFlow() throws Exception { final Semaphore semaphore = new Semaphore(1); From 390592d7d804c5bd73e10ec41af3fab45f82fc29 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Sun, 6 Aug 2023 22:21:14 +0800 Subject: [PATCH 06/22] fix code style --- .../java/io/grpc/stub/ServerCallsTest.java | 59 ++++++++++--------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java index 39e2d886b56..7b89f3ea8c8 100644 --- a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java @@ -537,18 +537,19 @@ public void invoke(Integer req, StreamObserver responseObserver) { public void callWithinTimeout() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall( - new ServerCalls.UnaryMethod() { - @Override - public void invoke(Integer req, StreamObserver responseObserver) { - responseObserver.onNext(42); - responseObserver.onCompleted(); - } - }); - - ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager(100, TimeUnit.MILLISECONDS, null); + ServerCalls.asyncUnaryCall( + new ServerCalls.UnaryMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + responseObserver.onNext(42); + responseObserver.onCompleted(); + } + }); + + ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( + 100, TimeUnit.MILLISECONDS, null); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) - .interceptCall(serverCall, new Metadata(), callHandler); + .interceptCall(serverCall, new Metadata(), callHandler); listener.onMessage(1); listener.onHalfClose(); @@ -560,23 +561,25 @@ public void invoke(Integer req, StreamObserver responseObserver) { public void callExceedsTimeout() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall( - new ServerCalls.UnaryMethod() { - @Override - public void invoke(Integer req, StreamObserver responseObserver) { - try { - Thread.sleep(100); - responseObserver.onNext(42); - responseObserver.onCompleted(); - } catch (InterruptedException e) { - Status status = Status.ABORTED.withDescription(e.getMessage()); - responseObserver.onError(new StatusRuntimeException(status)); - } - } - }); - - ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager(1, TimeUnit.MILLISECONDS, null); - ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager).interceptCall(serverCall, new Metadata(), callHandler); + ServerCalls.asyncUnaryCall( + new ServerCalls.UnaryMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + try { + Thread.sleep(100); + responseObserver.onNext(42); + responseObserver.onCompleted(); + } catch (InterruptedException e) { + Status status = Status.ABORTED.withDescription(e.getMessage()); + responseObserver.onError(new StatusRuntimeException(status)); + } + } + }); + + ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( + 1, TimeUnit.MILLISECONDS, null); + ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) + .interceptCall(serverCall, new Metadata(), callHandler); listener.onMessage(1); listener.onHalfClose(); From 7d10f253c059ef0996cd9c9af4768e413bae5164 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Mon, 7 Aug 2023 11:31:38 +0800 Subject: [PATCH 07/22] move unit tests --- .../ServerCallTimeoutInterceptorTest.java | 186 ++++++++++++++++++ .../java/io/grpc/stub/ServerCallsTest.java | 58 ------ 2 files changed, 186 insertions(+), 58 deletions(-) create mode 100644 api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java diff --git a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java new file mode 100644 index 00000000000..738cdb11ae1 --- /dev/null +++ b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java @@ -0,0 +1,186 @@ +/* + * Copyright 2016 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; + +import io.grpc.stub.ServerCalls; +import io.grpc.stub.StreamObserver; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ServerCallTimeoutInterceptor}. */ +@RunWith(JUnit4.class) +public class ServerCallTimeoutInterceptorTest { + static final MethodDescriptor STREAMING_METHOD = + MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.BIDI_STREAMING) + .setFullMethodName("some/bidi_streaming") + .setRequestMarshaller(new IntegerMarshaller()) + .setResponseMarshaller(new IntegerMarshaller()) + .build(); + + static final MethodDescriptor UNARY_METHOD = + STREAMING_METHOD.toBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("some/unary") + .build(); + + @Test + public void unaryServerCallWithinTimeout() { + ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + new ServerCalls.UnaryMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + responseObserver.onNext(42); + responseObserver.onCompleted(); + } + }); + StringBuffer logBuf = new StringBuffer(); + + ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( + 10, TimeUnit.MILLISECONDS, logBuf::append); + ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) + .interceptCall(serverCall, new Metadata(), callHandler); + listener.onMessage(1); + listener.onHalfClose(); + + assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); + assertEquals(Status.Code.OK, serverCall.status.getCode()); + assertThat(logBuf.toString()).isEmpty(); + } + + @Test + public void unaryServerCallExceedsTimeout() { + ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + new ServerCalls.UnaryMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + try { + Thread.sleep(10); + responseObserver.onNext(42); + responseObserver.onCompleted(); + } catch (InterruptedException e) { + Status status = Status.ABORTED.withDescription(e.getMessage()); + responseObserver.onError(new StatusRuntimeException(status)); + } + } + }); + StringBuffer logBuf = new StringBuffer(); + + ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( + 1, TimeUnit.MILLISECONDS, logBuf::append); + ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) + .interceptCall(serverCall, new Metadata(), callHandler); + listener.onMessage(1); + listener.onHalfClose(); + + assertThat(serverCall.responses).isEmpty(); + assertEquals(Status.Code.ABORTED, serverCall.status.getCode()); + assertEquals("sleep interrupted", serverCall.status.getDescription()); + assertThat(logBuf.toString()).startsWith("Interrupted RPC thread "); + } + + @Test + public void unaryServerCallSkipsZeroTimeout() { + ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + new ServerCalls.UnaryMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + try { + Thread.sleep(1); + responseObserver.onNext(42); + responseObserver.onCompleted(); + } catch (InterruptedException e) { + Status status = Status.ABORTED.withDescription(e.getMessage()); + responseObserver.onError(new StatusRuntimeException(status)); + } + } + }); + StringBuffer logBuf = new StringBuffer(); + + ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( + 0, TimeUnit.MILLISECONDS, logBuf::append); + ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) + .interceptCall(serverCall, new Metadata(), callHandler); + listener.onMessage(1); + listener.onHalfClose(); + + assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); + assertEquals(Status.Code.OK, serverCall.status.getCode()); + assertThat(logBuf.toString()).isEmpty(); + } + + private static class ServerCallRecorder extends ServerCall { + private final MethodDescriptor methodDescriptor; + private final List requestCalls = new ArrayList<>(); + private final List responses = new ArrayList<>(); + private Status status; + private boolean isCancelled; + private boolean isReady; + + public ServerCallRecorder(MethodDescriptor methodDescriptor) { + this.methodDescriptor = methodDescriptor; + } + + @Override + public void request(int numMessages) { + requestCalls.add(numMessages); + } + + @Override + public void sendHeaders(Metadata headers) { + } + + @Override + public void sendMessage(Integer message) { + this.responses.add(message); + } + + @Override + public void close(Status status, Metadata trailers) { + this.status = status; + } + + @Override + public boolean isCancelled() { + return isCancelled; + } + + @Override + public boolean isReady() { + return isReady; + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return methodDescriptor; + } + } +} diff --git a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java index 7b89f3ea8c8..7227d26c5b8 100644 --- a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java @@ -30,9 +30,7 @@ import io.grpc.MethodDescriptor; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; -import io.grpc.ServerCallTimeoutInterceptor; import io.grpc.ServerServiceDefinition; -import io.grpc.ServerTimeoutManager; import io.grpc.ServiceDescriptor; import io.grpc.Status; import io.grpc.StatusRuntimeException; @@ -45,7 +43,6 @@ import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; @@ -533,61 +530,6 @@ public void invoke(Integer req, StreamObserver responseObserver) { listener.onHalfClose(); } - @Test - public void callWithinTimeout() { - ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); - ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall( - new ServerCalls.UnaryMethod() { - @Override - public void invoke(Integer req, StreamObserver responseObserver) { - responseObserver.onNext(42); - responseObserver.onCompleted(); - } - }); - - ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 100, TimeUnit.MILLISECONDS, null); - ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) - .interceptCall(serverCall, new Metadata(), callHandler); - listener.onMessage(1); - listener.onHalfClose(); - - assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); - assertEquals(Status.Code.OK, serverCall.status.getCode()); - } - - @Test - public void callExceedsTimeout() { - ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); - ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall( - new ServerCalls.UnaryMethod() { - @Override - public void invoke(Integer req, StreamObserver responseObserver) { - try { - Thread.sleep(100); - responseObserver.onNext(42); - responseObserver.onCompleted(); - } catch (InterruptedException e) { - Status status = Status.ABORTED.withDescription(e.getMessage()); - responseObserver.onError(new StatusRuntimeException(status)); - } - } - }); - - ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 1, TimeUnit.MILLISECONDS, null); - ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) - .interceptCall(serverCall, new Metadata(), callHandler); - listener.onMessage(1); - listener.onHalfClose(); - - assertThat(serverCall.responses).isEmpty(); - assertEquals(Status.Code.ABORTED, serverCall.status.getCode()); - assertEquals("sleep interrupted", serverCall.status.getDescription()); - } - @Test public void inprocessTransportManualFlow() throws Exception { final Semaphore semaphore = new Semaphore(1); From 534785598e201745dabe6bb08d248e453f19e1c3 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Mon, 7 Aug 2023 12:09:24 +0800 Subject: [PATCH 08/22] test streaming method is not intercepted --- .../io/grpc/ServerCallTimeoutInterceptor.java | 2 +- .../java/io/grpc/ServerTimeoutManager.java | 6 +- .../ServerCallTimeoutInterceptorTest.java | 86 ++++++++++++++++--- 3 files changed, 81 insertions(+), 13 deletions(-) diff --git a/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java b/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java index 4008498e3e6..57127dc8f7d 100644 --- a/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java +++ b/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java @@ -66,7 +66,7 @@ private TimeoutServerCallListener( */ @Override public void onHalfClose() { - serverTimeoutManager.intercept(super::onHalfClose); + serverTimeoutManager.withTimeout(super::onHalfClose); } } } diff --git a/api/src/main/java/io/grpc/ServerTimeoutManager.java b/api/src/main/java/io/grpc/ServerTimeoutManager.java index 4e70fb66133..b89503065da 100644 --- a/api/src/main/java/io/grpc/ServerTimeoutManager.java +++ b/api/src/main/java/io/grpc/ServerTimeoutManager.java @@ -57,16 +57,18 @@ public void shutdown() { * Invalidates the timeout if the invocation completes in time. * * @param invocation The RPC method invocation that processes a request. + * @return true if a timeout is scheduled */ - public void intercept(Runnable invocation) { + public boolean withTimeout(Runnable invocation) { if (timeout <= 0) { invocation.run(); - return; + return false; } Future timeoutFuture = schedule(Thread.currentThread()); try { invocation.run(); + return true; } finally { // If it completes in time, cancel the timeout. if (timeoutFuture != null) { diff --git a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java index 738cdb11ae1..b035c403bb9 100644 --- a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java +++ b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java @@ -18,6 +18,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import io.grpc.stub.ServerCalls; import io.grpc.stub.StreamObserver; @@ -25,6 +26,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -54,7 +56,7 @@ public void unaryServerCallWithinTimeout() { new ServerCalls.UnaryMethod() { @Override public void invoke(Integer req, StreamObserver responseObserver) { - responseObserver.onNext(42); + responseObserver.onNext(req); responseObserver.onCompleted(); } }); @@ -64,8 +66,9 @@ public void invoke(Integer req, StreamObserver responseObserver) { 10, TimeUnit.MILLISECONDS, logBuf::append); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); - listener.onMessage(1); + listener.onMessage(42); listener.onHalfClose(); + listener.onComplete(); assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); assertEquals(Status.Code.OK, serverCall.status.getCode()); @@ -82,7 +85,7 @@ public void unaryServerCallExceedsTimeout() { public void invoke(Integer req, StreamObserver responseObserver) { try { Thread.sleep(10); - responseObserver.onNext(42); + responseObserver.onNext(req); responseObserver.onCompleted(); } catch (InterruptedException e) { Status status = Status.ABORTED.withDescription(e.getMessage()); @@ -96,8 +99,9 @@ public void invoke(Integer req, StreamObserver responseObserver) { 1, TimeUnit.MILLISECONDS, logBuf::append); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); - listener.onMessage(1); + listener.onMessage(42); listener.onHalfClose(); + listener.onComplete(); assertThat(serverCall.responses).isEmpty(); assertEquals(Status.Code.ABORTED, serverCall.status.getCode()); @@ -114,8 +118,8 @@ public void unaryServerCallSkipsZeroTimeout() { @Override public void invoke(Integer req, StreamObserver responseObserver) { try { - Thread.sleep(1); - responseObserver.onNext(42); + Thread.sleep(10); + responseObserver.onNext(req); responseObserver.onCompleted(); } catch (InterruptedException e) { Status status = Status.ABORTED.withDescription(e.getMessage()); @@ -123,18 +127,57 @@ public void invoke(Integer req, StreamObserver responseObserver) { } } }); - StringBuffer logBuf = new StringBuffer(); + AtomicBoolean timeoutScheduled = new AtomicBoolean(false); ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 0, TimeUnit.MILLISECONDS, logBuf::append); + 0, TimeUnit.MILLISECONDS, null) { + @Override + public boolean withTimeout(Runnable invocation) { + boolean result = super.withTimeout(invocation); + timeoutScheduled.set(result); + return result; + } + }; ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); - listener.onMessage(1); + listener.onMessage(42); listener.onHalfClose(); + listener.onComplete(); assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); assertEquals(Status.Code.OK, serverCall.status.getCode()); - assertThat(logBuf.toString()).isEmpty(); + assertFalse(timeoutScheduled.get()); + } + + @Test + public void streamingServerCallIsNotIntercepted() { + ServerCallRecorder serverCall = new ServerCallRecorder(STREAMING_METHOD); + ServerCallHandler callHandler = + ServerCalls.asyncBidiStreamingCall(new ServerCalls.BidiStreamingMethod() { + @Override + public StreamObserver invoke(StreamObserver responseObserver) { + return new EchoStreamObserver<>(responseObserver); + } + }); + AtomicBoolean interceptMethodCalled = new AtomicBoolean(false); + + ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( + 10, TimeUnit.MILLISECONDS, null) { + @Override + public boolean withTimeout(Runnable invocation) { + interceptMethodCalled.set(true); + return true; + } + }; + ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) + .interceptCall(serverCall, new Metadata(), callHandler); + listener.onMessage(42); + listener.onHalfClose(); + listener.onComplete(); + + assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); + assertEquals(Status.Code.OK, serverCall.status.getCode()); + assertFalse(interceptMethodCalled.get()); } private static class ServerCallRecorder extends ServerCall { @@ -183,4 +226,27 @@ public MethodDescriptor getMethodDescriptor() { return methodDescriptor; } } + + private static class EchoStreamObserver implements StreamObserver { + private final StreamObserver responseObserver; + + public EchoStreamObserver(StreamObserver responseObserver) { + this.responseObserver = responseObserver; + } + + @Override + public void onNext(T value) { + responseObserver.onNext(value); + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + } } From fad093b878a296db3c063e04198533bb29b39694 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Mon, 7 Aug 2023 12:13:01 +0800 Subject: [PATCH 09/22] update copyright year --- api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java | 2 +- api/src/main/java/io/grpc/ServerTimeoutManager.java | 4 ++-- .../test/java/io/grpc/ServerCallTimeoutInterceptorTest.java | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java b/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java index 57127dc8f7d..91ea7059153 100644 --- a/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java +++ b/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2014 The gRPC Authors + * Copyright 2023 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/api/src/main/java/io/grpc/ServerTimeoutManager.java b/api/src/main/java/io/grpc/ServerTimeoutManager.java index b89503065da..bba229335e0 100644 --- a/api/src/main/java/io/grpc/ServerTimeoutManager.java +++ b/api/src/main/java/io/grpc/ServerTimeoutManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2014 The gRPC Authors + * Copyright 2023 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ public void shutdown() { } /** - * Schedules a timeout and calls the RPC method invocation. + * Calls the RPC method invocation with a timeout scheduled. * Invalidates the timeout if the invocation completes in time. * * @param invocation The RPC method invocation that processes a request. diff --git a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java index b035c403bb9..edd8d746d2a 100644 --- a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java +++ b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2016 The gRPC Authors + * Copyright 2023 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From de4fce70da11614aee38faa856f80af146c0af21 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Mon, 7 Aug 2023 12:50:38 +0800 Subject: [PATCH 10/22] improve unit test --- api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java index edd8d746d2a..5014691fc20 100644 --- a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java +++ b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java @@ -96,7 +96,7 @@ public void invoke(Integer req, StreamObserver responseObserver) { StringBuffer logBuf = new StringBuffer(); ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 1, TimeUnit.MILLISECONDS, logBuf::append); + 1, TimeUnit.NANOSECONDS, logBuf::append); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); listener.onMessage(42); From 8ef5cef9adb44a32a3a20a61f8623fdc176cf5e7 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Mon, 7 Aug 2023 14:05:44 +0800 Subject: [PATCH 11/22] improve unit tests --- .../ServerCallTimeoutInterceptorTest.java | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java index 5014691fc20..ee4251cbaa0 100644 --- a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java +++ b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java @@ -56,19 +56,26 @@ public void unaryServerCallWithinTimeout() { new ServerCalls.UnaryMethod() { @Override public void invoke(Integer req, StreamObserver responseObserver) { - responseObserver.onNext(req); - responseObserver.onCompleted(); + try { + Thread.sleep(0); + responseObserver.onNext(req); + responseObserver.onCompleted(); + } catch (InterruptedException e) { + Status status = Status.ABORTED.withDescription(e.getMessage()); + responseObserver.onError(new StatusRuntimeException(status)); + } } }); StringBuffer logBuf = new StringBuffer(); ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 10, TimeUnit.MILLISECONDS, logBuf::append); + 100, TimeUnit.MILLISECONDS, logBuf::append); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); listener.onMessage(42); listener.onHalfClose(); listener.onComplete(); + serverTimeoutManager.shutdown(); assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); assertEquals(Status.Code.OK, serverCall.status.getCode()); @@ -84,7 +91,7 @@ public void unaryServerCallExceedsTimeout() { @Override public void invoke(Integer req, StreamObserver responseObserver) { try { - Thread.sleep(10); + Thread.sleep(100); responseObserver.onNext(req); responseObserver.onCompleted(); } catch (InterruptedException e) { @@ -102,6 +109,7 @@ public void invoke(Integer req, StreamObserver responseObserver) { listener.onMessage(42); listener.onHalfClose(); listener.onComplete(); + serverTimeoutManager.shutdown(); assertThat(serverCall.responses).isEmpty(); assertEquals(Status.Code.ABORTED, serverCall.status.getCode()); @@ -118,7 +126,7 @@ public void unaryServerCallSkipsZeroTimeout() { @Override public void invoke(Integer req, StreamObserver responseObserver) { try { - Thread.sleep(10); + Thread.sleep(1); responseObserver.onNext(req); responseObserver.onCompleted(); } catch (InterruptedException e) { @@ -127,14 +135,14 @@ public void invoke(Integer req, StreamObserver responseObserver) { } } }); - AtomicBoolean timeoutScheduled = new AtomicBoolean(false); + AtomicBoolean isTimeoutScheduled = new AtomicBoolean(false); ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( 0, TimeUnit.MILLISECONDS, null) { @Override public boolean withTimeout(Runnable invocation) { boolean result = super.withTimeout(invocation); - timeoutScheduled.set(result); + isTimeoutScheduled.set(result); return result; } }; @@ -143,10 +151,11 @@ public boolean withTimeout(Runnable invocation) { listener.onMessage(42); listener.onHalfClose(); listener.onComplete(); + serverTimeoutManager.shutdown(); assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); assertEquals(Status.Code.OK, serverCall.status.getCode()); - assertFalse(timeoutScheduled.get()); + assertFalse(isTimeoutScheduled.get()); } @Test @@ -159,13 +168,13 @@ public StreamObserver invoke(StreamObserver responseObserver) return new EchoStreamObserver<>(responseObserver); } }); - AtomicBoolean interceptMethodCalled = new AtomicBoolean(false); + AtomicBoolean isManagerCalled = new AtomicBoolean(false); ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 10, TimeUnit.MILLISECONDS, null) { + 100, TimeUnit.MILLISECONDS, null) { @Override public boolean withTimeout(Runnable invocation) { - interceptMethodCalled.set(true); + isManagerCalled.set(true); return true; } }; @@ -174,10 +183,11 @@ public boolean withTimeout(Runnable invocation) { listener.onMessage(42); listener.onHalfClose(); listener.onComplete(); + serverTimeoutManager.shutdown(); assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); assertEquals(Status.Code.OK, serverCall.status.getCode()); - assertFalse(interceptMethodCalled.get()); + assertFalse(isManagerCalled.get()); } private static class ServerCallRecorder extends ServerCall { From 9ff53b32cc218c94659ba5db1fea317e4074190c Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Wed, 9 Aug 2023 13:52:10 +0800 Subject: [PATCH 12/22] Change to CancellableContext and CancellationListener approach --- .../java/io/grpc/ServerTimeoutManager.java | 67 +++++++------------ 1 file changed, 24 insertions(+), 43 deletions(-) diff --git a/api/src/main/java/io/grpc/ServerTimeoutManager.java b/api/src/main/java/io/grpc/ServerTimeoutManager.java index bba229335e0..d39544d0827 100644 --- a/api/src/main/java/io/grpc/ServerTimeoutManager.java +++ b/api/src/main/java/io/grpc/ServerTimeoutManager.java @@ -16,11 +16,10 @@ package io.grpc; +import com.google.common.util.concurrent.MoreExecutors; import java.util.concurrent.Executors; -import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; /** A global instance that schedules the timeout tasks. */ @@ -47,7 +46,10 @@ public ServerTimeoutManager(int timeout, TimeUnit unit, Consumer logFunc this.logFunction = logFunction; } - /** Please call shutdown() when the application exits. */ + /** + * Please call shutdown() when the application exits. + * You can add a JVM shutdown hook. + */ public void shutdown() { scheduler.shutdownNow(); } @@ -60,54 +62,33 @@ public void shutdown() { * @return true if a timeout is scheduled */ public boolean withTimeout(Runnable invocation) { - if (timeout <= 0) { + if (timeout <= 0 || scheduler.isShutdown()) { invocation.run(); return false; } - Future timeoutFuture = schedule(Thread.currentThread()); - try { - invocation.run(); - return true; - } finally { - // If it completes in time, cancel the timeout. - if (timeoutFuture != null) { - timeoutFuture.cancel(false); - } - } - } - - private Future schedule(Thread thread) { - TimeoutTask timeoutTask = new TimeoutTask(thread); - if (scheduler.isShutdown()) { - return null; - } - return scheduler.schedule(timeoutTask, timeout, unit); - } - - private class TimeoutTask implements Runnable { - private final AtomicReference threadReference = new AtomicReference<>(); - - private TimeoutTask(Thread thread) { - threadReference.set(thread); - } - - @Override - public void run() { - // Ensure the reference is consumed only once. - Thread thread = threadReference.getAndSet(null); - if (thread != null) { + try (Context.CancellableContext context = Context.current() + .withDeadline(Deadline.after(timeout, unit), scheduler)) { + Thread thread = Thread.currentThread(); + Context.CancellationListener cancelled = c -> { + if (c.cancellationCause() == null) { + return; + } thread.interrupt(); if (logFunction != null) { logFunction.accept( - "Interrupted RPC thread " - + thread.getName() - + " for timeout at " - + timeout - + " " - + unit); + "Interrupted RPC thread " + + thread.getName() + + " for timeout at " + + timeout + + " " + + unit); } - } + }; + context.addListener(cancelled, MoreExecutors.directExecutor()); + context.run(invocation); + return true; } } + } From f791d10c4feee1c0beb319b63276cf27e05ded3b Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Sun, 3 Sep 2023 15:43:49 +0800 Subject: [PATCH 13/22] Make interruption optional --- .../io/grpc/ServerCallTimeoutInterceptor.java | 1 + .../java/io/grpc/ServerTimeoutManager.java | 48 +++++++++++++------ .../ServerCallTimeoutInterceptorTest.java | 8 ++-- 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java b/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java index 91ea7059153..9bcb6dcf68c 100644 --- a/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java +++ b/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java @@ -25,6 +25,7 @@ *

Limitation: it only applies the timeout to unary calls * (streaming calls will still run without timeout). */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/10361") public class ServerCallTimeoutInterceptor implements ServerInterceptor { private final ServerTimeoutManager serverTimeoutManager; diff --git a/api/src/main/java/io/grpc/ServerTimeoutManager.java b/api/src/main/java/io/grpc/ServerTimeoutManager.java index d39544d0827..51b59676571 100644 --- a/api/src/main/java/io/grpc/ServerTimeoutManager.java +++ b/api/src/main/java/io/grpc/ServerTimeoutManager.java @@ -20,29 +20,35 @@ import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; -/** A global instance that schedules the timeout tasks. */ +/** A global manager that schedules the timeout tasks for the gRPC server. */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/10361") public class ServerTimeoutManager { private final int timeout; private final TimeUnit unit; + private final boolean shouldInterrupt; private final Consumer logFunction; private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); /** - * Creates a manager. Please make it a singleton and remember to shut it down. + * Creates an instance. Please make it a singleton and remember to shut it down. * * @param timeout Configurable timeout threshold. A value less than 0 (e.g. 0 or -1) means not to * check timeout. * @param unit The unit of the timeout. + * @param shouldInterrupt If {@code true}, interrupts the RPC worker thread. * @param logFunction An optional function that can log (e.g. Logger::warn). Through this, * we avoid depending on a specific logger library. */ - public ServerTimeoutManager(int timeout, TimeUnit unit, Consumer logFunction) { + public ServerTimeoutManager(int timeout, TimeUnit unit, + boolean shouldInterrupt, Consumer logFunction) { this.timeout = timeout; this.unit = unit; + this.shouldInterrupt = shouldInterrupt; this.logFunction = logFunction; } @@ -58,6 +64,9 @@ public void shutdown() { * Calls the RPC method invocation with a timeout scheduled. * Invalidates the timeout if the invocation completes in time. * + *

When the timeout is reached: It cancels the context around the RPC method invocation. And + * if shouldInterrupt is {@code true}, it also interrupts the current worker thread. + * * @param invocation The RPC method invocation that processes a request. * @return true if a timeout is scheduled */ @@ -69,26 +78,37 @@ public boolean withTimeout(Runnable invocation) { try (Context.CancellableContext context = Context.current() .withDeadline(Deadline.after(timeout, unit), scheduler)) { - Thread thread = Thread.currentThread(); + AtomicReference threadRef = + shouldInterrupt ? new AtomicReference<>(Thread.currentThread()) : null; Context.CancellationListener cancelled = c -> { if (c.cancellationCause() == null) { return; } - thread.interrupt(); - if (logFunction != null) { - logFunction.accept( - "Interrupted RPC thread " - + thread.getName() - + " for timeout at " - + timeout - + " " - + unit); + if (threadRef != null) { + Thread thread = threadRef.getAndSet(null); + if (thread != null) { + thread.interrupt(); + if (logFunction != null) { + logFunction.accept( + "Interrupted RPC thread " + + thread.getName() + + " for timeout at " + + timeout + + " " + + unit); + } + } } }; context.addListener(cancelled, MoreExecutors.directExecutor()); context.run(invocation); + // Invocation is done, should ensure the interruption state is cleared so + // the worker thread can be safely reused for the next task. Doing this + // mainly for ForkJoinPool https://bugs.openjdk.org/browse/JDK-8223430. + if (threadRef != null && threadRef.get() == null) { + Thread.interrupted(); + } return true; } } - } diff --git a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java index ee4251cbaa0..1651fb9351f 100644 --- a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java +++ b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java @@ -69,7 +69,7 @@ public void invoke(Integer req, StreamObserver responseObserver) { StringBuffer logBuf = new StringBuffer(); ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 100, TimeUnit.MILLISECONDS, logBuf::append); + 100, TimeUnit.MILLISECONDS, true, logBuf::append); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); listener.onMessage(42); @@ -103,7 +103,7 @@ public void invoke(Integer req, StreamObserver responseObserver) { StringBuffer logBuf = new StringBuffer(); ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 1, TimeUnit.NANOSECONDS, logBuf::append); + 1, TimeUnit.NANOSECONDS, true, logBuf::append); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); listener.onMessage(42); @@ -138,7 +138,7 @@ public void invoke(Integer req, StreamObserver responseObserver) { AtomicBoolean isTimeoutScheduled = new AtomicBoolean(false); ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 0, TimeUnit.MILLISECONDS, null) { + 0, TimeUnit.MILLISECONDS, true, null) { @Override public boolean withTimeout(Runnable invocation) { boolean result = super.withTimeout(invocation); @@ -171,7 +171,7 @@ public StreamObserver invoke(StreamObserver responseObserver) AtomicBoolean isManagerCalled = new AtomicBoolean(false); ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 100, TimeUnit.MILLISECONDS, null) { + 100, TimeUnit.MILLISECONDS, true, null) { @Override public boolean withTimeout(Runnable invocation) { isManagerCalled.set(true); From 05168f2d0cdabc39de2e528263bccfe26fa4e473 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Sun, 3 Sep 2023 16:05:26 +0800 Subject: [PATCH 14/22] Builder for ServerTimeoutManager --- .../java/io/grpc/ServerTimeoutManager.java | 69 +++++++++++++++---- .../ServerCallTimeoutInterceptorTest.java | 30 +++++--- 2 files changed, 75 insertions(+), 24 deletions(-) diff --git a/api/src/main/java/io/grpc/ServerTimeoutManager.java b/api/src/main/java/io/grpc/ServerTimeoutManager.java index 51b59676571..8e6b3ad88f7 100644 --- a/api/src/main/java/io/grpc/ServerTimeoutManager.java +++ b/api/src/main/java/io/grpc/ServerTimeoutManager.java @@ -23,28 +23,31 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; -/** A global manager that schedules the timeout tasks for the gRPC server. */ +/** + * A global manager that schedules the timeout tasks for the gRPC server. + * Please make it a singleton and shut it down when the server is shutdown. + */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10361") public class ServerTimeoutManager { + /** + * Creates a builder. + * + * @param timeout Configurable timeout threshold. A value less than 0 (e.g. 0 or -1) means not to + * check timeout. + * @param unit The unit of the timeout. + */ + public static Builder newBuilder(int timeout, TimeUnit unit) { + return new Builder(timeout, unit); + } + private final int timeout; private final TimeUnit unit; private final boolean shouldInterrupt; - private final Consumer logFunction; private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); - /** - * Creates an instance. Please make it a singleton and remember to shut it down. - * - * @param timeout Configurable timeout threshold. A value less than 0 (e.g. 0 or -1) means not to - * check timeout. - * @param unit The unit of the timeout. - * @param shouldInterrupt If {@code true}, interrupts the RPC worker thread. - * @param logFunction An optional function that can log (e.g. Logger::warn). Through this, - * we avoid depending on a specific logger library. - */ - public ServerTimeoutManager(int timeout, TimeUnit unit, + protected ServerTimeoutManager(int timeout, TimeUnit unit, boolean shouldInterrupt, Consumer logFunction) { this.timeout = timeout; this.unit = unit; @@ -84,6 +87,7 @@ public boolean withTimeout(Runnable invocation) { if (c.cancellationCause() == null) { return; } + System.out.println("RPC cancelled"); if (threadRef != null) { Thread thread = threadRef.getAndSet(null); if (thread != null) { @@ -111,4 +115,43 @@ public boolean withTimeout(Runnable invocation) { return true; } } + + /** Builder for constructing ServerTimeoutManager instances. */ + public static class Builder { + private final int timeout; + private final TimeUnit unit; + + private boolean shouldInterrupt; + private Consumer logFunction; + + private Builder(int timeout, TimeUnit unit) { + this.timeout = timeout; + this.unit = unit; + } + + /** + * Sets shouldInterrupt. Defaults to {@code false}. + * + * @param shouldInterrupt If {@code true}, interrupts the RPC worker thread. + */ + public Builder setShouldInterrupt(boolean shouldInterrupt) { + this.shouldInterrupt = shouldInterrupt; + return this; + } + + /** + * Sets the logFunction. Through this, we avoid depending on a specific logger library. + * + * @param logFunction An optional function that can make server logs (e.g. Logger::warn). + */ + public Builder setLogFunction(Consumer logFunction) { + this.logFunction = logFunction; + return this; + } + + /** Construct new ServerTimeoutManager. */ + public ServerTimeoutManager build() { + return new ServerTimeoutManager(timeout, unit, shouldInterrupt, logFunction); + } + } } diff --git a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java index 1651fb9351f..ab992f9e743 100644 --- a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java +++ b/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java @@ -49,7 +49,7 @@ public class ServerCallTimeoutInterceptorTest { .build(); @Test - public void unaryServerCallWithinTimeout() { + public void unaryServerCall_setShouldInterrupt_withinTimeout_isNotInterrupted() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = ServerCalls.asyncUnaryCall( @@ -61,15 +61,19 @@ public void invoke(Integer req, StreamObserver responseObserver) { responseObserver.onNext(req); responseObserver.onCompleted(); } catch (InterruptedException e) { - Status status = Status.ABORTED.withDescription(e.getMessage()); + Status status = + Context.current().isCancelled() ? Status.CANCELLED : Status.INTERNAL; responseObserver.onError(new StatusRuntimeException(status)); } } }); StringBuffer logBuf = new StringBuffer(); - ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 100, TimeUnit.MILLISECONDS, true, logBuf::append); + ServerTimeoutManager serverTimeoutManager = + ServerTimeoutManager.newBuilder(100, TimeUnit.MILLISECONDS) + .setShouldInterrupt(true) + .setLogFunction(logBuf::append) + .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); listener.onMessage(42); @@ -83,7 +87,7 @@ public void invoke(Integer req, StreamObserver responseObserver) { } @Test - public void unaryServerCallExceedsTimeout() { + public void unaryServerCall_setShouldInterrupt_exceedingTimeout_isInterrupted() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = ServerCalls.asyncUnaryCall( @@ -95,15 +99,19 @@ public void invoke(Integer req, StreamObserver responseObserver) { responseObserver.onNext(req); responseObserver.onCompleted(); } catch (InterruptedException e) { - Status status = Status.ABORTED.withDescription(e.getMessage()); + Status status = + Context.current().isCancelled() ? Status.CANCELLED : Status.INTERNAL; responseObserver.onError(new StatusRuntimeException(status)); } } }); StringBuffer logBuf = new StringBuffer(); - ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 1, TimeUnit.NANOSECONDS, true, logBuf::append); + ServerTimeoutManager serverTimeoutManager = + ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS) + .setShouldInterrupt(true) + .setLogFunction(logBuf::append) + .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); listener.onMessage(42); @@ -112,8 +120,7 @@ public void invoke(Integer req, StreamObserver responseObserver) { serverTimeoutManager.shutdown(); assertThat(serverCall.responses).isEmpty(); - assertEquals(Status.Code.ABORTED, serverCall.status.getCode()); - assertEquals("sleep interrupted", serverCall.status.getDescription()); + assertEquals(Status.Code.CANCELLED, serverCall.status.getCode()); assertThat(logBuf.toString()).startsWith("Interrupted RPC thread "); } @@ -130,7 +137,8 @@ public void invoke(Integer req, StreamObserver responseObserver) { responseObserver.onNext(req); responseObserver.onCompleted(); } catch (InterruptedException e) { - Status status = Status.ABORTED.withDescription(e.getMessage()); + Status status = + Context.current().isCancelled() ? Status.CANCELLED : Status.INTERNAL; responseObserver.onError(new StatusRuntimeException(status)); } } From a7f5cc16d055b3fdddeb6dc12a89da4bce2ce785 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Sun, 3 Sep 2023 18:03:01 +0800 Subject: [PATCH 15/22] Clear interruption in a finally block --- .../main/java/io/grpc/ServerTimeoutManager.java | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/api/src/main/java/io/grpc/ServerTimeoutManager.java b/api/src/main/java/io/grpc/ServerTimeoutManager.java index 8e6b3ad88f7..a3bc32680dd 100644 --- a/api/src/main/java/io/grpc/ServerTimeoutManager.java +++ b/api/src/main/java/io/grpc/ServerTimeoutManager.java @@ -105,13 +105,17 @@ public boolean withTimeout(Runnable invocation) { } }; context.addListener(cancelled, MoreExecutors.directExecutor()); - context.run(invocation); - // Invocation is done, should ensure the interruption state is cleared so - // the worker thread can be safely reused for the next task. Doing this - // mainly for ForkJoinPool https://bugs.openjdk.org/browse/JDK-8223430. - if (threadRef != null && threadRef.get() == null) { - Thread.interrupted(); + try { + context.run(invocation); + } finally { + // Clear the interruption state if this context previously caused an interruption, + // allowing the worker thread to be safely reused for the next task in a ForkJoinPool. + // For more information, refer to https://bugs.openjdk.org/browse/JDK-8223430 + if (threadRef != null && threadRef.get() == null) { + Thread.interrupted(); + } } + return true; } } From 1200f355400e9007edabc751b1d9f2403dcfe270 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Sun, 17 Sep 2023 11:26:15 +0800 Subject: [PATCH 16/22] Intercept all stages and close server call using serializing execution --- .../io/grpc/util/SerializingServerCall.java | 167 +++++++++++++++ .../util}/ServerCallTimeoutInterceptor.java | 75 ++++++- .../io/grpc/util}/ServerTimeoutManager.java | 115 +++++++---- ...smitStatusRuntimeExceptionInterceptor.java | 160 -------------- .../ServerCallTimeoutInterceptorTest.java | 195 +++++++++--------- 5 files changed, 401 insertions(+), 311 deletions(-) create mode 100644 core/src/main/java/io/grpc/util/SerializingServerCall.java rename {api/src/main/java/io/grpc => core/src/main/java/io/grpc/util}/ServerCallTimeoutInterceptor.java (53%) rename {api/src/main/java/io/grpc => core/src/main/java/io/grpc/util}/ServerTimeoutManager.java (55%) rename {api/src/test/java/io/grpc => core/src/test/java/io/grpc/util}/ServerCallTimeoutInterceptorTest.java (60%) diff --git a/core/src/main/java/io/grpc/util/SerializingServerCall.java b/core/src/main/java/io/grpc/util/SerializingServerCall.java new file mode 100644 index 00000000000..d06606250e4 --- /dev/null +++ b/core/src/main/java/io/grpc/util/SerializingServerCall.java @@ -0,0 +1,167 @@ +package io.grpc.util; + +import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Attributes; +import io.grpc.ExperimentalApi; +import io.grpc.ForwardingServerCall; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.Status; +import io.grpc.internal.SerializingExecutor; +import java.util.concurrent.ExecutionException; +import javax.annotation.Nullable; + +/** + * A {@link ServerCall} that wraps around a non thread safe delegate and provides thread safe + * access by serializing everything on an executor. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/2189") +class SerializingServerCall extends + ForwardingServerCall.SimpleForwardingServerCall { + private static final String ERROR_MSG = "Encountered error during serialized access"; + private final SerializingExecutor serializingExecutor = + new SerializingExecutor(MoreExecutors.directExecutor()); + private boolean closeCalled = false; + + SerializingServerCall(ServerCall delegate) { + super(delegate); + } + + @Override + public void sendMessage(final RespT message) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + SerializingServerCall.super.sendMessage(message); + } + }); + } + + @Override + public void request(final int numMessages) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + SerializingServerCall.super.request(numMessages); + } + }); + } + + @Override + public void sendHeaders(final Metadata headers) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + SerializingServerCall.super.sendHeaders(headers); + } + }); + } + + @Override + public void close(final Status status, final Metadata trailers) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + if (!closeCalled) { + closeCalled = true; + + SerializingServerCall.super.close(status, trailers); + } + } + }); + } + + @Override + public boolean isReady() { + final SettableFuture retVal = SettableFuture.create(); + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + retVal.set(SerializingServerCall.super.isReady()); + } + }); + try { + return retVal.get(); + } catch (InterruptedException e) { + throw new RuntimeException(ERROR_MSG, e); + } catch (ExecutionException e) { + throw new RuntimeException(ERROR_MSG, e); + } + } + + @Override + public boolean isCancelled() { + final SettableFuture retVal = SettableFuture.create(); + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + retVal.set(SerializingServerCall.super.isCancelled()); + } + }); + try { + return retVal.get(); + } catch (InterruptedException e) { + throw new RuntimeException(ERROR_MSG, e); + } catch (ExecutionException e) { + throw new RuntimeException(ERROR_MSG, e); + } + } + + @Override + public void setMessageCompression(final boolean enabled) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + SerializingServerCall.super.setMessageCompression(enabled); + } + }); + } + + @Override + public void setCompression(final String compressor) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + SerializingServerCall.super.setCompression(compressor); + } + }); + } + + @Override + public Attributes getAttributes() { + final SettableFuture retVal = SettableFuture.create(); + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + retVal.set(SerializingServerCall.super.getAttributes()); + } + }); + try { + return retVal.get(); + } catch (InterruptedException e) { + throw new RuntimeException(ERROR_MSG, e); + } catch (ExecutionException e) { + throw new RuntimeException(ERROR_MSG, e); + } + } + + @Nullable + @Override + public String getAuthority() { + final SettableFuture retVal = SettableFuture.create(); + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + retVal.set(SerializingServerCall.super.getAuthority()); + } + }); + try { + return retVal.get(); + } catch (InterruptedException e) { + throw new RuntimeException(ERROR_MSG, e); + } catch (ExecutionException e) { + throw new RuntimeException(ERROR_MSG, e); + } + } +} diff --git a/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java b/core/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java similarity index 53% rename from api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java rename to core/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java index 9bcb6dcf68c..897bd77dc70 100644 --- a/api/src/main/java/io/grpc/ServerCallTimeoutInterceptor.java +++ b/core/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java @@ -14,7 +14,15 @@ * limitations under the License. */ -package io.grpc; +package io.grpc.util; + +import io.grpc.Context; +import io.grpc.ExperimentalApi; +import io.grpc.ForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; /** * An optional ServerInterceptor that can interrupt server calls that are running for too long time. @@ -41,33 +49,82 @@ public ServerCall.Listener interceptCall( ServerCallHandler serverCallHandler) { // Only intercepts unary calls because the timeout is inapplicable to streaming calls. if (serverCall.getMethodDescriptor().getType().clientSendsOneMessage()) { - return new TimeoutServerCallListener<>( - serverCallHandler.startCall(serverCall, metadata), serverTimeoutManager); - } else { - return serverCallHandler.startCall(serverCall, metadata); + ServerCall serializingServerCall = new SerializingServerCall<>(serverCall); + Context.CancellableContext timeoutContext = + serverTimeoutManager.startTimeoutContext(serializingServerCall); + if (timeoutContext != null) { + return new TimeoutServerCallListener<>( + serverCallHandler.startCall(serializingServerCall, metadata), + timeoutContext, + serverTimeoutManager); + } } + return serverCallHandler.startCall(serverCall, metadata); } - /** A listener that intercepts the RPC method invocation for timeout control. */ - private static class TimeoutServerCallListener + /** A listener that intercepts RPC callbacks for timeout control. */ + static class TimeoutServerCallListener extends ForwardingServerCallListener.SimpleForwardingServerCallListener { + private final Context.CancellableContext context; private final ServerTimeoutManager serverTimeoutManager; private TimeoutServerCallListener( ServerCall.Listener delegate, + Context.CancellableContext context, ServerTimeoutManager serverTimeoutManager) { super(delegate); + this.context = context; this.serverTimeoutManager = serverTimeoutManager; } + @Override + public void onMessage(ReqT message) { + Context previous = context.attach(); + try { + super.onMessage(message); + } finally { + context.detach(previous); + } + } + /** - * Intercepts onHalfClose() because the RPC method is called in it. See + * Intercepts onHalfClose() because the application RPC method is called in it. See * io.grpc.stub.ServerCalls.UnaryServerCallHandler.UnaryServerCallListener */ @Override public void onHalfClose() { - serverTimeoutManager.withTimeout(super::onHalfClose); + serverTimeoutManager.withInterruption(context, super::onHalfClose); + } + + @Override + public void onCancel() { + Context previous = context.attach(); + try { + super.onCancel(); + } finally { + context.detach(previous); + } + } + + @Override + public void onComplete() { + Context previous = context.attach(); + try { + super.onComplete(); + } finally { + context.detach(previous); + } + } + + @Override + public void onReady() { + Context previous = context.attach(); + try { + super.onReady(); + } finally { + context.detach(previous); + } } } } diff --git a/api/src/main/java/io/grpc/ServerTimeoutManager.java b/core/src/main/java/io/grpc/util/ServerTimeoutManager.java similarity index 55% rename from api/src/main/java/io/grpc/ServerTimeoutManager.java rename to core/src/main/java/io/grpc/util/ServerTimeoutManager.java index a3bc32680dd..c3ebf384a8e 100644 --- a/api/src/main/java/io/grpc/ServerTimeoutManager.java +++ b/core/src/main/java/io/grpc/util/ServerTimeoutManager.java @@ -14,14 +14,21 @@ * limitations under the License. */ -package io.grpc; +package io.grpc.util; import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.Context; +import io.grpc.Deadline; +import io.grpc.ExperimentalApi; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.Status; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import javax.annotation.Nullable; /** * A global manager that schedules the timeout tasks for the gRPC server. @@ -45,9 +52,9 @@ public static Builder newBuilder(int timeout, TimeUnit unit) { private final boolean shouldInterrupt; private final Consumer logFunction; - private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); + private final ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(); - protected ServerTimeoutManager(int timeout, TimeUnit unit, + private ServerTimeoutManager(int timeout, TimeUnit unit, boolean shouldInterrupt, Consumer logFunction) { this.timeout = timeout; this.unit = unit; @@ -56,68 +63,86 @@ protected ServerTimeoutManager(int timeout, TimeUnit unit, } /** - * Please call shutdown() when the application exits. - * You can add a JVM shutdown hook. + * Please call shutdown() when the application exits. You can add a JVM shutdown hook. */ public void shutdown() { scheduler.shutdownNow(); } /** - * Calls the RPC method invocation with a timeout scheduled. - * Invalidates the timeout if the invocation completes in time. + * Creates a context with the timeout limit. + * @param serverCall Should pass in a SerializingServerCall that can be closed thread-safely. + * @return null if not to set a timeout for it + */ + @Nullable + public Context.CancellableContext startTimeoutContext(ServerCall serverCall) { + if (timeout <= 0 || scheduler.isShutdown()) { + return null; + } + Context.CancellationListener callCloser = c -> { + if (c.cancellationCause() == null) { + return; + } + serverCall.close(Status.CANCELLED.withDescription("server call timeout"), new Metadata()); + }; + Context.CancellableContext context = Context.current().withDeadline( + Deadline.after(timeout, unit), scheduler); + context.addListener(callCloser, MoreExecutors.directExecutor()); + return context; + } + + /** + * Executes the application RPC invocation in the timeout context. * - *

When the timeout is reached: It cancels the context around the RPC method invocation. And + *

When the timeout is reached: It cancels the context around the RPC invocation. And * if shouldInterrupt is {@code true}, it also interrupts the current worker thread. * - * @param invocation The RPC method invocation that processes a request. + * @param context The timeout context. + * @param invocation The application RPC invocation that processes a request. * @return true if a timeout is scheduled */ - public boolean withTimeout(Runnable invocation) { + public boolean withInterruption(Context.CancellableContext context, Runnable invocation) { if (timeout <= 0 || scheduler.isShutdown()) { invocation.run(); return false; } - try (Context.CancellableContext context = Context.current() - .withDeadline(Deadline.after(timeout, unit), scheduler)) { - AtomicReference threadRef = - shouldInterrupt ? new AtomicReference<>(Thread.currentThread()) : null; - Context.CancellationListener cancelled = c -> { - if (c.cancellationCause() == null) { - return; - } - System.out.println("RPC cancelled"); - if (threadRef != null) { - Thread thread = threadRef.getAndSet(null); - if (thread != null) { - thread.interrupt(); - if (logFunction != null) { - logFunction.accept( - "Interrupted RPC thread " - + thread.getName() - + " for timeout at " - + timeout - + " " - + unit); - } + AtomicReference threadRef = + shouldInterrupt ? new AtomicReference<>(Thread.currentThread()) : null; + Context.CancellationListener interruption = c -> { + if (c.cancellationCause() == null) { + return; + } + if (threadRef != null) { + Thread thread = threadRef.getAndSet(null); + if (thread != null) { + thread.interrupt(); + if (logFunction != null) { + logFunction.accept( + "Interrupted RPC thread " + + thread.getName() + + " for timeout at " + + timeout + + " " + + unit); } } - }; - context.addListener(cancelled, MoreExecutors.directExecutor()); - try { - context.run(invocation); - } finally { - // Clear the interruption state if this context previously caused an interruption, - // allowing the worker thread to be safely reused for the next task in a ForkJoinPool. - // For more information, refer to https://bugs.openjdk.org/browse/JDK-8223430 - if (threadRef != null && threadRef.get() == null) { - Thread.interrupted(); - } } - - return true; + }; + context.addListener(interruption, MoreExecutors.directExecutor()); + try { + context.run(invocation); + } finally { + context.removeListener(interruption); + // Clear the interruption state if this context previously caused an interruption, + // allowing the worker thread to be safely reused for the next task in a ForkJoinPool. + // For more information, refer to https://bugs.openjdk.org/browse/JDK-8223430 + if (threadRef != null && threadRef.get() == null) { + Thread.interrupted(); + } } + + return true; } /** Builder for constructing ServerTimeoutManager instances. */ diff --git a/core/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java b/core/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java index bead2be4e9e..438c43134c1 100644 --- a/core/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java +++ b/core/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java @@ -16,11 +16,7 @@ package io.grpc.util; -import com.google.common.util.concurrent.MoreExecutors; -import com.google.common.util.concurrent.SettableFuture; -import io.grpc.Attributes; import io.grpc.ExperimentalApi; -import io.grpc.ForwardingServerCall; import io.grpc.ForwardingServerCallListener; import io.grpc.Metadata; import io.grpc.ServerCall; @@ -28,9 +24,6 @@ import io.grpc.ServerInterceptor; import io.grpc.Status; import io.grpc.StatusRuntimeException; -import io.grpc.internal.SerializingExecutor; -import java.util.concurrent.ExecutionException; -import javax.annotation.Nullable; /** * A class that intercepts uncaught exceptions of type {@link StatusRuntimeException} and handles @@ -113,157 +106,4 @@ private void closeWithException(StatusRuntimeException t) { } }; } - - /** - * A {@link ServerCall} that wraps around a non thread safe delegate and provides thread safe - * access by serializing everything on an executor. - */ - private static class SerializingServerCall extends - ForwardingServerCall.SimpleForwardingServerCall { - private static final String ERROR_MSG = "Encountered error during serialized access"; - private final SerializingExecutor serializingExecutor = - new SerializingExecutor(MoreExecutors.directExecutor()); - private boolean closeCalled = false; - - SerializingServerCall(ServerCall delegate) { - super(delegate); - } - - @Override - public void sendMessage(final RespT message) { - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - SerializingServerCall.super.sendMessage(message); - } - }); - } - - @Override - public void request(final int numMessages) { - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - SerializingServerCall.super.request(numMessages); - } - }); - } - - @Override - public void sendHeaders(final Metadata headers) { - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - SerializingServerCall.super.sendHeaders(headers); - } - }); - } - - @Override - public void close(final Status status, final Metadata trailers) { - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - if (!closeCalled) { - closeCalled = true; - - SerializingServerCall.super.close(status, trailers); - } - } - }); - } - - @Override - public boolean isReady() { - final SettableFuture retVal = SettableFuture.create(); - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - retVal.set(SerializingServerCall.super.isReady()); - } - }); - try { - return retVal.get(); - } catch (InterruptedException e) { - throw new RuntimeException(ERROR_MSG, e); - } catch (ExecutionException e) { - throw new RuntimeException(ERROR_MSG, e); - } - } - - @Override - public boolean isCancelled() { - final SettableFuture retVal = SettableFuture.create(); - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - retVal.set(SerializingServerCall.super.isCancelled()); - } - }); - try { - return retVal.get(); - } catch (InterruptedException e) { - throw new RuntimeException(ERROR_MSG, e); - } catch (ExecutionException e) { - throw new RuntimeException(ERROR_MSG, e); - } - } - - @Override - public void setMessageCompression(final boolean enabled) { - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - SerializingServerCall.super.setMessageCompression(enabled); - } - }); - } - - @Override - public void setCompression(final String compressor) { - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - SerializingServerCall.super.setCompression(compressor); - } - }); - } - - @Override - public Attributes getAttributes() { - final SettableFuture retVal = SettableFuture.create(); - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - retVal.set(SerializingServerCall.super.getAttributes()); - } - }); - try { - return retVal.get(); - } catch (InterruptedException e) { - throw new RuntimeException(ERROR_MSG, e); - } catch (ExecutionException e) { - throw new RuntimeException(ERROR_MSG, e); - } - } - - @Nullable - @Override - public String getAuthority() { - final SettableFuture retVal = SettableFuture.create(); - serializingExecutor.execute(new Runnable() { - @Override - public void run() { - retVal.set(SerializingServerCall.super.getAuthority()); - } - }); - try { - return retVal.get(); - } catch (InterruptedException e) { - throw new RuntimeException(ERROR_MSG, e); - } catch (ExecutionException e) { - throw new RuntimeException(ERROR_MSG, e); - } - } - } } diff --git a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java b/core/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java similarity index 60% rename from api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java rename to core/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java index ab992f9e743..b7ee473672f 100644 --- a/api/src/test/java/io/grpc/ServerCallTimeoutInterceptorTest.java +++ b/core/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java @@ -14,27 +14,36 @@ * limitations under the License. */ -package io.grpc; +package io.grpc.util; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import io.grpc.Context; +import io.grpc.IntegerMarshaller; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; import io.grpc.stub.ServerCalls; import io.grpc.stub.StreamObserver; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for {@link ServerCallTimeoutInterceptor}. */ +/** + * Unit tests for {@link ServerCallTimeoutInterceptor}. + */ @RunWith(JUnit4.class) public class ServerCallTimeoutInterceptorTest { - static final MethodDescriptor STREAMING_METHOD = + private static final MethodDescriptor STREAMING_METHOD = MethodDescriptor.newBuilder() .setType(MethodDescriptor.MethodType.BIDI_STREAMING) .setFullMethodName("some/bidi_streaming") @@ -42,38 +51,46 @@ public class ServerCallTimeoutInterceptorTest { .setResponseMarshaller(new IntegerMarshaller()) .build(); - static final MethodDescriptor UNARY_METHOD = + private static final MethodDescriptor UNARY_METHOD = STREAMING_METHOD.toBuilder() .setType(MethodDescriptor.MethodType.UNARY) .setFullMethodName("some/unary") .build(); + private static ServerCalls.UnaryMethod sleepingUnaryMethod(int sleepMillis) { + return new ServerCalls.UnaryMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + try { + Thread.sleep(sleepMillis); + if (Context.current().isCancelled()) { + responseObserver.onError(new StatusRuntimeException(Status.CANCELLED)); + return; + } + responseObserver.onNext(req); + responseObserver.onCompleted(); + } catch (InterruptedException e) { + Status status = Context.current().isCancelled() ? + Status.CANCELLED : Status.INTERNAL; + responseObserver.onError( + new StatusRuntimeException(status.withDescription(e.getMessage()))); + } + } + }; + } + @Test - public void unaryServerCall_setShouldInterrupt_withinTimeout_isNotInterrupted() { + public void unary_setShouldInterrupt_exceedingTimeout_isInterrupted() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall( - new ServerCalls.UnaryMethod() { - @Override - public void invoke(Integer req, StreamObserver responseObserver) { - try { - Thread.sleep(0); - responseObserver.onNext(req); - responseObserver.onCompleted(); - } catch (InterruptedException e) { - Status status = - Context.current().isCancelled() ? Status.CANCELLED : Status.INTERNAL; - responseObserver.onError(new StatusRuntimeException(status)); - } - } - }); + ServerCalls.asyncUnaryCall(sleepingUnaryMethod(100)); StringBuffer logBuf = new StringBuffer(); ServerTimeoutManager serverTimeoutManager = - ServerTimeoutManager.newBuilder(100, TimeUnit.MILLISECONDS) - .setShouldInterrupt(true) - .setLogFunction(logBuf::append) - .build(); + ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS) + .setShouldInterrupt(true) + .setLogFunction(logBuf::append) + .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); listener.onMessage(42); @@ -81,37 +98,23 @@ public void invoke(Integer req, StreamObserver responseObserver) { listener.onComplete(); serverTimeoutManager.shutdown(); - assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); - assertEquals(Status.Code.OK, serverCall.status.getCode()); - assertThat(logBuf.toString()).isEmpty(); + assertThat(serverCall.responses).isEmpty(); + assertEquals(Status.Code.CANCELLED, serverCall.status.getCode()); + assertEquals("server call timeout", serverCall.status.getDescription()); + assertThat(logBuf.toString()).startsWith("Interrupted RPC thread "); } @Test - public void unaryServerCall_setShouldInterrupt_exceedingTimeout_isInterrupted() { + public void unary_byDefault_exceedingTimeout_isCancelledButNotInterrupted() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall( - new ServerCalls.UnaryMethod() { - @Override - public void invoke(Integer req, StreamObserver responseObserver) { - try { - Thread.sleep(100); - responseObserver.onNext(req); - responseObserver.onCompleted(); - } catch (InterruptedException e) { - Status status = - Context.current().isCancelled() ? Status.CANCELLED : Status.INTERNAL; - responseObserver.onError(new StatusRuntimeException(status)); - } - } - }); + ServerCalls.asyncUnaryCall(sleepingUnaryMethod(100)); StringBuffer logBuf = new StringBuffer(); ServerTimeoutManager serverTimeoutManager = ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS) - .setShouldInterrupt(true) - .setLogFunction(logBuf::append) - .build(); + .setLogFunction(logBuf::append) + .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); listener.onMessage(42); @@ -121,81 +124,79 @@ public void invoke(Integer req, StreamObserver responseObserver) { assertThat(serverCall.responses).isEmpty(); assertEquals(Status.Code.CANCELLED, serverCall.status.getCode()); - assertThat(logBuf.toString()).startsWith("Interrupted RPC thread "); + assertEquals("server call timeout", serverCall.status.getDescription()); + assertThat(logBuf.toString()).isEmpty(); + } + + @Test + public void unary_setShouldInterrupt_withinTimeout_isNotCancelledOrInterrupted() { + ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall(sleepingUnaryMethod(0)); + StringBuffer logBuf = new StringBuffer(); + + ServerTimeoutManager serverTimeoutManager = + ServerTimeoutManager.newBuilder(100, TimeUnit.MILLISECONDS) + .setShouldInterrupt(true) + .setLogFunction(logBuf::append) + .build(); + ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) + .interceptCall(serverCall, new Metadata(), callHandler); + listener.onMessage(42); + listener.onHalfClose(); + listener.onComplete(); + serverTimeoutManager.shutdown(); + + assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); + assertEquals(Status.Code.OK, serverCall.status.getCode()); + assertThat(logBuf.toString()).isEmpty(); } @Test - public void unaryServerCallSkipsZeroTimeout() { + public void unary_setZeroTimeout_isNotIntercepted() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = ServerCalls.asyncUnaryCall( new ServerCalls.UnaryMethod() { @Override public void invoke(Integer req, StreamObserver responseObserver) { - try { - Thread.sleep(1); - responseObserver.onNext(req); - responseObserver.onCompleted(); - } catch (InterruptedException e) { - Status status = - Context.current().isCancelled() ? Status.CANCELLED : Status.INTERNAL; - responseObserver.onError(new StatusRuntimeException(status)); - } + responseObserver.onNext(req); + responseObserver.onCompleted(); } }); - AtomicBoolean isTimeoutScheduled = new AtomicBoolean(false); - ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 0, TimeUnit.MILLISECONDS, true, null) { - @Override - public boolean withTimeout(Runnable invocation) { - boolean result = super.withTimeout(invocation); - isTimeoutScheduled.set(result); - return result; - } - }; + ServerTimeoutManager serverTimeoutManager = + ServerTimeoutManager.newBuilder(0, TimeUnit.NANOSECONDS) + .setShouldInterrupt(true) + .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); - listener.onMessage(42); - listener.onHalfClose(); - listener.onComplete(); serverTimeoutManager.shutdown(); - assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); - assertEquals(Status.Code.OK, serverCall.status.getCode()); - assertFalse(isTimeoutScheduled.get()); + assertNotEquals( + ServerCallTimeoutInterceptor.TimeoutServerCallListener.class, listener.getClass()); } @Test - public void streamingServerCallIsNotIntercepted() { + public void streaming_isNotIntercepted() { ServerCallRecorder serverCall = new ServerCallRecorder(STREAMING_METHOD); ServerCallHandler callHandler = - ServerCalls.asyncBidiStreamingCall(new ServerCalls.BidiStreamingMethod() { - @Override - public StreamObserver invoke(StreamObserver responseObserver) { - return new EchoStreamObserver<>(responseObserver); - } - }); - AtomicBoolean isManagerCalled = new AtomicBoolean(false); + ServerCalls.asyncBidiStreamingCall( + new ServerCalls.BidiStreamingMethod() { + @Override + public StreamObserver invoke(StreamObserver responseObserver) { + return new EchoStreamObserver<>(responseObserver); + } + }); - ServerTimeoutManager serverTimeoutManager = new ServerTimeoutManager( - 100, TimeUnit.MILLISECONDS, true, null) { - @Override - public boolean withTimeout(Runnable invocation) { - isManagerCalled.set(true); - return true; - } - }; + ServerTimeoutManager serverTimeoutManager = + ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS).build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); - listener.onMessage(42); - listener.onHalfClose(); - listener.onComplete(); serverTimeoutManager.shutdown(); - assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); - assertEquals(Status.Code.OK, serverCall.status.getCode()); - assertFalse(isManagerCalled.get()); + assertNotEquals( + ServerCallTimeoutInterceptor.TimeoutServerCallListener.class, listener.getClass()); } private static class ServerCallRecorder extends ServerCall { From c20a372e3b7f2b1312922297aa1293e19ef0f058 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Sun, 17 Sep 2023 13:27:17 +0800 Subject: [PATCH 17/22] maintain after merge --- .../java/io/grpc/util/SerializingServerCall.java | 16 ++++++++++++++++ .../grpc/util/ServerCallTimeoutInterceptor.java | 0 .../java/io/grpc/util/ServerTimeoutManager.java | 0 .../util/ServerCallTimeoutInterceptorTest.java | 4 ++-- 4 files changed, 18 insertions(+), 2 deletions(-) rename {core => util}/src/main/java/io/grpc/util/SerializingServerCall.java (88%) rename {core => util}/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java (100%) rename {core => util}/src/main/java/io/grpc/util/ServerTimeoutManager.java (100%) rename {core => util}/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java (98%) diff --git a/core/src/main/java/io/grpc/util/SerializingServerCall.java b/util/src/main/java/io/grpc/util/SerializingServerCall.java similarity index 88% rename from core/src/main/java/io/grpc/util/SerializingServerCall.java rename to util/src/main/java/io/grpc/util/SerializingServerCall.java index d06606250e4..2bffb57654c 100644 --- a/core/src/main/java/io/grpc/util/SerializingServerCall.java +++ b/util/src/main/java/io/grpc/util/SerializingServerCall.java @@ -1,3 +1,19 @@ +/* + * Copyright 2017 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.grpc.util; import com.google.common.util.concurrent.MoreExecutors; diff --git a/core/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java b/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java similarity index 100% rename from core/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java rename to util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java diff --git a/core/src/main/java/io/grpc/util/ServerTimeoutManager.java b/util/src/main/java/io/grpc/util/ServerTimeoutManager.java similarity index 100% rename from core/src/main/java/io/grpc/util/ServerTimeoutManager.java rename to util/src/main/java/io/grpc/util/ServerTimeoutManager.java diff --git a/core/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java b/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java similarity index 98% rename from core/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java rename to util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java index b7ee473672f..9bc6df1a3df 100644 --- a/core/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java +++ b/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java @@ -70,8 +70,8 @@ public void invoke(Integer req, StreamObserver responseObserver) { responseObserver.onNext(req); responseObserver.onCompleted(); } catch (InterruptedException e) { - Status status = Context.current().isCancelled() ? - Status.CANCELLED : Status.INTERNAL; + Status status = Context.current().isCancelled() + ? Status.CANCELLED : Status.INTERNAL; responseObserver.onError( new StatusRuntimeException(status.withDescription(e.getMessage()))); } From 5b78a539b9330523e0631243e55971862ba906f3 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Sun, 17 Sep 2023 14:06:21 +0800 Subject: [PATCH 18/22] Improve javadoc comments --- .../io/grpc/util/ServerCallTimeoutInterceptor.java | 10 +++++----- .../main/java/io/grpc/util/ServerTimeoutManager.java | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java b/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java index 897bd77dc70..d19fa6ddc6b 100644 --- a/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java +++ b/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java @@ -25,13 +25,13 @@ import io.grpc.ServerInterceptor; /** - * An optional ServerInterceptor that can interrupt server calls that are running for too long time. - * In this way, it prevents problematic code from using up all threads. + * An optional ServerInterceptor to stop server calls at best effort when the timeout is reached. + * In this way, it prevents problematic code from excessively using up all threads in the pool. * - *

How to use: you can add it to your server using ServerBuilder#intercept(ServerInterceptor). + *

How to use: install it to your server using ServerBuilder#intercept(ServerInterceptor). * *

Limitation: it only applies the timeout to unary calls - * (streaming calls will still run without timeout). + * (long-running streaming calls are allowed, so they can run without this timeout limit). */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10361") public class ServerCallTimeoutInterceptor implements ServerInterceptor { @@ -89,7 +89,7 @@ public void onMessage(ReqT message) { } /** - * Intercepts onHalfClose() because the application RPC method is called in it. See + * Adds interruption here because the application RPC method is called in halfClose(). See * io.grpc.stub.ServerCalls.UnaryServerCallHandler.UnaryServerCallListener */ @Override diff --git a/util/src/main/java/io/grpc/util/ServerTimeoutManager.java b/util/src/main/java/io/grpc/util/ServerTimeoutManager.java index c3ebf384a8e..6d767edd61d 100644 --- a/util/src/main/java/io/grpc/util/ServerTimeoutManager.java +++ b/util/src/main/java/io/grpc/util/ServerTimeoutManager.java @@ -39,7 +39,7 @@ public class ServerTimeoutManager { /** * Creates a builder. * - * @param timeout Configurable timeout threshold. A value less than 0 (e.g. 0 or -1) means not to + * @param timeout Configurable timeout threshold. A non-positive value (e.g. 0 or -1) means not to * check timeout. * @param unit The unit of the timeout. */ @@ -63,7 +63,8 @@ private ServerTimeoutManager(int timeout, TimeUnit unit, } /** - * Please call shutdown() when the application exits. You can add a JVM shutdown hook. + * Please call shutdown() when the application exits. + * You can add a JVM shutdown hook to call it. */ public void shutdown() { scheduler.shutdownNow(); From 06d771cf01b648b9861e53044dcd6a84e19df861 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Sun, 17 Sep 2023 15:17:49 +0800 Subject: [PATCH 19/22] Close the context --- .../util/ServerCallTimeoutInterceptor.java | 4 + .../io/grpc/util/ServerTimeoutManager.java | 4 + .../ServerCallTimeoutInterceptorTest.java | 107 +++++++++++++++--- 3 files changed, 101 insertions(+), 14 deletions(-) diff --git a/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java b/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java index d19fa6ddc6b..f63ae364773 100644 --- a/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java +++ b/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java @@ -104,6 +104,8 @@ public void onCancel() { super.onCancel(); } finally { context.detach(previous); + // Cancel the timeout when the call is finished. + context.close(); } } @@ -114,6 +116,8 @@ public void onComplete() { super.onComplete(); } finally { context.detach(previous); + // Cancel the timeout when the call is finished. + context.close(); } } diff --git a/util/src/main/java/io/grpc/util/ServerTimeoutManager.java b/util/src/main/java/io/grpc/util/ServerTimeoutManager.java index 6d767edd61d..9fbbedb34be 100644 --- a/util/src/main/java/io/grpc/util/ServerTimeoutManager.java +++ b/util/src/main/java/io/grpc/util/ServerTimeoutManager.java @@ -84,6 +84,10 @@ public Context.CancellableContext startTimeoutContext(ServerCall serverCal if (c.cancellationCause() == null) { return; } + if (logFunction != null) { + logFunction.accept("server call timeout for " + + serverCall.getMethodDescriptor().getFullMethodName()); + } serverCall.close(Status.CANCELLED.withDescription("server call timeout"), new Metadata()); }; Context.CancellableContext context = Context.current().withDeadline( diff --git a/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java b/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java index 9bc6df1a3df..5cc405a87ae 100644 --- a/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java +++ b/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java @@ -30,7 +30,10 @@ import io.grpc.StatusRuntimeException; import io.grpc.stub.ServerCalls; import io.grpc.stub.StreamObserver; +import java.io.PrintWriter; +import java.io.StringWriter; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.concurrent.TimeUnit; @@ -83,13 +86,13 @@ public void invoke(Integer req, StreamObserver responseObserver) { public void unary_setShouldInterrupt_exceedingTimeout_isInterrupted() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall(sleepingUnaryMethod(100)); - StringBuffer logBuf = new StringBuffer(); + ServerCalls.asyncUnaryCall(sleepingUnaryMethod(20)); + StringWriter logWriter = new StringWriter(); ServerTimeoutManager serverTimeoutManager = ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS) .setShouldInterrupt(true) - .setLogFunction(logBuf::append) + .setLogFunction(new PrintWriter(logWriter)::println) .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); @@ -101,19 +104,20 @@ public void unary_setShouldInterrupt_exceedingTimeout_isInterrupted() { assertThat(serverCall.responses).isEmpty(); assertEquals(Status.Code.CANCELLED, serverCall.status.getCode()); assertEquals("server call timeout", serverCall.status.getDescription()); - assertThat(logBuf.toString()).startsWith("Interrupted RPC thread "); + assertThat(logWriter.toString()) + .startsWith("server call timeout for some/unary\nInterrupted RPC thread "); } @Test public void unary_byDefault_exceedingTimeout_isCancelledButNotInterrupted() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall(sleepingUnaryMethod(100)); - StringBuffer logBuf = new StringBuffer(); + ServerCalls.asyncUnaryCall(sleepingUnaryMethod(20)); + StringWriter logWriter = new StringWriter(); ServerTimeoutManager serverTimeoutManager = ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS) - .setLogFunction(logBuf::append) + .setLogFunction(new PrintWriter(logWriter)::println) .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); @@ -125,7 +129,7 @@ public void unary_byDefault_exceedingTimeout_isCancelledButNotInterrupted() { assertThat(serverCall.responses).isEmpty(); assertEquals(Status.Code.CANCELLED, serverCall.status.getCode()); assertEquals("server call timeout", serverCall.status.getDescription()); - assertThat(logBuf.toString()).isEmpty(); + assertEquals("server call timeout for some/unary\n", logWriter.toString()); } @Test @@ -133,12 +137,12 @@ public void unary_setShouldInterrupt_withinTimeout_isNotCancelledOrInterrupted() ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = ServerCalls.asyncUnaryCall(sleepingUnaryMethod(0)); - StringBuffer logBuf = new StringBuffer(); + StringWriter logWriter = new StringWriter(); ServerTimeoutManager serverTimeoutManager = ServerTimeoutManager.newBuilder(100, TimeUnit.MILLISECONDS) .setShouldInterrupt(true) - .setLogFunction(logBuf::append) + .setLogFunction(new PrintWriter(logWriter)::println) .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); @@ -149,7 +153,7 @@ public void unary_setShouldInterrupt_withinTimeout_isNotCancelledOrInterrupted() assertThat(serverCall.responses).isEqualTo(Collections.singletonList(42)); assertEquals(Status.Code.OK, serverCall.status.getCode()); - assertThat(logBuf.toString()).isEmpty(); + assertThat(logWriter.toString()).isEmpty(); } @Test @@ -166,9 +170,7 @@ public void invoke(Integer req, StreamObserver responseObserver) { }); ServerTimeoutManager serverTimeoutManager = - ServerTimeoutManager.newBuilder(0, TimeUnit.NANOSECONDS) - .setShouldInterrupt(true) - .build(); + ServerTimeoutManager.newBuilder(0, TimeUnit.NANOSECONDS).build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); serverTimeoutManager.shutdown(); @@ -199,6 +201,83 @@ public StreamObserver invoke(StreamObserver responseObserver) ServerCallTimeoutInterceptor.TimeoutServerCallListener.class, listener.getClass()); } + @Test + public void allStagesCanKnowCancellation() throws Exception { + List cancelledStages = Collections.synchronizedList(new ArrayList<>()); + ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); + ServerCallHandler callHandler = new ServerCallHandler() { + private final ServerCallHandler innerHandler = + ServerCalls.asyncUnaryCall(sleepingUnaryMethod(0)); + + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + ServerCall.Listener delegate = innerHandler.startCall(call, headers); + return new ServerCall.Listener() { + @Override + public void onMessage(Integer message) { + if (Context.current().isCancelled()) { + cancelledStages.add("onMessage"); + } + delegate.onMessage(message); + } + + @Override + public void onHalfClose() { + if (Context.current().isCancelled()) { + cancelledStages.add("onHalfClose"); + } + delegate.onHalfClose(); + } + + @Override + public void onCancel() { + if (Context.current().isCancelled()) { + cancelledStages.add("onCancel"); + } + delegate.onCancel(); + } + + @Override + public void onComplete() { + if (Context.current().isCancelled()) { + cancelledStages.add("onComplete"); + } + delegate.onComplete(); + } + + @Override + public void onReady() { + if (Context.current().isCancelled()) { + cancelledStages.add("onReady"); + } + delegate.onReady(); + } + }; + } + }; + + ServerTimeoutManager serverTimeoutManager = + ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS).build(); + ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) + .interceptCall(serverCall, new Metadata(), callHandler); + // Let it timeout + Thread.sleep(20); + listener.onMessage(42); + listener.onHalfClose(); + listener.onReady(); + listener.onComplete(); + listener.onCancel(); + serverTimeoutManager.shutdown(); + + assertThat(serverCall.responses).isEmpty(); + assertEquals(Status.Code.CANCELLED, serverCall.status.getCode()); + assertEquals("server call timeout", serverCall.status.getDescription()); + assertEquals( + Arrays.asList("onMessage", "onHalfClose", "onReady", "onComplete", "onCancel"), + cancelledStages); + } + private static class ServerCallRecorder extends ServerCall { private final MethodDescriptor methodDescriptor; private final List requestCalls = new ArrayList<>(); From df83e54a3b9cbe09067c466a080f23681b4343ed Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Wed, 20 Sep 2023 14:39:19 +0800 Subject: [PATCH 20/22] Skip listener callback execution if context has been cancelled (server call should have been closed) --- .../util/ServerCallTimeoutInterceptor.java | 24 +-- .../io/grpc/util/ServerTimeoutManager.java | 57 ++++--- .../ForwardingScheduledExecutorService.java | 54 +++++++ .../ServerCallTimeoutInterceptorTest.java | 143 +++++++++++++----- 4 files changed, 200 insertions(+), 78 deletions(-) create mode 100644 util/src/test/java/io/grpc/util/ForwardingScheduledExecutorService.java diff --git a/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java b/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java index f63ae364773..bdae2e70494 100644 --- a/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java +++ b/util/src/main/java/io/grpc/util/ServerCallTimeoutInterceptor.java @@ -80,12 +80,7 @@ private TimeoutServerCallListener( @Override public void onMessage(ReqT message) { - Context previous = context.attach(); - try { - super.onMessage(message); - } finally { - context.detach(previous); - } + serverTimeoutManager.runWithContext(context, () -> super.onMessage(message)); } /** @@ -94,16 +89,14 @@ public void onMessage(ReqT message) { */ @Override public void onHalfClose() { - serverTimeoutManager.withInterruption(context, super::onHalfClose); + serverTimeoutManager.runWithContextInterruptibly(context, super::onHalfClose); } @Override public void onCancel() { - Context previous = context.attach(); try { - super.onCancel(); + serverTimeoutManager.runWithContext(context, super::onCancel); } finally { - context.detach(previous); // Cancel the timeout when the call is finished. context.close(); } @@ -111,11 +104,9 @@ public void onCancel() { @Override public void onComplete() { - Context previous = context.attach(); try { - super.onComplete(); + serverTimeoutManager.runWithContext(context, super::onComplete); } finally { - context.detach(previous); // Cancel the timeout when the call is finished. context.close(); } @@ -123,12 +114,7 @@ public void onComplete() { @Override public void onReady() { - Context previous = context.attach(); - try { - super.onReady(); - } finally { - context.detach(previous); - } + serverTimeoutManager.runWithContext(context, super::onReady); } } } diff --git a/util/src/main/java/io/grpc/util/ServerTimeoutManager.java b/util/src/main/java/io/grpc/util/ServerTimeoutManager.java index 9fbbedb34be..4904894dd80 100644 --- a/util/src/main/java/io/grpc/util/ServerTimeoutManager.java +++ b/util/src/main/java/io/grpc/util/ServerTimeoutManager.java @@ -18,7 +18,6 @@ import com.google.common.util.concurrent.MoreExecutors; import io.grpc.Context; -import io.grpc.Deadline; import io.grpc.ExperimentalApi; import io.grpc.Metadata; import io.grpc.ServerCall; @@ -50,16 +49,19 @@ public static Builder newBuilder(int timeout, TimeUnit unit) { private final int timeout; private final TimeUnit unit; private final boolean shouldInterrupt; + @Nullable private final Consumer logFunction; - - private final ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(); + private final ScheduledExecutorService scheduler; private ServerTimeoutManager(int timeout, TimeUnit unit, - boolean shouldInterrupt, Consumer logFunction) { + boolean shouldInterrupt, + @Nullable Consumer logFunction, + @Nullable ScheduledExecutorService scheduler) { this.timeout = timeout; this.unit = unit; this.shouldInterrupt = shouldInterrupt; this.logFunction = logFunction; + this.scheduler = scheduler != null ? scheduler : Executors.newSingleThreadScheduledExecutor(); } /** @@ -90,28 +92,40 @@ public Context.CancellableContext startTimeoutContext(ServerCall serverCal } serverCall.close(Status.CANCELLED.withDescription("server call timeout"), new Metadata()); }; - Context.CancellableContext context = Context.current().withDeadline( - Deadline.after(timeout, unit), scheduler); + Context.CancellableContext context = Context.current() + .withDeadlineAfter(timeout, unit, scheduler); context.addListener(callCloser, MoreExecutors.directExecutor()); return context; } /** - * Executes the application RPC invocation in the timeout context. + * Executes the application invocation in the timeout context. + * Skips execution if context has been cancelled. + * + * @param context The timeout context. + * @param invocation The application invocation that processes a request. + */ + public void runWithContext(Context.CancellableContext context, Runnable invocation) { + if (context.isCancelled()) { + return; + } + context.run(invocation); + } + + /** + * Executes the application invocation in the timeout context, may interrupt the current thread. + * Skips execution if context has been cancelled. * *

When the timeout is reached: It cancels the context around the RPC invocation. And * if shouldInterrupt is {@code true}, it also interrupts the current worker thread. * * @param context The timeout context. - * @param invocation The application RPC invocation that processes a request. - * @return true if a timeout is scheduled + * @param invocation The application invocation that processes a request. */ - public boolean withInterruption(Context.CancellableContext context, Runnable invocation) { - if (timeout <= 0 || scheduler.isShutdown()) { - invocation.run(); - return false; + public void runWithContextInterruptibly(Context.CancellableContext context, Runnable invocation) { + if (context.isCancelled()) { + return; } - AtomicReference threadRef = shouldInterrupt ? new AtomicReference<>(Thread.currentThread()) : null; Context.CancellationListener interruption = c -> { @@ -146,8 +160,6 @@ public boolean withInterruption(Context.CancellableContext context, Runnable inv Thread.interrupted(); } } - - return true; } /** Builder for constructing ServerTimeoutManager instances. */ @@ -157,6 +169,7 @@ public static class Builder { private boolean shouldInterrupt; private Consumer logFunction; + private ScheduledExecutorService scheduler; private Builder(int timeout, TimeUnit unit) { this.timeout = timeout; @@ -183,9 +196,19 @@ public Builder setLogFunction(Consumer logFunction) { return this; } + /** + * Sets a custom scheduler instance. If not set, a default scheduler is used. + * + * @param scheduler An custom scheduler. + */ + public Builder setScheduler(ScheduledExecutorService scheduler) { + this.scheduler = scheduler; + return this; + } + /** Construct new ServerTimeoutManager. */ public ServerTimeoutManager build() { - return new ServerTimeoutManager(timeout, unit, shouldInterrupt, logFunction); + return new ServerTimeoutManager(timeout, unit, shouldInterrupt, logFunction, scheduler); } } } diff --git a/util/src/test/java/io/grpc/util/ForwardingScheduledExecutorService.java b/util/src/test/java/io/grpc/util/ForwardingScheduledExecutorService.java new file mode 100644 index 00000000000..271aa518719 --- /dev/null +++ b/util/src/test/java/io/grpc/util/ForwardingScheduledExecutorService.java @@ -0,0 +1,54 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import com.google.common.util.concurrent.ForwardingExecutorService; +import java.util.concurrent.Callable; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +/** + * Forwards all methods to delegate. + */ +abstract class ForwardingScheduledExecutorService extends ForwardingExecutorService + implements ScheduledExecutorService { + @Override + protected abstract ScheduledExecutorService delegate(); + + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + return delegate().schedule(callable, delay, unit); + } + + @Override + public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { + return delegate().schedule(command, delay, unit); + } + + @Override + public ScheduledFuture scheduleAtFixedRate( + Runnable command, long initialDelay, long period, TimeUnit unit) { + return delegate().scheduleAtFixedRate(command, initialDelay, period, unit); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay( + Runnable command, long initialDelay, long delay, TimeUnit unit) { + return delegate().scheduleWithFixedDelay(command, initialDelay, delay, unit); + } +} diff --git a/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java b/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java index 5cc405a87ae..6d0fa43bf36 100644 --- a/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java +++ b/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java @@ -20,6 +20,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import com.google.common.util.concurrent.testing.TestingExecutors; import io.grpc.Context; import io.grpc.IntegerMarshaller; import io.grpc.Metadata; @@ -33,9 +34,16 @@ import java.io.PrintWriter; import java.io.StringWriter; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Queue; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Delayed; +import java.util.concurrent.FutureTask; +import java.util.concurrent.RunnableScheduledFuture; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import org.junit.Test; import org.junit.runner.RunWith; @@ -60,11 +68,16 @@ public class ServerCallTimeoutInterceptorTest { .setFullMethodName("some/unary") .build(); - private static ServerCalls.UnaryMethod sleepingUnaryMethod(int sleepMillis) { + private static ServerCalls.UnaryMethod sleepingUnaryMethod( + int sleepMillis, + MockTimeoutScheduler scheduler) { return new ServerCalls.UnaryMethod() { @Override public void invoke(Integer req, StreamObserver responseObserver) { try { + if (sleepMillis > 0) { + scheduler.timeoutImmediately(); + } Thread.sleep(sleepMillis); if (Context.current().isCancelled()) { responseObserver.onError(new StatusRuntimeException(Status.CANCELLED)); @@ -84,15 +97,17 @@ public void invoke(Integer req, StreamObserver responseObserver) { @Test public void unary_setShouldInterrupt_exceedingTimeout_isInterrupted() { + StringWriter logWriter = new StringWriter(); + MockTimeoutScheduler scheduler = new MockTimeoutScheduler(); ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall(sleepingUnaryMethod(20)); - StringWriter logWriter = new StringWriter(); + ServerCalls.asyncUnaryCall(sleepingUnaryMethod(20, scheduler)); ServerTimeoutManager serverTimeoutManager = - ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS) + ServerTimeoutManager.newBuilder(10, TimeUnit.MILLISECONDS) .setShouldInterrupt(true) .setLogFunction(new PrintWriter(logWriter)::println) + .setScheduler(scheduler) .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); @@ -110,14 +125,16 @@ public void unary_setShouldInterrupt_exceedingTimeout_isInterrupted() { @Test public void unary_byDefault_exceedingTimeout_isCancelledButNotInterrupted() { + StringWriter logWriter = new StringWriter(); + MockTimeoutScheduler scheduler = new MockTimeoutScheduler(); ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall(sleepingUnaryMethod(20)); - StringWriter logWriter = new StringWriter(); + ServerCalls.asyncUnaryCall(sleepingUnaryMethod(20, scheduler)); ServerTimeoutManager serverTimeoutManager = - ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS) + ServerTimeoutManager.newBuilder(10, TimeUnit.MILLISECONDS) .setLogFunction(new PrintWriter(logWriter)::println) + .setScheduler(scheduler) .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); @@ -134,15 +151,17 @@ public void unary_byDefault_exceedingTimeout_isCancelledButNotInterrupted() { @Test public void unary_setShouldInterrupt_withinTimeout_isNotCancelledOrInterrupted() { + StringWriter logWriter = new StringWriter(); + MockTimeoutScheduler scheduler = new MockTimeoutScheduler(); ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall(sleepingUnaryMethod(0)); - StringWriter logWriter = new StringWriter(); + ServerCalls.asyncUnaryCall(sleepingUnaryMethod(0, scheduler)); ServerTimeoutManager serverTimeoutManager = - ServerTimeoutManager.newBuilder(100, TimeUnit.MILLISECONDS) + ServerTimeoutManager.newBuilder(10, TimeUnit.MILLISECONDS) .setShouldInterrupt(true) .setLogFunction(new PrintWriter(logWriter)::println) + .setScheduler(scheduler) .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); @@ -202,67 +221,51 @@ public StreamObserver invoke(StreamObserver responseObserver) } @Test - public void allStagesCanKnowCancellation() throws Exception { - List cancelledStages = Collections.synchronizedList(new ArrayList<>()); + public void canSkipEachStageUponCancellation() { + List notSkippedStages = Collections.synchronizedList(new ArrayList<>()); + MockTimeoutScheduler scheduler = new MockTimeoutScheduler(); ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler callHandler = new ServerCallHandler() { - private final ServerCallHandler innerHandler = - ServerCalls.asyncUnaryCall(sleepingUnaryMethod(0)); @Override public ServerCall.Listener startCall( ServerCall call, Metadata headers) { - ServerCall.Listener delegate = innerHandler.startCall(call, headers); return new ServerCall.Listener() { @Override public void onMessage(Integer message) { - if (Context.current().isCancelled()) { - cancelledStages.add("onMessage"); - } - delegate.onMessage(message); + notSkippedStages.add("onMessage"); } @Override public void onHalfClose() { - if (Context.current().isCancelled()) { - cancelledStages.add("onHalfClose"); - } - delegate.onHalfClose(); + notSkippedStages.add("onHalfClose"); } @Override public void onCancel() { - if (Context.current().isCancelled()) { - cancelledStages.add("onCancel"); - } - delegate.onCancel(); + notSkippedStages.add("onCancel"); } @Override public void onComplete() { - if (Context.current().isCancelled()) { - cancelledStages.add("onComplete"); - } - delegate.onComplete(); + notSkippedStages.add("onComplete"); } @Override public void onReady() { - if (Context.current().isCancelled()) { - cancelledStages.add("onReady"); - } - delegate.onReady(); + notSkippedStages.add("onReady"); } }; } }; ServerTimeoutManager serverTimeoutManager = - ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS).build(); + ServerTimeoutManager.newBuilder(1, TimeUnit.NANOSECONDS) + .setScheduler(scheduler) + .build(); ServerCall.Listener listener = new ServerCallTimeoutInterceptor(serverTimeoutManager) .interceptCall(serverCall, new Metadata(), callHandler); - // Let it timeout - Thread.sleep(20); + scheduler.timeoutImmediately(); listener.onMessage(42); listener.onHalfClose(); listener.onReady(); @@ -273,9 +276,7 @@ public void onReady() { assertThat(serverCall.responses).isEmpty(); assertEquals(Status.Code.CANCELLED, serverCall.status.getCode()); assertEquals("server call timeout", serverCall.status.getDescription()); - assertEquals( - Arrays.asList("onMessage", "onHalfClose", "onReady", "onComplete", "onCancel"), - cancelledStages); + assertEquals(Collections.emptyList(), notSkippedStages); } private static class ServerCallRecorder extends ServerCall { @@ -347,4 +348,62 @@ public void onCompleted() { responseObserver.onCompleted(); } } + + // Enables manually controlling the schedule + private static class MockTimeoutScheduler extends ForwardingScheduledExecutorService { + private final ScheduledExecutorService delegate = TestingExecutors.noOpScheduledExecutor(); + private final Queue> queue = new ConcurrentLinkedQueue<>(); + + private void timeoutImmediately() { + while (!queue.isEmpty()) { + queue.poll().run(); + } + } + + @Override + public ScheduledExecutorService delegate() { + return delegate; + } + + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + ScheduledFutureTask futureTask = new ScheduledFutureTask<>(callable); + queue.add(futureTask); + return futureTask; + } + + @Override + public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { + ScheduledFutureTask futureTask = new ScheduledFutureTask<>(command, null); + queue.add(futureTask); + return futureTask; + } + + private static class ScheduledFutureTask + extends FutureTask implements RunnableScheduledFuture { + + private ScheduledFutureTask(Callable callable) { + super(callable); + } + + private ScheduledFutureTask(Runnable runnable, V result) { + super(runnable, result); + } + + @Override + public boolean isPeriodic() { + return false; + } + + @Override + public long getDelay(TimeUnit unit) { + return 0; + } + + @Override + public int compareTo(Delayed o) { + return 0; + } + } + } } From 1a9a9bcc58e9ca6378884ae8ae905c100e15f275 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Sun, 15 Oct 2023 21:19:07 +0800 Subject: [PATCH 21/22] improve code coverage --- .../io/grpc/util/SerializingServerCall.java | 16 ++++--------- .../grpc/util/SerializingServerCallTest.java | 24 +++++++++++++++++++ .../ServerCallTimeoutInterceptorTest.java | 4 +++- 3 files changed, 31 insertions(+), 13 deletions(-) create mode 100644 util/src/test/java/io/grpc/util/SerializingServerCallTest.java diff --git a/util/src/main/java/io/grpc/util/SerializingServerCall.java b/util/src/main/java/io/grpc/util/SerializingServerCall.java index 2bffb57654c..f96f4abf7f4 100644 --- a/util/src/main/java/io/grpc/util/SerializingServerCall.java +++ b/util/src/main/java/io/grpc/util/SerializingServerCall.java @@ -99,9 +99,7 @@ public void run() { }); try { return retVal.get(); - } catch (InterruptedException e) { - throw new RuntimeException(ERROR_MSG, e); - } catch (ExecutionException e) { + } catch (InterruptedException | ExecutionException e) { throw new RuntimeException(ERROR_MSG, e); } } @@ -117,9 +115,7 @@ public void run() { }); try { return retVal.get(); - } catch (InterruptedException e) { - throw new RuntimeException(ERROR_MSG, e); - } catch (ExecutionException e) { + } catch (InterruptedException | ExecutionException e) { throw new RuntimeException(ERROR_MSG, e); } } @@ -155,9 +151,7 @@ public void run() { }); try { return retVal.get(); - } catch (InterruptedException e) { - throw new RuntimeException(ERROR_MSG, e); - } catch (ExecutionException e) { + } catch (InterruptedException | ExecutionException e) { throw new RuntimeException(ERROR_MSG, e); } } @@ -174,9 +168,7 @@ public void run() { }); try { return retVal.get(); - } catch (InterruptedException e) { - throw new RuntimeException(ERROR_MSG, e); - } catch (ExecutionException e) { + } catch (InterruptedException | ExecutionException e) { throw new RuntimeException(ERROR_MSG, e); } } diff --git a/util/src/test/java/io/grpc/util/SerializingServerCallTest.java b/util/src/test/java/io/grpc/util/SerializingServerCallTest.java new file mode 100644 index 00000000000..efd6af014f0 --- /dev/null +++ b/util/src/test/java/io/grpc/util/SerializingServerCallTest.java @@ -0,0 +1,24 @@ +package io.grpc.util; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import io.grpc.ServerCall; +import io.grpc.internal.NoopServerCall; +import org.junit.Test; + +public class SerializingServerCallTest { + + @Test + public void testMethods() { + ServerCall testCall = new SerializingServerCall<>(new NoopServerCall<>()); + testCall.setCompression("gzip"); + testCall.setMessageCompression(true); + assertTrue(testCall.isReady()); + assertFalse(testCall.isCancelled()); + assertNull(testCall.getAuthority()); + assertNotNull(testCall.getAttributes()); + } +} diff --git a/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java b/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java index 6d0fa43bf36..c57ca459996 100644 --- a/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java +++ b/util/src/test/java/io/grpc/util/ServerCallTimeoutInterceptorTest.java @@ -63,9 +63,11 @@ public class ServerCallTimeoutInterceptorTest { .build(); private static final MethodDescriptor UNARY_METHOD = - STREAMING_METHOD.toBuilder() + MethodDescriptor.newBuilder() .setType(MethodDescriptor.MethodType.UNARY) .setFullMethodName("some/unary") + .setRequestMarshaller(new IntegerMarshaller()) + .setResponseMarshaller(new IntegerMarshaller()) .build(); private static ServerCalls.UnaryMethod sleepingUnaryMethod( From e6bb04b09774658a3c25e4794e7e887eade609c3 Mon Sep 17 00:00:00 2001 From: Dongqing Hu Date: Fri, 20 Oct 2023 15:50:25 +0800 Subject: [PATCH 22/22] add copyright --- .../io/grpc/util/SerializingServerCallTest.java | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/util/src/test/java/io/grpc/util/SerializingServerCallTest.java b/util/src/test/java/io/grpc/util/SerializingServerCallTest.java index efd6af014f0..c9ffffa56ac 100644 --- a/util/src/test/java/io/grpc/util/SerializingServerCallTest.java +++ b/util/src/test/java/io/grpc/util/SerializingServerCallTest.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.grpc.util; import static org.junit.Assert.assertFalse;