diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 782ae21d3b3..0bb83a56d75 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -65,7 +65,6 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -785,15 +784,22 @@ static ClientStreamTracer newClientStreamTracer( } else { streamTracer = new ForwardingClientStreamTracer() { final ClientStreamTracer noop = new ClientStreamTracer() {}; - AtomicReference delegate = new AtomicReference<>(noop); + volatile ClientStreamTracer delegate = noop; void maybeInit(StreamInfo info, Metadata headers) { - delegate.compareAndSet(noop, streamTracerFactory.newClientStreamTracer(info, headers)); + if (delegate != noop) { + return; + } + synchronized (this) { + if (delegate == noop) { + delegate = streamTracerFactory.newClientStreamTracer(info, headers); + } + } } @Override protected ClientStreamTracer delegate() { - return delegate.get(); + return delegate; } @SuppressWarnings("deprecation") diff --git a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java index 95d1c448f4f..6d2c21ddab8 100644 --- a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java +++ b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java @@ -38,6 +38,7 @@ import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.GrpcUtil.Http2Error; import io.grpc.testing.TestMethodDescriptors; +import java.util.ArrayDeque; import java.util.concurrent.atomic.AtomicReference; import org.junit.Rule; import org.junit.Test; @@ -301,12 +302,14 @@ public void clientStreamTracerFactoryBackwardCompatibility() { final AtomicReference transportAttrsRef = new AtomicReference<>(); final ClientStreamTracer mockTracer = mock(ClientStreamTracer.class); final Metadata.Key key = Metadata.Key.of("fake-key", Metadata.ASCII_STRING_MARSHALLER); + final ArrayDeque tracers = new ArrayDeque<>(); ClientStreamTracer.Factory oldFactoryImpl = new ClientStreamTracer.Factory() { @SuppressWarnings("deprecation") @Override public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { transportAttrsRef.set(info.getTransportAttrs()); headers.put(key, "fake-value"); + tracers.offer(mockTracer); return mockTracer; } }; @@ -318,8 +321,12 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata header Attributes.newBuilder().set(Attributes.Key.create("foo"), "bar").build(); ClientStreamTracer tracer = GrpcUtil.newClientStreamTracer(oldFactoryImpl, info, metadata); tracer.streamCreated(transAttrs, metadata); - + assertThat(tracers.poll()).isSameInstanceAs(mockTracer); assertThat(transportAttrsRef.get()).isEqualTo(transAttrs); assertThat(metadata.get(key)).isEqualTo("fake-value"); + + tracer.streamClosed(Status.UNAVAILABLE); + // verify that newClientStreamTracer() is called no more than once + assertThat(tracers).isEmpty(); } }