Skip to content

Commit

Permalink
ref #357
Browse files Browse the repository at this point in the history
  • Loading branch information
jvmlet committed May 24, 2023
1 parent 4436fa2 commit d646b89
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re

}

@SpyBean
private GRpcErrorHandler errorHandler;


@Test
public void originalCustomInterceptorStatusIsPreserved() {
Expand All @@ -88,7 +87,7 @@ public void originalCustomInterceptorStatusIsPreserved() {
.sayAuthHello2(Empty.newBuilder().build()).getMessage();
});
assertThat(statusRuntimeException.getStatus().getCode(), Matchers.is(Status.Code.ALREADY_EXISTS));
verifyZeroInteractions(errorHandler);

}
@Test
public void unsupportedAuthSchemeShouldThrowUnauthenticatedException() {
Expand All @@ -104,7 +103,7 @@ public void unsupportedAuthSchemeShouldThrowUnauthenticatedException() {
.sayAuthHello2(Empty.newBuilder().build()).getMessage();
});
assertThat(statusRuntimeException.getStatus().getCode(), Matchers.is(Status.Code.UNAUTHENTICATED));
verify(errorHandler).handle(any(),eq(Status.UNAUTHENTICATED), any(),any(),any());

}


Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package org.lognet.springboot.grpc.recovery;

import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.*;
import io.grpc.examples.custom.Custom;
import io.grpc.examples.custom.CustomServiceGrpc;
import io.grpc.stub.StreamObserver;
Expand All @@ -11,14 +9,21 @@
import org.lognet.springboot.grpc.GRpcService;
import org.lognet.springboot.grpc.GrpcServerTestBase;
import org.lognet.springboot.grpc.demo.DemoApp;
import org.lognet.springboot.grpc.security.*;
import org.mockito.Mockito;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.boot.test.mock.mockito.SpyBean;
import org.springframework.context.annotation.Import;
import org.springframework.security.authentication.dao.DaoAuthenticationProvider;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.crypto.password.NoOpPasswordEncoder;
import org.springframework.security.provisioning.InMemoryUserDetailsManager;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;

import java.util.Collections;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
Expand All @@ -36,13 +41,13 @@

@RunWith(SpringRunner.class)
@SpringBootTest(classes = {DemoApp.class}, webEnvironment = NONE)
@ActiveProfiles({"disable-security"})
@Import(GRpcRecoveryTest.Cfg.class)
public class GRpcRecoveryTest extends GrpcServerTestBase {

static class CheckedException extends Exception {

}

static class CheckedException1 extends Exception {

}
Expand All @@ -59,9 +64,25 @@ static class Exception1 extends RuntimeException {

}

@TestConfiguration
static class Cfg {
private static User user1 = new User("test1", "test1", Collections.EMPTY_LIST);

private AuthHeader.AuthHeaderBuilder user1AuthHeaderBuilder =
AuthHeader.builder().basic(user1.getUsername(), user1.getPassword().getBytes());

@TestConfiguration
static class Cfg extends GrpcSecurityConfigurerAdapter {
@Override
public void configure(GrpcSecurity builder) throws Exception {
DaoAuthenticationProvider provider = new DaoAuthenticationProvider();
UserDetailsService users = new InMemoryUserDetailsManager(user1);
provider.setUserDetailsService(users);
provider.setPasswordEncoder(NoOpPasswordEncoder.getInstance());

builder
.authenticationProvider(provider)
.authorizeRequests()
.anyMethod().authenticated();
}

@GRpcServiceAdvice
static class CustomErrorHandler {
Expand Down Expand Up @@ -102,6 +123,9 @@ public StreamObserver<Custom.CustomRequest> customStream(StreamObserver<Custom.C
return new StreamObserver<Custom.CustomRequest>() {
@Override
public void onNext(Custom.CustomRequest value) {
if ("onNext".equalsIgnoreCase(value.getName())) {
throw new GRpcRuntimeExceptionWrapper(new CheckedException1());
}
responseObserver.onNext(Custom.CustomReply.newBuilder().build());
}

Expand Down Expand Up @@ -151,9 +175,24 @@ public Status handleB(ExceptionB e, GRpcExceptionScope scope) {
private Cfg.CustomErrorHandler handler;


protected Channel getChannel() {
return ClientInterceptors.intercept(super.getChannel(), new AuthClientInterceptor(user1AuthHeaderBuilder));
}


@Test
public void streamingServiceErrorHandlerTest() throws ExecutionException, InterruptedException, TimeoutException {
public void parameterizedStreamingServiceErrorHandlerTest() throws ExecutionException, InterruptedException, TimeoutException {
String[] phases = new String[]{
"onNext", // exception will be thrown onNext
"onCompleted" // exception will be thrown onCompleted
};
for (String errorPhase : phases) {
streamingServiceErrorHandlerTest(errorPhase);
Mockito.clearInvocations(srv);
}
}

public void streamingServiceErrorHandlerTest(String errorName) throws ExecutionException, InterruptedException, TimeoutException {


final CompletableFuture<Throwable> errorFuture = new CompletableFuture<>();
Expand All @@ -175,20 +214,18 @@ public void onCompleted() {
}
};

final StreamObserver<Custom.CustomRequest> requests = CustomServiceGrpc.newStub(getChannel()).customStream(reply);
requests.onNext(Custom.CustomRequest.newBuilder().build());
final StreamObserver<Custom.CustomRequest> requests = CustomServiceGrpc.newStub(getChannel())
.customStream(reply);
requests.onNext(Custom.CustomRequest.newBuilder().setName(errorName).build());
requests.onCompleted();





final Throwable actual = errorFuture.get(20, TimeUnit.SECONDS);
assertThat(actual, notNullValue());
assertThat(actual, isA(StatusRuntimeException.class));
assertThat(((StatusRuntimeException)actual).getStatus(), is(Status.RESOURCE_EXHAUSTED));
assertThat(((StatusRuntimeException) actual).getStatus(), is(Status.RESOURCE_EXHAUSTED));

Mockito.verify(srv,times(1)).handle(any(CheckedException1.class),any());
Mockito.verify(srv, times(1)).handle(any(CheckedException1.class), any());

}

Expand All @@ -199,7 +236,8 @@ public void checkedExceptionHandlerTest() {
.custom(any(Custom.CustomRequest.class), any(StreamObserver.class));

final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, () ->
CustomServiceGrpc.newBlockingStub(getChannel()).custom(Custom.CustomRequest.newBuilder().build())
CustomServiceGrpc.newBlockingStub(getChannel())
.custom(Custom.CustomRequest.newBuilder().build())
);
assertThat(statusRuntimeException.getStatus(), is(Status.OUT_OF_RANGE));

Expand All @@ -221,7 +259,8 @@ public void globalHandlerTest() {
.custom(any(Custom.CustomRequest.class), any(StreamObserver.class));

final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, () ->
CustomServiceGrpc.newBlockingStub(getChannel()).custom(Custom.CustomRequest.newBuilder().build())
CustomServiceGrpc.newBlockingStub(getChannel())
.custom(Custom.CustomRequest.newBuilder().build())
);
assertThat(statusRuntimeException.getStatus(), is(Status.NOT_FOUND));

Expand All @@ -242,7 +281,8 @@ public void globalHandlerWithExceptionHierarchyTest() {
.custom(any(Custom.CustomRequest.class), any(StreamObserver.class));

final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, () ->
CustomServiceGrpc.newBlockingStub(getChannel()).custom(Custom.CustomRequest.newBuilder().build())
CustomServiceGrpc.newBlockingStub(getChannel())
.custom(Custom.CustomRequest.newBuilder().build())
);
assertThat(statusRuntimeException.getStatus(), is(Status.DATA_LOSS));

Expand All @@ -263,7 +303,8 @@ public void privateHandlerHasHigherPrecedence() {
.custom(any(Custom.CustomRequest.class), any(StreamObserver.class));

final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, () ->
CustomServiceGrpc.newBlockingStub(getChannel()).custom(Custom.CustomRequest.newBuilder().build())
CustomServiceGrpc.newBlockingStub(getChannel())
.custom(Custom.CustomRequest.newBuilder().build())
);
assertThat(statusRuntimeException.getStatus(), is(Status.FAILED_PRECONDITION));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ public Optional<Authentication> getAuthScheme(CharSequence authorization) {
.map(selector -> selector.getAuthScheme(authorization))
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.toList());
.toList();
switch (auth.size()){
case 0:
throw new IllegalStateException(String.format("Authentication scheme '%s' is not supported.",
log.error(String.format("Authentication scheme '%s' is not supported.",
Optional.ofNullable(authorization)
.map(s->s.toString().split(" ",2)[0])
.orElse(null)
));
return Optional.empty();
case 1 :
return Optional.of(auth.get(0));
default:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
package org.lognet.springboot.grpc.security;

import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.ForwardingServerCall;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.*;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.lognet.springboot.grpc.FailureHandlingSupport;
import org.lognet.springboot.grpc.GRpcServicesRegistry;
import org.lognet.springboot.grpc.MessageBlockingServerCallListener;
import org.lognet.springboot.grpc.autoconfigure.GRpcServerProperties;
import org.lognet.springboot.grpc.recovery.GRpcRuntimeExceptionWrapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.core.Ordered;
Expand All @@ -36,7 +30,7 @@
public class SecurityInterceptor extends AbstractSecurityInterceptor implements ServerInterceptor, Ordered {

private static final Context.Key<InterceptorStatusToken> INTERCEPTOR_STATUS_TOKEN = Context.key("INTERCEPTOR_STATUS_TOKEN");
private static final Context.Key<GrpcMethodInvocation<?,?>> METHOD_INVOCATION = Context.key("METHOD_INVOCATION");
private static final Context.Key<GrpcMethodInvocation<?, ?>> METHOD_INVOCATION = Context.key("METHOD_INVOCATION");

private final SecurityMetadataSource securityMetadataSource;

Expand Down Expand Up @@ -75,8 +69,6 @@ ServerCall<ReqT, RespT> getCall() {
}




public SecurityInterceptor(SecurityMetadataSource securityMetadataSource, AuthenticationSchemeSelector schemeSelector) {
this.securityMetadataSource = securityMetadataSource;
this.schemeSelector = schemeSelector;
Expand Down Expand Up @@ -142,11 +134,10 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
final Context grpcSecurityContext;
try {
grpcSecurityContext = setupGRpcSecurityContext(call, headers, next, authorization);
} catch (AccessDeniedException | AuthenticationException e) {
} catch (RuntimeException e) {
return fail(next, call, headers, e);
} catch (Exception e) {
return fail(next, call, headers, new AuthenticationException("Authentication failure.", e) {
});
return fail(next, call, headers, new GRpcRuntimeExceptionWrapper(e));
}
return Contexts.interceptCall(grpcSecurityContext, call, headers, authenticationPropagatingHandler(next));
} finally {
Expand Down Expand Up @@ -176,15 +167,16 @@ public void onMessage(ReqT message) {
METHOD_INVOCATION.get().setArguments(new Object[]{message});
break;
default:
throw new AuthenticationException("Unsupported call type "+call.getMethodDescriptor().getType()) {};
log.error("Unsupported call type " + call.getMethodDescriptor().getType());
throw new StatusRuntimeException(Status.UNAUTHENTICATED) ;
}

beforeInvocation(METHOD_INVOCATION.get());
super.onMessage(message);
} catch (AccessDeniedException | AuthenticationException e) {
failureHandlingSupport.closeCall(e,call,headers);
} catch (RuntimeException e) {
failureHandlingSupport.closeCall(e, call, headers);
} catch (Exception e) {
failureHandlingSupport.closeCall( new AuthenticationException("", e) {},call, headers);
failureHandlingSupport.closeCall(new GRpcRuntimeExceptionWrapper(e), call, headers);
} finally {
METHOD_INVOCATION.get().setArguments(null);
}
Expand Down Expand Up @@ -243,15 +235,15 @@ private <RespT, ReqT> Context setupGRpcSecurityContext(ServerCall<RespT, ReqT> c
ServerCallHandler<RespT, ReqT> next, CharSequence authorization) {
final Authentication authentication = null == authorization ? null :
schemeSelector.getAuthScheme(authorization)
.orElseThrow(() -> new RuntimeException("Can't get authentication from authorization header"));
.orElseThrow(() -> new StatusRuntimeException(Status.UNAUTHENTICATED));

SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);

final GRpcServicesRegistry.GrpcServiceMethod grpcServiceMethod = registry.getGrpServiceMethod(call.getMethodDescriptor());

final GrpcMethodInvocation<RespT, ReqT> methodInvocation = new GrpcMethodInvocation<>(grpcServiceMethod , call, headers, next);
final GrpcMethodInvocation<RespT, ReqT> methodInvocation = new GrpcMethodInvocation<>(grpcServiceMethod, call, headers, next);
final InterceptorStatusToken interceptorStatusToken = beforeInvocation(methodInvocation);

return Context.current()
Expand Down

0 comments on commit d646b89

Please sign in to comment.