Skip to content

Commit

Permalink
fixes #355
Browse files Browse the repository at this point in the history
  • Loading branch information
jvmlet committed May 15, 2023
1 parent 8dd9d47 commit f0d1b9e
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package org.lognet.springboot.grpc.demo;

import io.grpc.CallOptions;
import io.grpc.Context;
import io.grpc.Status;
import io.grpc.examples.reactor.ReactiveHelloRequest;
import io.grpc.examples.reactor.ReactiveHelloResponse;
Expand All @@ -10,9 +8,8 @@
import org.lognet.springboot.grpc.GRpcService;
import org.lognet.springboot.grpc.recovery.GRpcExceptionHandler;
import org.lognet.springboot.grpc.recovery.GRpcExceptionScope;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.security.access.annotation.Secured;
import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
Expand All @@ -32,6 +29,7 @@ public ReactiveGreeterGrpcService(ReactiveGreeterService reactiveGreeterService)
}

@Override
@Secured({})
public Mono<ReactiveHelloResponse> greet(Mono<ReactiveHelloRequest> request) {
return reactiveGreeterService.greet(request);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
import org.hamcrest.collection.IsIn;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.lognet.springboot.grpc.GrpcServerTestBase;
import org.lognet.springboot.grpc.auth.JwtAuthBaseTest;
import org.lognet.springboot.grpc.demo.DemoApp;
import org.lognet.springboot.grpc.security.GrpcSecurity;
import org.lognet.springboot.grpc.security.GrpcSecurityConfigurerAdapter;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;
Expand All @@ -32,9 +35,22 @@
@Slf4j
@RunWith(SpringRunner.class)
@SpringBootTest(classes = DemoApp.class, webEnvironment = NONE)
@ActiveProfiles({"disable-security","r2dbc-test"})
@ActiveProfiles({"keycloack-test", "r2dbc-test"})
@DirtiesContext
public class ReactiveDemoTest extends GrpcServerTestBase {
public class ReactiveDemoTest extends JwtAuthBaseTest {

@TestConfiguration
static class TestCfg {
private static class DemoGrpcSecurityAdapter extends GrpcSecurityConfigurerAdapter {
@Override
public void configure(GrpcSecurity builder) throws Exception {
builder.authorizeRequests()
.withSecuredAnnotation();

}
}
}

@Test
public void grpcGreetTest() {
String shrek = "Shrek";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.MethodIntrospector;
import org.springframework.core.MethodParameter;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.function.SingletonSupplier;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.ParameterizedType;
import java.util.*;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

Expand All @@ -41,9 +44,7 @@ public static class GrpcServiceMethod {

private Supplier<Map<MethodDescriptor<?, ?>, GrpcServiceMethod>> descriptorToServiceMethod;

private Supplier< Map<Method,MethodDescriptor<?,?>>> methodToDescriptor ;


private Supplier<Map<Method, MethodDescriptor<?, ?>>> methodToDescriptor;


/**
Expand All @@ -65,11 +66,11 @@ Collection<ServerInterceptor> getGlobalInterceptors() {
return grpcGlobalInterceptors.get();
}

public GrpcServiceMethod getGrpServiceMethod(MethodDescriptor<?,?> descriptor) {
public GrpcServiceMethod getGrpServiceMethod(MethodDescriptor<?, ?> descriptor) {
return descriptorToServiceMethod.get().get(descriptor);
}

public MethodDescriptor<?,?> getMethodDescriptor( Method method) {
public MethodDescriptor<?, ?> getMethodDescriptor(Method method) {
return methodToDescriptor.get().get(method);
}

Expand All @@ -88,11 +89,11 @@ public void afterPropertiesSet() throws Exception {

descriptorToServiceMethod = SingletonSupplier.of(this::descriptorToServiceMethod);

methodToDescriptor = SingletonSupplier.of(()->
descriptorToServiceMethod.get()
.entrySet()
.stream()
.collect(Collectors.toMap(e->e.getValue().getMethod(), Map.Entry::getKey))
methodToDescriptor = SingletonSupplier.of(() ->
descriptorToServiceMethod.get()
.entrySet()
.stream()
.collect(Collectors.toMap(e -> e.getValue().getMethod(), Map.Entry::getKey))
);
beanNameToServiceBean = SingletonSupplier.of(() ->
getBeanNamesByTypeWithAnnotation(GRpcService.class, BindableService.class)
Expand All @@ -119,45 +120,74 @@ public void setApplicationContext(ApplicationContext applicationContext) throws
this.applicationContext = applicationContext;
}

private Map<MethodDescriptor<?, ?>, GrpcServiceMethod> descriptorToServiceMethod (){
private Map<MethodDescriptor<?, ?>, GrpcServiceMethod> descriptorToServiceMethod() {
final Map<MethodDescriptor<?, ?>, GrpcServiceMethod> map = new HashMap<>();

Function<String, ReflectionUtils.MethodFilter> filterFactory = name ->
method ->
method.getName().equalsIgnoreCase(name.replaceAll("_",""));
method.getName().equalsIgnoreCase(name.replaceAll("_", ""));


Predicate<Method> firstArgIsMono = m -> "reactor.core.publisher.Mono".equals(m.getParameterTypes()[0].getName());
Predicate<Method> singleArg = m -> 1 == m.getParameterCount();

for (BindableService service : getBeanNameToServiceBeanMap().values()) {
final ServerServiceDefinition serviceDefinition = service.bindService();
for (MethodDescriptor<?, ?> d : serviceDefinition.getServiceDescriptor().getMethods()) {
Class<?> abstractBaseClass = service.getClass();
while (!Modifier.isAbstract(abstractBaseClass.getModifiers())){
while (!Modifier.isAbstract(abstractBaseClass.getModifiers())) {
abstractBaseClass = abstractBaseClass.getSuperclass();
}

final Set<Method> methods = MethodIntrospector
.selectMethods(abstractBaseClass, filterFactory.apply(d.getBareMethodName()));


switch (methods.size()){
switch (methods.size()) {
case 0:
throw new IllegalStateException("Method " +d.getBareMethodName()+ "not found in service "+ serviceDefinition.getServiceDescriptor().getName());
throw new IllegalStateException("Method " + d.getBareMethodName() + "not found in service " + serviceDefinition.getServiceDescriptor().getName());
case 1:
map.put(d, GrpcServiceMethod.builder()
.service(service)
.method(methods.iterator().next())
.build());
break;
default:
throw new IllegalStateException("Ambiguous method " +d.getBareMethodName()+ " in service "+ serviceDefinition.getServiceDescriptor().getName());
}

if (2 == methods.size()) {

Optional<Method> methodWithMono = methods.stream() // grpcMethod(Mono<Payload> arg)
.filter(singleArg.and(firstArgIsMono))
.findFirst();

Optional<Method> methodPure = methods.stream() // grpcMethod(Payload arg)
.filter(singleArg.and(firstArgIsMono.negate()))
.findFirst();

Class<?> finalAbstractBaseClass = abstractBaseClass;
Boolean typesAreEqual = methodWithMono
.map(m -> ((ParameterizedType) new MethodParameter(m, 0)
.withContainingClass(finalAbstractBaseClass)
.getGenericParameterType())
.getActualTypeArguments()[0]
).map(t -> t.equals(methodPure.map(m -> m.getParameterTypes()[0]).orElse(null)))
.orElse(false);

if (typesAreEqual) {
map.put(d, GrpcServiceMethod.builder()
.service(service)
.method(methodWithMono.get())
.build());
break;
}
}
throw new IllegalStateException("Ambiguous method " + d.getBareMethodName() + " in service " + serviceDefinition.getServiceDescriptor().getName());


}


}
}
return Collections.unmodifiableMap(map);
return Collections.unmodifiableMap(map);
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
package org.lognet.springboot.grpc.security;

import io.grpc.BindableService;
import io.grpc.MethodDescriptor;
import io.grpc.ServerInterceptor;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServiceDescriptor;
import io.grpc.*;
import org.lognet.springboot.grpc.GRpcServicesRegistry;
import org.springframework.beans.factory.BeanCreationException;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.security.access.ConfigAttribute;
import org.springframework.security.access.SecurityConfig;
Expand All @@ -15,11 +11,7 @@
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.*;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand All @@ -40,7 +32,7 @@ public Registry getRegistry() {
@Override
public void configure(GrpcSecurity builder) throws Exception {
registry.processSecuredAnnotation();
builder.setSharedObject(GrpcSecurityMetadataSource.class, new GrpcSecurityMetadataSource(registry.servicesRegistry,registry.securedMethods));
builder.setSharedObject(GrpcSecurityMetadataSource.class, new GrpcSecurityMetadataSource(registry.servicesRegistry, registry.securedMethods));
}


Expand Down Expand Up @@ -99,18 +91,18 @@ public GrpcSecurity withoutSecuredAnnotation() {
}

public AuthorizedMethod anyMethod() {
return anyMethodExcluding(s->false);
return anyMethodExcluding(s -> false);
}

public AuthorizedMethod anyMethodExcluding(MethodDescriptor<?, ?>... methodDescriptor) {
List<MethodDescriptor<?,?>> excludedMethods = Arrays.asList(methodDescriptor);
List<MethodDescriptor<?, ?>> excludedMethods = Arrays.asList(methodDescriptor);
return anyMethodExcluding(excludedMethods::contains);

}


public AuthorizedMethod anyMethodExcluding(Predicate<MethodDescriptor<?, ?>> excludePredicate) {
MethodDescriptor<?,?>[] allMethods = servicesRegistry.getBeanNameToServiceBeanMap()
MethodDescriptor<?, ?>[] allMethods = servicesRegistry.getBeanNameToServiceBeanMap()
.values()
.stream()
.map(BindableService::bindService)
Expand All @@ -124,12 +116,14 @@ public AuthorizedMethod anyMethodExcluding(Predicate<MethodDescriptor<?, ?>> exc


public AuthorizedMethod anyService() {
return anyServiceExcluding(s-> false);
return anyServiceExcluding(s -> false);
}

public AuthorizedMethod anyServiceExcluding(ServiceDescriptor... serviceDescriptor) {
List<ServiceDescriptor> excludedServices = Arrays.asList(serviceDescriptor);
return anyServiceExcluding(excludedServices::contains);
}

public AuthorizedMethod anyServiceExcluding(Predicate<ServiceDescriptor> excludePredicate) {

ServiceDescriptor[] allServices = servicesRegistry.getBeanNameToServiceBeanMap()
Expand All @@ -144,11 +138,13 @@ public AuthorizedMethod anyServiceExcluding(Predicate<ServiceDescriptor> exclude

/**
* Same as {@code withSecuredAnnotation(true)}
*
* @return GrpcSecurity configuration
*/
public GrpcSecurity withSecuredAnnotation() {
return withSecuredAnnotation(true);
}

public GrpcSecurity withSecuredAnnotation(boolean withSecuredAnnotation) {
this.withSecuredAnnotation = withSecuredAnnotation;
return and();
Expand All @@ -163,29 +159,43 @@ private void processSecuredAnnotation() {
// service level security
{
Optional.ofNullable(AnnotationUtils.findAnnotation(service.getClass(), Secured.class))
.ifPresent(secured -> {
if (secured.value().length == 0) {
new AuthorizedMethod(serverServiceDefinition.getServiceDescriptor()).authenticated();
} else {
new AuthorizedMethod(serverServiceDefinition.getServiceDescriptor()).hasAnyAuthority(secured.value());
}
});

}
// method level security
for (ServerMethodDefinition<?, ?> methodDefinition : serverServiceDefinition.getMethods()) {
Stream.of(service.getClass().getMethods()) // get method from methodDefinition
.filter(m -> m.getName().equalsIgnoreCase(methodDefinition.getMethodDescriptor().getBareMethodName()))
.findFirst()
.flatMap(m -> Optional.ofNullable(AnnotationUtils.findAnnotation(m, Secured.class)))
.ifPresent(secured -> {
if (secured.value().length == 0) {
new AuthorizedMethod(methodDefinition.getMethodDescriptor()).authenticated();
new AuthorizedMethod(serverServiceDefinition.getServiceDescriptor()).authenticated();
} else {
new AuthorizedMethod(methodDefinition.getMethodDescriptor()).hasAnyAuthority(secured.value());
new AuthorizedMethod(serverServiceDefinition.getServiceDescriptor()).hasAnyAuthority(secured.value());
}
});

}
// method level security
for (ServerMethodDefinition<?, ?> methodDefinition : serverServiceDefinition.getMethods()) {

List<Secured> secureds = Stream.of(service.getClass().getMethods()) // get method from methodDefinition
.filter(m -> m.getName().equalsIgnoreCase(methodDefinition.getMethodDescriptor().getBareMethodName()))
.map(m -> AnnotationUtils.findAnnotation(m, Secured.class))
.filter(Objects::nonNull)
.toList();
if (secureds.isEmpty()) {
continue;
}
if (1 == secureds.size()) {
Secured secured = secureds.get(0);
if (secured.value().length == 0) {
new AuthorizedMethod(methodDefinition.getMethodDescriptor()).authenticated();
} else {
new AuthorizedMethod(methodDefinition.getMethodDescriptor()).hasAnyAuthority(secured.value());
}
} else {
String errorMessage = String.format("Ambiguous 'Secured' method '%s' in service '%s'." +
"When securing reactive method, the @Secured annotation should be added to the method getting 'Mono<Request>' and not with pure 'Request' argument.",
methodDefinition.getMethodDescriptor().getBareMethodName(),
service.getClass().getName()
);
throw new BeanCreationException(errorMessage);
}


}
}
}
Expand Down

0 comments on commit f0d1b9e

Please sign in to comment.