diff --git a/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/file/transform/RecordFieldExtractor.java b/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/file/transform/RecordFieldExtractor.java new file mode 100644 index 0000000000..fb5b15a064 --- /dev/null +++ b/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/file/transform/RecordFieldExtractor.java @@ -0,0 +1,101 @@ +/* + * Copyright 2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.batch.item.file.transform; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.RecordComponent; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * This is a field extractor for a Java record. By default, it will extract all record + * components, unless a subset is selected using {@link #setNames(String...)}. + * + * @author Mahmoud Ben Hassine + * @since 5.0 + */ +public class RecordFieldExtractor implements FieldExtractor { + + private List names; + + private Class targetType; + + private RecordComponent[] recordComponents; + + public RecordFieldExtractor(Class targetType) { + Assert.notNull(targetType, "target type must not be null"); + Assert.isTrue(targetType.isRecord(), "target type must be a record"); + this.targetType = targetType; + this.recordComponents = this.targetType.getRecordComponents(); + this.names = getRecordComponentNames(); + } + + /** + * Set the names of record components to extract. + * @param names of record component to be extracted. + */ + public void setNames(String... names) { + Assert.notNull(names, "Names must not be null"); + Assert.notEmpty(names, "Names must not be empty"); + validate(names); + this.names = Arrays.stream(names).toList(); + } + + /** + * @see FieldExtractor#extract(Object) + */ + @Override + public Object[] extract(T item) { + List values = new ArrayList<>(); + for (String componentName : this.names) { + RecordComponent recordComponent = getRecordComponentByName(componentName); + Object value; + try { + value = recordComponent.getAccessor().invoke(item); + values.add(value); + } + catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException("Unable to extract value for record component " + componentName, e); + } + } + return values.toArray(); + } + + private List getRecordComponentNames() { + return Arrays.stream(this.recordComponents).map(recordComponent -> recordComponent.getName()).toList(); + } + + private void validate(String[] names) { + for (String name : names) { + if (getRecordComponentByName(name) == null) { + throw new IllegalArgumentException( + "Component '" + name + "' is not defined in record " + targetType.getName()); + } + } + } + + @Nullable + private RecordComponent getRecordComponentByName(String name) { + return Arrays.stream(this.recordComponents).filter(recordComponent -> recordComponent.getName().equals(name)) + .findFirst().orElse(null); + } + +} diff --git a/spring-batch-infrastructure/src/test/java/org/springframework/batch/item/file/transform/RecordFieldExtractorTests.java b/spring-batch-infrastructure/src/test/java/org/springframework/batch/item/file/transform/RecordFieldExtractorTests.java new file mode 100644 index 0000000000..e0a920290c --- /dev/null +++ b/spring-batch-infrastructure/src/test/java/org/springframework/batch/item/file/transform/RecordFieldExtractorTests.java @@ -0,0 +1,85 @@ +/* + * Copyright 2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.batch.item.file.transform; + +import org.junit.Assert; +import org.junit.Test; + +/** + * @author Mahmoud Ben Hassine + */ +public class RecordFieldExtractorTests { + + @Test(expected = IllegalArgumentException.class) + public void testSetupWithNullTargetType() { + new RecordFieldExtractor<>(null); + } + + @Test(expected = IllegalArgumentException.class) + public void testSetupWithNonRecordTargetType() { + new RecordFieldExtractor<>(NonRecordType.class); + } + + @Test + public void testExtractFields() { + // given + RecordFieldExtractor recordFieldExtractor = new RecordFieldExtractor<>(Person.class); + Person person = new Person(1, "foo"); + + // when + Object[] fields = recordFieldExtractor.extract(person); + + // then + Assert.assertNotNull(fields); + Assert.assertArrayEquals(new Object[] { 1, "foo" }, fields); + } + + @Test + public void testExtractFieldsSubset() { + // given + RecordFieldExtractor recordFieldExtractor = new RecordFieldExtractor<>(Person.class); + recordFieldExtractor.setNames("name"); + Person person = new Person(1, "foo"); + + // when + Object[] fields = recordFieldExtractor.extract(person); + + // then + Assert.assertNotNull(fields); + Assert.assertArrayEquals(new Object[] { "foo" }, fields); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidComponentName() { + // given + RecordFieldExtractor recordFieldExtractor = new RecordFieldExtractor<>(Person.class); + recordFieldExtractor.setNames("nonExistent"); + Person person = new Person(1, "foo"); + + // when + recordFieldExtractor.extract(person); + + // then + // expected exception + } + + public record Person(int id, String name) { + } + + public class NonRecordType { + } + +} \ No newline at end of file