From ac2ead70b4e6b77268a0affd2656503e130e3d37 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Tue, 26 Jan 2021 12:01:16 -0800 Subject: [PATCH] core: delay CallCredentialsApplyingTransport shutdown until metadataApplier finalized (#7813) Improve the CallCredentialsApplyingTransport shutdown lifecycle management. Right now CallCredentialsApplyingTransport shutdown the delegated real transport too early. It should be waiting for the metadataAppliers to finish because they may execute asynchronously. In addition, there is no shutdown check on CallCredentialsApplyingTransport for newStream(). The degraded lifecycle implementation may cause RejectionExecutionException, or accepting new RPCs after the underlying transport is already closed during channel shutdown. We added listener on metadataApplier to notify completion, a magic counter to track the pending metadataApplier for delaying shutdown, also added shutdown check for newStream(). --- ...llCredentialsApplyingTransportFactory.java | 85 ++++++++++++++- .../io/grpc/internal/MetadataApplierImpl.java | 19 +++- .../CallCredentials2ApplyingTest.java | 27 +++++ .../internal/CallCredentialsApplyingTest.java | 101 ++++++++++++++++++ 4 files changed, 229 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java index ab2a36d9492..de96a3306bd 100644 --- a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java @@ -29,9 +29,12 @@ import io.grpc.MethodDescriptor; import io.grpc.SecurityLevel; import io.grpc.Status; +import io.grpc.internal.MetadataApplierImpl.MetadataApplierListener; import java.net.SocketAddress; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.concurrent.GuardedBy; final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory { private final ClientTransportFactory delegate; @@ -66,6 +69,21 @@ public void close() { private class CallCredentialsApplyingTransport extends ForwardingConnectionClientTransport { private final ConnectionClientTransport delegate; private final String authority; + // Negative value means transport active, non-negative value indicates shutdown invoked. + private final AtomicInteger pendingApplier = new AtomicInteger(Integer.MIN_VALUE + 1); + private volatile Status shutdownStatus; + @GuardedBy("this") + private Status savedShutdownStatus; + @GuardedBy("this") + private Status savedShutdownNowStatus; + private final MetadataApplierListener applierListener = new MetadataApplierListener() { + @Override + public void onComplete() { + if (pendingApplier.decrementAndGet() == 0) { + maybeShutdown(); + } + } + }; CallCredentialsApplyingTransport(ConnectionClientTransport delegate, String authority) { this.delegate = checkNotNull(delegate, "delegate"); @@ -89,7 +107,11 @@ public ClientStream newStream( } if (creds != null) { MetadataApplierImpl applier = new MetadataApplierImpl( - delegate, method, headers, callOptions); + delegate, method, headers, callOptions, applierListener); + if (pendingApplier.incrementAndGet() > 0) { + applierListener.onComplete(); + return new FailingClientStream(shutdownStatus); + } RequestInfo requestInfo = new RequestInfo() { @Override public MethodDescriptor getMethodDescriptor() { @@ -123,8 +145,69 @@ public Attributes getTransportAttrs() { } return applier.returnStream(); } else { + if (pendingApplier.get() >= 0) { + return new FailingClientStream(shutdownStatus); + } return delegate.newStream(method, headers, callOptions); } } + + @Override + public void shutdown(Status status) { + checkNotNull(status, "status"); + synchronized (this) { + if (pendingApplier.get() < 0) { + shutdownStatus = status; + pendingApplier.addAndGet(Integer.MAX_VALUE); + } else { + return; + } + if (pendingApplier.get() != 0) { + savedShutdownStatus = status; + return; + } + } + super.shutdown(status); + } + + // TODO(zivy): cancel pending applier here. + @Override + public void shutdownNow(Status status) { + checkNotNull(status, "status"); + synchronized (this) { + if (pendingApplier.get() < 0) { + shutdownStatus = status; + pendingApplier.addAndGet(Integer.MAX_VALUE); + } else if (savedShutdownNowStatus != null) { + return; + } + if (pendingApplier.get() != 0) { + savedShutdownNowStatus = status; + // TODO(zivy): propagate shutdownNow to the delegate immediately. + return; + } + } + super.shutdownNow(status); + } + + private void maybeShutdown() { + Status maybeShutdown; + Status maybeShutdownNow; + synchronized (this) { + if (pendingApplier.get() != 0) { + return; + } + maybeShutdown = savedShutdownStatus; + maybeShutdownNow = savedShutdownNowStatus; + savedShutdownStatus = null; + savedShutdownNowStatus = null; + } + if (maybeShutdown != null) { + super.shutdown(maybeShutdown); + } + if (maybeShutdownNow != null) { + super.shutdownNow(maybeShutdownNow); + } + } } } diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index 4c49a14a06b..76d280b2d00 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -35,6 +35,7 @@ final class MetadataApplierImpl extends MetadataApplier { private final Metadata origHeaders; private final CallOptions callOptions; private final Context ctx; + private final MetadataApplierListener listener; private final Object lock = new Object(); @@ -51,12 +52,13 @@ final class MetadataApplierImpl extends MetadataApplier { MetadataApplierImpl( ClientTransport transport, MethodDescriptor method, Metadata origHeaders, - CallOptions callOptions) { + CallOptions callOptions, MetadataApplierListener listener) { this.transport = transport; this.method = method; this.origHeaders = origHeaders; this.callOptions = callOptions; this.ctx = Context.current(); + this.listener = listener; } @Override @@ -84,14 +86,19 @@ public void fail(Status status) { private void finalizeWith(ClientStream stream) { checkState(!finalized, "already finalized"); finalized = true; + boolean directStream = false; synchronized (lock) { if (returnedStream == null) { // Fast path: returnStream() hasn't been called, the call will use the // real stream directly. returnedStream = stream; - return; + directStream = true; } } + if (directStream) { + listener.onComplete(); + return; + } // returnStream() has been called before me, thus delayedStream must have been // created. checkState(delayedStream != null, "delayedStream is null"); @@ -100,6 +107,7 @@ private void finalizeWith(ClientStream stream) { // TODO(ejona): run this on a separate thread slow.run(); } + listener.onComplete(); } /** @@ -116,4 +124,11 @@ ClientStream returnStream() { } } } + + public interface MetadataApplierListener { + /** + * Notify that the metadata has been successfully applied, or failed. + * */ + void onComplete(); + } } diff --git a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java index c26944c16b2..7725c46726b 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; @@ -203,6 +204,10 @@ public void credentialThrows() { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); assertSame(ex, stream.getError().getCause()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -227,6 +232,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable { assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -249,6 +258,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); assertSame(error, stream.getError()); + transport.shutdownNow(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } @Test @@ -263,6 +276,9 @@ public void applyMetadata_delayed() { any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + transport.shutdown(Status.UNAVAILABLE); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); @@ -271,6 +287,9 @@ public void applyMetadata_delayed() { assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -290,6 +309,10 @@ public void fail_delayed() { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -301,5 +324,9 @@ public void noCreds() { assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } } diff --git a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java index 6949ab7c310..61a221f73de 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java @@ -19,12 +19,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -179,6 +181,11 @@ public void credentialThrows() { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); assertSame(ex, stream.getError().getCause()); + + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -192,6 +199,10 @@ public void applyMetadata_inline() { assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -214,6 +225,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); assertSame(error, stream.getError()); + transport.shutdownNow(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } @Test @@ -228,6 +243,11 @@ public void applyMetadata_delayed() { same(mockExecutor), applierCaptor.capture()); verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + transport.shutdown(Status.UNAVAILABLE); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); @@ -236,6 +256,79 @@ public void applyMetadata_delayed() { assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + } + + @Test + public void delayedShutdown_shutdownShutdownNowThenApply() { + transport.newStream(method, origHeaders, callOptions); + ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), + same(mockExecutor), applierCaptor.capture()); + transport.shutdown(Status.UNAVAILABLE); + transport.shutdownNow(Status.ABORTED); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport, never()).shutdown(any(Status.class)); + verify(mockTransport, never()).shutdownNow(any(Status.class)); + Metadata headers = new Metadata(); + headers.put(CREDS_KEY, CREDS_VALUE); + applierCaptor.getValue().apply(headers); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + verify(mockTransport).shutdownNow(Status.ABORTED); + } + + @Test + public void delayedShutdown_shutdownThenApplyThenShutdownNow() { + transport.newStream(method, origHeaders, callOptions); + ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), + same(mockExecutor), applierCaptor.capture()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport, never()).shutdown(any(Status.class)); + Metadata headers = new Metadata(); + headers.put(CREDS_KEY, CREDS_VALUE); + applierCaptor.getValue().apply(headers); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + transport.shutdownNow(Status.ABORTED); + verify(mockTransport).shutdownNow(Status.ABORTED); + + transport.shutdown(Status.UNAVAILABLE); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + transport.shutdownNow(Status.ABORTED); + verify(mockTransport, times(2)).shutdownNow(Status.ABORTED); + } + + @Test + public void delayedShutdown_shutdownMulti() { + Metadata headers = new Metadata(); + headers.put(CREDS_KEY, CREDS_VALUE); + + transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions); + ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds, times(3)).applyRequestMetadata(any(RequestInfo.class), + same(mockExecutor), applierCaptor.capture()); + applierCaptor.getAllValues().get(1).apply(headers); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + + applierCaptor.getAllValues().get(0).apply(headers); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + + applierCaptor.getAllValues().get(2).apply(headers); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -255,6 +348,10 @@ public void fail_delayed() { verify(mockTransport, never()).newStream(method, origHeaders, callOptions); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test @@ -266,6 +363,10 @@ public void noCreds() { assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); } @Test