Skip to content

Commit

Permalink
Polish
Browse files Browse the repository at this point in the history
  • Loading branch information
philwebb committed Jun 8, 2021
1 parent 87d3525 commit be23a29
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 73 deletions.
Expand Up @@ -30,12 +30,6 @@
*/
public abstract class AbstractBeansOfTypeDatabaseInitializerDetector implements DatabaseInitializerDetector {

/**
* Returns the bean types that should be detected as being database initializers.
* @return the database initializer bean types
*/
protected abstract Set<Class<?>> getDatabaseInitializerBeanTypes();

@Override
public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
try {
Expand All @@ -47,4 +41,10 @@ public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
}
}

/**
* Returns the bean types that should be detected as being database initializers.
* @return the database initializer bean types
*/
protected abstract Set<Class<?>> getDatabaseInitializerBeanTypes();

}
Expand Up @@ -32,13 +32,6 @@
public abstract class AbstractBeansOfTypeDependsOnDatabaseInitializationDetector
implements DependsOnDatabaseInitializationDetector {

/**
* Returns the bean types that should be detected as depending on database
* initialization.
* @return the database initialization dependent bean types
*/
protected abstract Set<Class<?>> getDependsOnDatabaseInitializationBeanTypes();

@Override
public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
try {
Expand All @@ -50,4 +43,11 @@ public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
}
}

/**
* Returns the bean types that should be detected as depending on database
* initialization.
* @return the database initialization dependent bean types
*/
protected abstract Set<Class<?>> getDependsOnDatabaseInitializationBeanTypes();

}
Expand Up @@ -16,9 +16,11 @@

package org.springframework.boot.sql.init.dependency;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -65,16 +67,23 @@ public class DatabaseInitializationDependencyConfigurer implements ImportBeanDef

@Override
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
if (registry.containsBeanDefinition(DependsOnDatabaseInitializationPostProcessor.class.getName())) {
return;
String name = DependsOnDatabaseInitializationPostProcessor.class.getName();
if (!registry.containsBeanDefinition(name)) {
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(
DependsOnDatabaseInitializationPostProcessor.class,
this::createDependsOnDatabaseInitializationPostProcessor);
registry.registerBeanDefinition(name, builder.getBeanDefinition());
}
registry.registerBeanDefinition(DependsOnDatabaseInitializationPostProcessor.class.getName(),
BeanDefinitionBuilder
.genericBeanDefinition(DependsOnDatabaseInitializationPostProcessor.class,
() -> new DependsOnDatabaseInitializationPostProcessor(this.environment))
.getBeanDefinition());
}

private DependsOnDatabaseInitializationPostProcessor createDependsOnDatabaseInitializationPostProcessor() {
return new DependsOnDatabaseInitializationPostProcessor(this.environment);
}

/**
* {@link BeanFactoryPostProcessor} used to configure database initialization
* dependency relationships.
*/
static class DependsOnDatabaseInitializationPostProcessor implements BeanFactoryPostProcessor {

private final Environment environment;
Expand All @@ -85,58 +94,55 @@ static class DependsOnDatabaseInitializationPostProcessor implements BeanFactory

@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) {
Set<String> detectedDatabaseInitializers = detectDatabaseInitializers(beanFactory);
if (detectedDatabaseInitializers.isEmpty()) {
Set<String> initializerBeanNames = detectInitializerBeanNames(beanFactory);
if (initializerBeanNames.isEmpty()) {
return;
}
for (String dependentDefinitionName : detectDependsOnDatabaseInitialization(beanFactory,
this.environment)) {
BeanDefinition definition = getBeanDefinition(dependentDefinitionName, beanFactory);
String[] dependencies = definition.getDependsOn();
for (String dependencyName : detectedDatabaseInitializers) {
dependencies = StringUtils.addStringToArray(dependencies, dependencyName);
}
definition.setDependsOn(dependencies);
for (String dependsOnInitializationBeanNames : detectDependsOnInitializationBeanNames(beanFactory)) {
BeanDefinition definition = getBeanDefinition(dependsOnInitializationBeanNames, beanFactory);
definition.setDependsOn(merge(definition.getDependsOn(), initializerBeanNames));
}
}

private Set<String> detectDatabaseInitializers(ConfigurableListableBeanFactory beanFactory) {
List<DatabaseInitializerDetector> detectors = instantiateDetectors(beanFactory, this.environment,
DatabaseInitializerDetector.class);
Set<String> detected = new HashSet<>();
private String[] merge(String[] source, Set<String> additional) {
Set<String> result = new LinkedHashSet<>((source != null) ? Arrays.asList(source) : Collections.emptySet());
result.addAll(additional);
return StringUtils.toStringArray(result);
}

private Set<String> detectInitializerBeanNames(ConfigurableListableBeanFactory beanFactory) {
List<DatabaseInitializerDetector> detectors = getDetectors(beanFactory, DatabaseInitializerDetector.class);
Set<String> beanNames = new HashSet<>();
for (DatabaseInitializerDetector detector : detectors) {
for (String initializerName : detector.detect(beanFactory)) {
detected.add(initializerName);
beanFactory.getBeanDefinition(initializerName)
.setAttribute(DatabaseInitializerDetector.class.getName(), detector.getClass().getName());
for (String beanName : detector.detect(beanFactory)) {
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
beanDefinition.setAttribute(DatabaseInitializerDetector.class.getName(),
detector.getClass().getName());
beanNames.add(beanName);
}
}
detected = Collections.unmodifiableSet(detected);
beanNames = Collections.unmodifiableSet(beanNames);
for (DatabaseInitializerDetector detector : detectors) {
detector.detectionComplete(beanFactory, detected);
detector.detectionComplete(beanFactory, beanNames);
}
return detected;
return beanNames;
}

private Collection<String> detectDependsOnDatabaseInitialization(ConfigurableListableBeanFactory beanFactory,
Environment environment) {
List<DependsOnDatabaseInitializationDetector> detectors = instantiateDetectors(beanFactory, environment,
private Collection<String> detectDependsOnInitializationBeanNames(ConfigurableListableBeanFactory beanFactory) {
List<DependsOnDatabaseInitializationDetector> detectors = getDetectors(beanFactory,
DependsOnDatabaseInitializationDetector.class);
Set<String> dependentUponDatabaseInitialization = new HashSet<>();
Set<String> beanNames = new HashSet<>();
for (DependsOnDatabaseInitializationDetector detector : detectors) {
dependentUponDatabaseInitialization.addAll(detector.detect(beanFactory));
beanNames.addAll(detector.detect(beanFactory));
}
return dependentUponDatabaseInitialization;
return beanNames;
}

private <T> List<T> instantiateDetectors(ConfigurableListableBeanFactory beanFactory, Environment environment,
Class<T> detectorType) {
List<String> detectorNames = SpringFactoriesLoader.loadFactoryNames(detectorType,
beanFactory.getBeanClassLoader());
Instantiator<T> instantiator = new Instantiator<>(detectorType,
(availableParameters) -> availableParameters.add(Environment.class, environment));
List<T> detectors = instantiator.instantiate(detectorNames);
return detectors;
private <T> List<T> getDetectors(ConfigurableListableBeanFactory beanFactory, Class<T> type) {
List<String> names = SpringFactoriesLoader.loadFactoryNames(type, beanFactory.getBeanClassLoader());
Instantiator<T> instantiator = new Instantiator<>(type,
(availableParameters) -> availableParameters.add(Environment.class, this.environment));
return instantiator.instantiate(names);
}

private static BeanDefinition getBeanDefinition(String beanName, ConfigurableListableBeanFactory beanFactory) {
Expand Down
Expand Up @@ -33,7 +33,6 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.mockito.Mockito;

import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
Expand All @@ -47,6 +46,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;

Expand All @@ -59,16 +59,12 @@ class DatabaseInitializationDependencyConfigurerTests {

private final ConfigurableEnvironment environment = new MockEnvironment();

DatabaseInitializerDetector databaseInitializerDetector = MockedDatabaseInitializerDetector.mock;

DependsOnDatabaseInitializationDetector dependsOnDatabaseInitializationDetector = MockedDependsOnDatabaseInitializationDetector.mock;

@TempDir
File temp;

@BeforeEach
void resetMocks() {
reset(MockedDatabaseInitializerDetector.mock, MockedDependsOnDatabaseInitializationDetector.mock);
reset(MockDatabaseInitializerDetector.instance, MockedDependsOnDatabaseInitializationDetector.instance);
}

@Test
Expand All @@ -89,19 +85,19 @@ void whenDetectorsAreCreatedThenTheEnvironmentCanBeInjected() {
void whenDependenciesAreConfiguredThenBeansThatDependUponDatabaseInitializationDependUponDetectedDatabaseInitializers() {
BeanDefinition alpha = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition();
BeanDefinition bravo = BeanDefinitionBuilder.genericBeanDefinition(String.class).getBeanDefinition();
performDetection(Arrays.asList(MockedDatabaseInitializerDetector.class,
performDetection(Arrays.asList(MockDatabaseInitializerDetector.class,
MockedDependsOnDatabaseInitializationDetector.class), (context) -> {
context.registerBeanDefinition("alpha", alpha);
context.registerBeanDefinition("bravo", bravo);
given(this.databaseInitializerDetector.detect(context.getBeanFactory()))
given(MockDatabaseInitializerDetector.instance.detect(context.getBeanFactory()))
.willReturn(Collections.singleton("alpha"));
given(this.dependsOnDatabaseInitializationDetector.detect(context.getBeanFactory()))
given(MockedDependsOnDatabaseInitializationDetector.instance.detect(context.getBeanFactory()))
.willReturn(Collections.singleton("bravo"));
context.refresh();
assertThat(alpha.getAttribute(DatabaseInitializerDetector.class.getName()))
.isEqualTo(MockedDatabaseInitializerDetector.class.getName());
.isEqualTo(MockDatabaseInitializerDetector.class.getName());
assertThat(bravo.getAttribute(DatabaseInitializerDetector.class.getName())).isNull();
verify(this.databaseInitializerDetector).detectionComplete(context.getBeanFactory(),
verify(MockDatabaseInitializerDetector.instance).detectionComplete(context.getBeanFactory(),
Collections.singleton("alpha"));
assertThat(bravo.getDependsOn()).containsExactly("alpha");
});
Expand Down Expand Up @@ -156,31 +152,31 @@ public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {

}

static class MockedDatabaseInitializerDetector implements DatabaseInitializerDetector {
static class MockDatabaseInitializerDetector implements DatabaseInitializerDetector {

private static DatabaseInitializerDetector mock = Mockito.mock(DatabaseInitializerDetector.class);
private static DatabaseInitializerDetector instance = mock(DatabaseInitializerDetector.class);

@Override
public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
return MockedDatabaseInitializerDetector.mock.detect(beanFactory);
return MockDatabaseInitializerDetector.instance.detect(beanFactory);
}

@Override
public void detectionComplete(ConfigurableListableBeanFactory beanFactory,
Set<String> databaseInitializerNames) {
mock.detectionComplete(beanFactory, databaseInitializerNames);
instance.detectionComplete(beanFactory, databaseInitializerNames);
}

}

static class MockedDependsOnDatabaseInitializationDetector implements DependsOnDatabaseInitializationDetector {

private static DependsOnDatabaseInitializationDetector mock = Mockito
.mock(DependsOnDatabaseInitializationDetector.class);
private static DependsOnDatabaseInitializationDetector instance = mock(
DependsOnDatabaseInitializationDetector.class);

@Override
public Set<String> detect(ConfigurableListableBeanFactory beanFactory) {
return MockedDependsOnDatabaseInitializationDetector.mock.detect(beanFactory);
return instance.detect(beanFactory);
}

}
Expand Down

0 comments on commit be23a29

Please sign in to comment.