Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api, core, services: make ProtoReflectionService interceptor compatible #6967

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 0 additions & 28 deletions api/src/main/java/io/grpc/InternalNotifyOnServerBuild.java

This file was deleted.

6 changes: 6 additions & 0 deletions api/src/main/java/io/grpc/PartialForwardingServerCall.java
Expand Up @@ -76,6 +76,12 @@ public String getAuthority() {
return delegate().getAuthority();
}

@Override
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/6989")
public Server getServer() {
return delegate().getServer();
}

@Override
public String toString() {
return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString();
Expand Down
9 changes: 9 additions & 0 deletions api/src/main/java/io/grpc/ServerCall.java
Expand Up @@ -229,4 +229,13 @@ public String getAuthority() {
* The {@link MethodDescriptor} for the call.
*/
public abstract MethodDescriptor<ReqT, RespT> getMethodDescriptor();

/**
* Returns the {@link Server} that dispatches the call. {@code null} if the implementation
* choose to not expose the server.
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/6989")
public Server getServer() {
return null;
}
}
Expand Up @@ -29,7 +29,6 @@
import io.grpc.DecompressorRegistry;
import io.grpc.HandlerRegistry;
import io.grpc.InternalChannelz;
import io.grpc.InternalNotifyOnServerBuild;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptor;
Expand Down Expand Up @@ -77,7 +76,6 @@ public static ServerBuilder<?> forPort(int port) {
new InternalHandlerRegistry.Builder();
final List<ServerTransportFilter> transportFilters = new ArrayList<>();
final List<ServerInterceptor> interceptors = new ArrayList<>();
private final List<InternalNotifyOnServerBuild> notifyOnBuildList = new ArrayList<>();
private final List<ServerStreamTracer.Factory> streamTracerFactories = new ArrayList<>();
HandlerRegistry fallbackRegistry = DEFAULT_FALLBACK_REGISTRY;
ObjectPool<? extends Executor> executorPool = DEFAULT_EXECUTOR_POOL;
Expand Down Expand Up @@ -114,9 +112,6 @@ public final T addService(ServerServiceDefinition service) {

@Override
public final T addService(BindableService bindableService) {
if (bindableService instanceof InternalNotifyOnServerBuild) {
notifyOnBuildList.add((InternalNotifyOnServerBuild) bindableService);
}
return addService(checkNotNull(bindableService, "bindableService").bindService());
}

Expand Down Expand Up @@ -222,14 +217,7 @@ protected void setDeadlineTicker(Deadline.Ticker ticker) {

@Override
public final Server build() {
ServerImpl server = new ServerImpl(
this,
buildTransportServers(getTracerFactories()),
Context.ROOT);
for (InternalNotifyOnServerBuild notifyTarget : notifyOnBuildList) {
notifyTarget.notifyOnBuild(server);
}
return server;
return new ServerImpl(this, buildTransportServers(getTracerFactories()), Context.ROOT);
}

@VisibleForTesting
Expand Down
10 changes: 9 additions & 1 deletion core/src/main/java/io/grpc/internal/ServerCallImpl.java
Expand Up @@ -35,6 +35,7 @@
import io.grpc.InternalDecompressorRegistry;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.Status;
import io.perfmark.PerfMark;
Expand All @@ -52,6 +53,7 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {
@VisibleForTesting
static final String MISSING_RESPONSE = "Completed without a response";

private final Server server;
private final ServerStream stream;
private final MethodDescriptor<ReqT, RespT> method;
private final Tag tag;
Expand All @@ -68,10 +70,11 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {
private Compressor compressor;
private boolean messageSent;

ServerCallImpl(ServerStream stream, MethodDescriptor<ReqT, RespT> method,
ServerCallImpl(Server server, ServerStream stream, MethodDescriptor<ReqT, RespT> method,
Metadata inboundHeaders, Context.CancellableContext context,
DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry,
CallTracer serverCallTracer, Tag tag) {
this.server = server;
this.stream = stream;
this.method = method;
this.context = context;
Expand Down Expand Up @@ -245,6 +248,11 @@ public MethodDescriptor<ReqT, RespT> getMethodDescriptor() {
return method;
}

@Override
public Server getServer() {
return server;
}

/**
* Close the {@link ServerStream} because an internal error occurred. Allow the application to
* run until completion, but silently ignore interactions with the {@link ServerStream} from now
Expand Down
1 change: 1 addition & 0 deletions core/src/main/java/io/grpc/internal/ServerImpl.java
Expand Up @@ -636,6 +636,7 @@ private <WReqT, WRespT> ServerStreamListener startWrappedCall(
Tag tag) {

ServerCallImpl<WReqT, WRespT> call = new ServerCallImpl<>(
ServerImpl.this,
stream,
methodDef.getMethodDescriptor(),
headers,
Expand Down
9 changes: 7 additions & 2 deletions core/src/test/java/io/grpc/internal/ServerCallImplTest.java
Expand Up @@ -39,6 +39,7 @@
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.Status;
import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl;
Expand All @@ -60,6 +61,7 @@
@RunWith(JUnit4.class)
public class ServerCallImplTest {
@Rule public final ExpectedException thrown = ExpectedException.none();
@Mock private Server server;
@Mock private ServerStream stream;
@Mock private ServerCall.Listener<Long> callListener;

Expand Down Expand Up @@ -89,7 +91,7 @@ public class ServerCallImplTest {
public void setUp() {
MockitoAnnotations.initMocks(this);
context = Context.ROOT.withCancellation();
call = new ServerCallImpl<>(stream, UNARY_METHOD, requestHeaders, context,
call = new ServerCallImpl<>(server, stream, UNARY_METHOD, requestHeaders, context,
DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(),
serverCallTracer, PerfMark.createTag());
}
Expand All @@ -112,7 +114,7 @@ private void callTracer0(Status status) {
assertEquals(0, before.callsStarted);
assertEquals(0, before.lastCallStartedNanos);

call = new ServerCallImpl<>(stream, UNARY_METHOD, requestHeaders, context,
call = new ServerCallImpl<>(server, stream, UNARY_METHOD, requestHeaders, context,
DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(),
tracer, PerfMark.createTag());

Expand Down Expand Up @@ -219,6 +221,7 @@ public void sendMessage_serverSendsOne_closeOnSecondCall_clientStreaming() {
private void sendMessage_serverSendsOne_closeOnSecondCall(
MethodDescriptor<Long, Long> method) {
ServerCallImpl<Long, Long> serverCall = new ServerCallImpl<>(
server,
stream,
method,
requestHeaders,
Expand Down Expand Up @@ -254,6 +257,7 @@ public void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion_clie
private void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion(
MethodDescriptor<Long, Long> method) {
ServerCallImpl<Long, Long> serverCall = new ServerCallImpl<>(
server,
stream,
method,
requestHeaders,
Expand Down Expand Up @@ -292,6 +296,7 @@ public void serverSendsOne_okFailsOnMissingResponse_clientStreaming() {
private void serverSendsOne_okFailsOnMissingResponse(
MethodDescriptor<Long, Long> method) {
ServerCallImpl<Long, Long> serverCall = new ServerCallImpl<>(
server,
stream,
method,
requestHeaders,
Expand Down
Expand Up @@ -24,10 +24,14 @@
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.Descriptors.MethodDescriptor;
import com.google.protobuf.Descriptors.ServiceDescriptor;
import io.grpc.BindableService;
import io.grpc.ExperimentalApi;
import io.grpc.InternalNotifyOnServerBuild;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors;
import io.grpc.ServerServiceDefinition;
import io.grpc.Status;
import io.grpc.protobuf.ProtoFileDescriptorSupplier;
Expand All @@ -50,6 +54,7 @@
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;

Expand All @@ -61,26 +66,29 @@
* extension.
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/2222")
public final class ProtoReflectionService extends ServerReflectionGrpc.ServerReflectionImplBase
implements InternalNotifyOnServerBuild {
public final class ProtoReflectionService extends ServerReflectionGrpc.ServerReflectionImplBase {

private final Object lock = new Object();
ejona86 marked this conversation as resolved.
Show resolved Hide resolved

private AtomicReference<Server> serverRef;
@GuardedBy("lock")
private ServerReflectionIndex serverReflectionIndex;

private Server server;

private ProtoReflectionService() {}

public static BindableService newInstance() {
return new ProtoReflectionService();
/**
* Creates a instance of {@link ProtoReflectionService}.
*/
public static ServerServiceDefinition newInstance() {
AtomicReference<Server> serverCaptor = new AtomicReference<>();
ProtoReflectionService protoReflection = new ProtoReflectionService();
protoReflection.init(serverCaptor);
return ServerInterceptors.intercept(
protoReflection, new ServerCaptureInterceptor(serverCaptor));
}

/** Receives a reference to the server at build time. */
@Override
public void notifyOnBuild(Server server) {
this.server = checkNotNull(server);
private void init(AtomicReference<Server> serverRef) {
this.serverRef = serverRef;
}

/**
Expand All @@ -92,6 +100,7 @@ public void notifyOnBuild(Server server) {
*/
private ServerReflectionIndex updateIndexIfNecessary() {
synchronized (lock) {
Server server = serverRef.get();
if (serverReflectionIndex == null) {
serverReflectionIndex =
new ServerReflectionIndex(server.getImmutableServices(), server.getMutableServices());
Expand Down Expand Up @@ -140,6 +149,21 @@ public StreamObserver<ServerReflectionRequest> serverReflectionInfo(
return requestObserver;
}

private static final class ServerCaptureInterceptor implements ServerInterceptor {
private final AtomicReference<Server> captor;

ServerCaptureInterceptor(AtomicReference<Server> captor) {
this.captor = captor;
}

@Override
public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
Metadata headers, ServerCallHandler<ReqT, RespT> next) {
captor.set(call.getServer());
voidzcy marked this conversation as resolved.
Show resolved Hide resolved
return next.startCall(call, headers);
}
}

private static class ProtoReflectionStreamObserver
implements Runnable, StreamObserver<ServerReflectionRequest> {
private final ServerReflectionIndex serverReflectionIndex;
Expand Down
Expand Up @@ -22,7 +22,6 @@
import static org.junit.Assert.fail;

import com.google.protobuf.ByteString;
import io.grpc.BindableService;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.ServerServiceDefinition;
Expand Down Expand Up @@ -64,7 +63,7 @@
public class ProtoReflectionServiceTest {
private static final String TEST_HOST = "localhost";
private MutableHandlerRegistry handlerRegistry = new MutableHandlerRegistry();
private BindableService reflectionService;
private ServerServiceDefinition reflectionService;
private ServerServiceDefinition dynamicService =
new DynamicServiceGrpc.DynamicServiceImplBase() {}.bindService();
private ServerServiceDefinition anotherDynamicService =
Expand Down
1 change: 1 addition & 0 deletions testing/src/main/java/io/grpc/internal/NoopServerCall.java
Expand Up @@ -18,6 +18,7 @@

import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.Status;

Expand Down