diff --git a/spring-test/src/main/java/org/springframework/test/context/aot/TestContextAotGenerator.java b/spring-test/src/main/java/org/springframework/test/context/aot/TestContextAotGenerator.java index a6522477be0a..9715821e9fdf 100644 --- a/spring-test/src/main/java/org/springframework/test/context/aot/TestContextAotGenerator.java +++ b/spring-test/src/main/java/org/springframework/test/context/aot/TestContextAotGenerator.java @@ -16,7 +16,10 @@ package org.springframework.test.context.aot; +import java.util.Arrays; +import java.util.LinkedHashSet; import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Stream; @@ -30,12 +33,19 @@ import org.springframework.aot.generate.GeneratedFiles; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; import org.springframework.aot.hint.TypeReference; +import org.springframework.aot.hint.annotation.ReflectiveRuntimeHintsRegistrar; +import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.aot.AotServices; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.annotation.ImportRuntimeHints; import org.springframework.context.aot.ApplicationContextAotGenerator; import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.core.annotation.MergedAnnotations; +import org.springframework.core.annotation.MergedAnnotations.SearchStrategy; import org.springframework.core.log.LogMessage; import org.springframework.javapoet.ClassName; import org.springframework.test.context.BootstrapUtils; @@ -117,16 +127,36 @@ public void processAheadOfTime(Stream> testClasses) throws TestContextA try { resetAotFactories(); + Set runtimeHintsRegistrars = new LinkedHashSet<>(); + ReflectiveRuntimeHintsRegistrar reflectiveRuntimeHintsRegistrar = new ReflectiveRuntimeHintsRegistrar(); + MultiValueMap> mergedConfigMappings = new LinkedMultiValueMap<>(); ClassLoader classLoader = getClass().getClassLoader(); testClasses.forEach(testClass -> { MergedContextConfiguration mergedConfig = buildMergedContextConfiguration(testClass); mergedConfigMappings.add(mergedConfig, testClass); + + MergedAnnotations.from(testClass, SearchStrategy.TYPE_HIERARCHY).stream(ImportRuntimeHints.class) + .map(MergedAnnotation::synthesize) + .map(ImportRuntimeHints::value) + .flatMap(Arrays::stream) + .map(BeanUtils::instantiateClass) + .forEach(runtimeHintsRegistrars::add); + + reflectiveRuntimeHintsRegistrar.registerRuntimeHints(this.runtimeHints, testClass); this.testRuntimeHintsRegistrars.forEach(registrar -> registrar.registerHints(this.runtimeHints, testClass, classLoader)); }); MultiValueMap> initializerClassMappings = processAheadOfTime(mergedConfigMappings); + runtimeHintsRegistrars.forEach(registrar -> { + if (logger.isTraceEnabled()) { + logger.trace("Processing RuntimeHints contribution from test class [%s]" + .formatted(registrar.getClass().getCanonicalName())); + } + registrar.registerHints(this.runtimeHints, classLoader); + }); + generateAotTestContextInitializerMappings(initializerClassMappings); generateAotTestAttributeMappings(); } diff --git a/spring-test/src/test/java/org/springframework/test/context/aot/DeclarativeRuntimeHintsTests.java b/spring-test/src/test/java/org/springframework/test/context/aot/DeclarativeRuntimeHintsTests.java new file mode 100644 index 000000000000..b4a9c433ec5f --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/aot/DeclarativeRuntimeHintsTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.aot; + +import java.util.Set; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.InMemoryGeneratedFiles; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.test.context.aot.samples.hints.DeclarativeRuntimeHintsSpringJupiterTests; + +import static java.util.Comparator.comparing; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; +import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.resource; + +/** + * Tests for declarative support for registering run-time hints for tests, tested + * via the {@link TestContextAotGenerator} + * + * @author Sam Brannen + * @since 6.0 + */ +class DeclarativeRuntimeHintsTests extends AbstractAotTests { + + @Test + void declarativeRuntimeHints() { + Set> testClasses = Set.of(DeclarativeRuntimeHintsSpringJupiterTests.class); + TestContextAotGenerator generator = new TestContextAotGenerator(new InMemoryGeneratedFiles()); + RuntimeHints runtimeHints = generator.getRuntimeHints(); + + generator.processAheadOfTime(testClasses.stream().sorted(comparing(Class::getName))); + + // @Reflective + assertReflectionRegistered(runtimeHints, DeclarativeRuntimeHintsSpringJupiterTests.class); + + // @ImportRuntimeHints + assertThat(resource().forResource("org/example/config/enigma.txt")).accepts(runtimeHints); + assertThat(resource().forResource("org/example/config/level2/foo.txt")).accepts(runtimeHints); + } + + private static void assertReflectionRegistered(RuntimeHints runtimeHints, Class type) { + assertThat(reflection().onType(type)) + .as("Reflection hint for %s", type.getSimpleName()) + .accepts(runtimeHints); + } + + private static void assertReflectionRegistered(RuntimeHints runtimeHints, Class type, MemberCategory memberCategory) { + assertThat(reflection().onType(type).withMemberCategory(memberCategory)) + .as("Reflection hint for %s with category %s", type.getSimpleName(), memberCategory) + .accepts(runtimeHints); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/context/aot/samples/hints/DeclarativeRuntimeHintsSpringJupiterTests.java b/spring-test/src/test/java/org/springframework/test/context/aot/samples/hints/DeclarativeRuntimeHintsSpringJupiterTests.java new file mode 100644 index 000000000000..0688fc66a170 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/aot/samples/hints/DeclarativeRuntimeHintsSpringJupiterTests.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.aot.samples.hints; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.aot.hint.annotation.Reflective; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.ImportRuntimeHints; +import org.springframework.test.context.aot.samples.hints.DeclarativeRuntimeHintsSpringJupiterTests.DemoHints; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Sam Brannen + * @since 6.0 + */ +@SpringJUnitConfig +@Reflective +@ImportRuntimeHints(DemoHints.class) +public class DeclarativeRuntimeHintsSpringJupiterTests { + + @Test + void test(@Autowired String foo) { + assertThat(foo).isEqualTo("bar"); + } + + + @Configuration(proxyBeanMethods = false) + static class Config { + + @Bean + String foo() { + return "bar"; + } + } + + static class DemoHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(RuntimeHints hints, ClassLoader classLoader) { + hints.resources().registerPattern("org/example/config/*.txt"); + } + + } + +}