diff --git a/stub/src/main/java/io/grpc/stub/ClientCalls.java b/stub/src/main/java/io/grpc/stub/ClientCalls.java index 12adecbd84d..cf770a991e1 100644 --- a/stub/src/main/java/io/grpc/stub/ClientCalls.java +++ b/stub/src/main/java/io/grpc/stub/ClientCalls.java @@ -124,6 +124,7 @@ public static RespT blockingUnaryCall(ClientCall call public static RespT blockingUnaryCall( Channel channel, MethodDescriptor method, CallOptions callOptions, ReqT req) { ThreadlessExecutor executor = new ThreadlessExecutor(); + boolean interrupt = false; ClientCall call = channel.newCall(method, callOptions.withExecutor(executor)); try { ListenableFuture responseFuture = futureUnaryCall(call, req); @@ -131,18 +132,22 @@ public static RespT blockingUnaryCall( try { executor.waitAndDrain(); } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw Status.CANCELLED - .withDescription("Call was interrupted") - .withCause(e) - .asRuntimeException(); + interrupt = true; + call.cancel("Thread interrupted", e); + // Now wait for onClose() to be called, so interceptors can clean up } } return getUnchecked(responseFuture); } catch (RuntimeException e) { + // Something very bad happened. All bets are off; it may be dangerous to wait for onClose(). throw cancelThrow(call, e); } catch (Error e) { + // Something very bad happened. All bets are off; it may be dangerous to wait for onClose(). throw cancelThrow(call, e); + } finally { + if (interrupt) { + Thread.currentThread().interrupt(); + } } } @@ -209,7 +214,7 @@ private static V getUnchecked(Future future) { } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw Status.CANCELLED - .withDescription("Call was interrupted") + .withDescription("Thread interrupted") .withCause(e) .asRuntimeException(); } catch (ExecutionException e) { @@ -553,30 +558,45 @@ ClientCall.Listener listener() { return listener; } - private Object waitForNext() throws InterruptedException { - if (threadless == null) { - return buffer.take(); - } else { - Object next = buffer.poll(); - while (next == null) { - threadless.waitAndDrain(); - next = buffer.poll(); + private Object waitForNext() { + boolean interrupt = false; + try { + if (threadless == null) { + while (true) { + try { + return buffer.take(); + } catch (InterruptedException ie) { + interrupt = true; + call.cancel("Thread interrupted", ie); + // Now wait for onClose() to be called, to guarantee BlockingQueue doesn't fill + } + } + } else { + Object next; + while ((next = buffer.poll()) == null) { + try { + threadless.waitAndDrain(); + } catch (InterruptedException ie) { + interrupt = true; + call.cancel("Thread interrupted", ie); + // Now wait for onClose() to be called, so interceptors can clean up + } + } + return next; + } + } finally { + if (interrupt) { + Thread.currentThread().interrupt(); } - return next; } } @Override public boolean hasNext() { - if (last == null) { - try { - // Will block here indefinitely waiting for content. RPC timeouts defend against permanent - // hangs here as the call will become closed. - last = waitForNext(); - } catch (InterruptedException ie) { - Thread.currentThread().interrupt(); - throw Status.CANCELLED.withDescription("interrupted").withCause(ie).asRuntimeException(); - } + while (last == null) { + // Will block here indefinitely waiting for content. RPC timeouts defend against permanent + // hangs here as the call will become closed. + last = waitForNext(); } if (last instanceof StatusRuntimeException) { // Rethrow the exception with a new stacktrace. @@ -650,15 +670,14 @@ private static final class ThreadlessExecutor extends ConcurrentLinkedQueue STREAMING_METHOD = + private static final MethodDescriptor UNARY_METHOD = MethodDescriptor.newBuilder() - .setType(MethodDescriptor.MethodType.BIDI_STREAMING) + .setType(MethodDescriptor.MethodType.UNARY) .setFullMethodName("some/method") .setRequestMarshaller(new IntegerMarshaller()) .setResponseMarshaller(new IntegerMarshaller()) .build(); + private static final MethodDescriptor SERVER_STREAMING_METHOD = + UNARY_METHOD.toBuilder().setType(MethodDescriptor.MethodType.SERVER_STREAMING).build(); + private static final MethodDescriptor BIDI_STREAMING_METHOD = + UNARY_METHOD.toBuilder().setType(MethodDescriptor.MethodType.BIDI_STREAMING).build(); private Server server; private ManagedChannel channel; @@ -130,6 +140,69 @@ public void start(io.grpc.ClientCall.Listener listener, Metadata headers } } + @Test + public void blockingUnaryCall2_success() throws Exception { + Integer req = 2; + final Integer resp = 3; + + class BasicUnaryResponse implements UnaryMethod { + Integer request; + + @Override public void invoke(Integer request, StreamObserver responseObserver) { + this.request = request; + responseObserver.onNext(resp); + responseObserver.onCompleted(); + } + } + + BasicUnaryResponse service = new BasicUnaryResponse(); + server = InProcessServerBuilder.forName("simple-reply").directExecutor() + .addService(ServerServiceDefinition.builder("some") + .addMethod(UNARY_METHOD, ServerCalls.asyncUnaryCall(service)) + .build()) + .build().start(); + channel = InProcessChannelBuilder.forName("simple-reply").directExecutor().build(); + Integer actualResponse = + ClientCalls.blockingUnaryCall(channel, UNARY_METHOD, CallOptions.DEFAULT, req); + assertEquals(resp, actualResponse); + assertEquals(req, service.request); + } + + @Test + public void blockingUnaryCall2_interruptedWaitsForOnClose() throws Exception { + Integer req = 2; + + class NoopUnaryMethod implements UnaryMethod { + ServerCallStreamObserver observer; + + @Override public void invoke(Integer request, StreamObserver responseObserver) { + observer = (ServerCallStreamObserver) responseObserver; + } + } + + NoopUnaryMethod methodImpl = new NoopUnaryMethod(); + server = InProcessServerBuilder.forName("noop").directExecutor() + .addService(ServerServiceDefinition.builder("some") + .addMethod(UNARY_METHOD, ServerCalls.asyncUnaryCall(methodImpl)) + .build()) + .build().start(); + + InterruptInterceptor interceptor = new InterruptInterceptor(); + channel = InProcessChannelBuilder.forName("noop") + .directExecutor() + .intercept(interceptor) + .build(); + try { + ClientCalls.blockingUnaryCall(channel, UNARY_METHOD, CallOptions.DEFAULT, req); + fail(); + } catch (StatusRuntimeException ex) { + assertTrue(Thread.interrupted()); + assertTrue("interrupted", ex.getCause() instanceof InterruptedException); + } + assertTrue("onCloseCalled", interceptor.onCloseCalled); + assertTrue("context not cancelled", methodImpl.observer.isCancelled()); + } + @Test public void unaryFutureCallSuccess() throws Exception { final AtomicReference> listener = @@ -372,8 +445,8 @@ public void request(int numMessages) { public void inprocessTransportInboundFlowControl() throws Exception { final Semaphore semaphore = new Semaphore(0); ServerServiceDefinition service = ServerServiceDefinition.builder( - new ServiceDescriptor("some", STREAMING_METHOD)) - .addMethod(STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall( + new ServiceDescriptor("some", BIDI_STREAMING_METHOD)) + .addMethod(BIDI_STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall( new ServerCalls.BidiStreamingMethod() { int iteration; @@ -404,7 +477,7 @@ public void onCompleted() { server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor() .addService(service).build().start(); channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build(); - final ClientCall clientCall = channel.newCall(STREAMING_METHOD, + final ClientCall clientCall = channel.newCall(BIDI_STREAMING_METHOD, CallOptions.DEFAULT); final CountDownLatch latch = new CountDownLatch(1); final List receivedMessages = new ArrayList<>(6); @@ -453,8 +526,8 @@ public void inprocessTransportOutboundFlowControl() throws Exception { final SettableFuture> observerFuture = SettableFuture.create(); ServerServiceDefinition service = ServerServiceDefinition.builder( - new ServiceDescriptor("some", STREAMING_METHOD)) - .addMethod(STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall( + new ServiceDescriptor("some", BIDI_STREAMING_METHOD)) + .addMethod(BIDI_STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall( new ServerCalls.BidiStreamingMethod() { @Override public StreamObserver invoke(StreamObserver responseObserver) { @@ -485,7 +558,7 @@ public void onCompleted() { server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor() .addService(service).build().start(); channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build(); - final ClientCall clientCall = channel.newCall(STREAMING_METHOD, + final ClientCall clientCall = channel.newCall(BIDI_STREAMING_METHOD, CallOptions.DEFAULT); final SettableFuture future = SettableFuture.create(); @@ -564,4 +637,136 @@ public void start(io.grpc.ClientCall.Listener responseListener, Metadata assertSame(trailers, metadata); } } + + @Test + public void blockingServerStreamingCall_interruptedWaitsForOnClose() throws Exception { + Integer req = 2; + + class NoopServerStreamingMethod implements ServerStreamingMethod { + ServerCallStreamObserver observer; + + @Override public void invoke(Integer request, StreamObserver responseObserver) { + observer = (ServerCallStreamObserver) responseObserver; + } + } + + NoopServerStreamingMethod methodImpl = new NoopServerStreamingMethod(); + server = InProcessServerBuilder.forName("noop").directExecutor() + .addService(ServerServiceDefinition.builder("some") + .addMethod(SERVER_STREAMING_METHOD, ServerCalls.asyncServerStreamingCall(methodImpl)) + .build()) + .build().start(); + + InterruptInterceptor interceptor = new InterruptInterceptor(); + channel = InProcessChannelBuilder.forName("noop") + .directExecutor() + .intercept(interceptor) + .build(); + Iterator iter = ClientCalls.blockingServerStreamingCall( + channel.newCall(SERVER_STREAMING_METHOD, CallOptions.DEFAULT), req); + try { + iter.next(); + fail(); + } catch (StatusRuntimeException ex) { + assertTrue(Thread.interrupted()); + assertTrue("interrupted", ex.getCause() instanceof InterruptedException); + } + assertTrue("onCloseCalled", interceptor.onCloseCalled); + assertTrue("context not cancelled", methodImpl.observer.isCancelled()); + } + + @Test + public void blockingServerStreamingCall2_success() throws Exception { + Integer req = 2; + final Integer resp1 = 3; + final Integer resp2 = 4; + + class BasicServerStreamingResponse implements ServerStreamingMethod { + Integer request; + + @Override public void invoke(Integer request, StreamObserver responseObserver) { + this.request = request; + responseObserver.onNext(resp1); + responseObserver.onNext(resp2); + responseObserver.onCompleted(); + } + } + + BasicServerStreamingResponse service = new BasicServerStreamingResponse(); + server = InProcessServerBuilder.forName("simple-reply").directExecutor() + .addService(ServerServiceDefinition.builder("some") + .addMethod(SERVER_STREAMING_METHOD, ServerCalls.asyncServerStreamingCall(service)) + .build()) + .build().start(); + channel = InProcessChannelBuilder.forName("simple-reply").directExecutor().build(); + Iterator iter = ClientCalls.blockingServerStreamingCall( + channel, SERVER_STREAMING_METHOD, CallOptions.DEFAULT, req); + assertEquals(resp1, iter.next()); + assertTrue(iter.hasNext()); + assertEquals(resp2, iter.next()); + assertFalse(iter.hasNext()); + assertEquals(req, service.request); + } + + @Test + public void blockingServerStreamingCall2_interruptedWaitsForOnClose() throws Exception { + Integer req = 2; + + class NoopServerStreamingMethod implements ServerStreamingMethod { + ServerCallStreamObserver observer; + + @Override public void invoke(Integer request, StreamObserver responseObserver) { + observer = (ServerCallStreamObserver) responseObserver; + } + } + + NoopServerStreamingMethod methodImpl = new NoopServerStreamingMethod(); + server = InProcessServerBuilder.forName("noop").directExecutor() + .addService(ServerServiceDefinition.builder("some") + .addMethod(SERVER_STREAMING_METHOD, ServerCalls.asyncServerStreamingCall(methodImpl)) + .build()) + .build().start(); + + InterruptInterceptor interceptor = new InterruptInterceptor(); + channel = InProcessChannelBuilder.forName("noop") + .directExecutor() + .intercept(interceptor) + .build(); + Iterator iter = ClientCalls.blockingServerStreamingCall( + channel, SERVER_STREAMING_METHOD, CallOptions.DEFAULT, req); + try { + iter.next(); + fail(); + } catch (StatusRuntimeException ex) { + assertTrue(Thread.interrupted()); + assertTrue("interrupted", ex.getCause() instanceof InterruptedException); + } + assertTrue("onCloseCalled", interceptor.onCloseCalled); + assertTrue("context not cancelled", methodImpl.observer.isCancelled()); + } + + // Used for blocking tests to check interrupt behavior and make sure onClose is still called. + class InterruptInterceptor implements ClientInterceptor { + boolean onCloseCalled; + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return new SimpleForwardingClientCall(next.newCall(method, callOptions)) { + @Override public void start(ClientCall.Listener listener, Metadata headers) { + super.start(new SimpleForwardingClientCallListener(listener) { + @Override public void onClose(Status status, Metadata trailers) { + onCloseCalled = true; + super.onClose(status, trailers); + } + }, headers); + } + + @Override public void halfClose() { + Thread.currentThread().interrupt(); + super.halfClose(); + } + }; + } + } }