Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AVRO-3947: [Java] Support subclasses in custom LogicalType Conversions #2766

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.nio.charset.StandardCharsets;
import java.time.temporal.Temporal;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
Expand All @@ -36,6 +37,7 @@
import java.util.UUID;
import java.util.concurrent.ConcurrentMap;

import com.fasterxml.jackson.databind.JsonNode;
import org.apache.avro.AvroMissingFieldException;
import org.apache.avro.AvroRuntimeException;
import org.apache.avro.AvroTypeException;
Expand All @@ -57,8 +59,6 @@
import org.apache.avro.io.FastReaderBuilder;
import org.apache.avro.util.Utf8;
import org.apache.avro.util.internal.Accessor;

import com.fasterxml.jackson.databind.JsonNode;
import org.apache.avro.util.springframework.ConcurrentReferenceHashMap;

import static org.apache.avro.util.springframework.ConcurrentReferenceHashMap.ReferenceType.WEAK;
Expand Down Expand Up @@ -141,6 +141,8 @@ private void loadConversions() {

private Map<Class<?>, Map<String, Conversion<?>>> conversionsByClass = new IdentityHashMap<>();

private Map<String, List<Conversion<?>>> conversionsByLogialTypeName = new HashMap<>();

public Collection<Conversion<?>> getConversions() {
return conversions.values();
}
Expand All @@ -153,11 +155,11 @@ public Collection<Conversion<?>> getConversions() {
* @param conversion a logical type Conversion.
*/
public void addLogicalTypeConversion(Conversion<?> conversion) {
conversions.put(conversion.getLogicalTypeName(), conversion);
String logicalTypeName = conversion.getLogicalTypeName();
conversions.put(logicalTypeName, conversion);
Class<?> type = conversion.getConvertedType();
Map<String, Conversion<?>> conversionsForClass = conversionsByClass.computeIfAbsent(type,
k -> new LinkedHashMap<>());
conversionsForClass.put(conversion.getLogicalTypeName(), conversion);
conversionsByClass.computeIfAbsent(type, k -> new LinkedHashMap<>()).put(logicalTypeName, conversion);
conversionsByLogialTypeName.computeIfAbsent(logicalTypeName, k -> new ArrayList<>()).add(conversion);
}

/**
Expand All @@ -176,17 +178,23 @@ public <T> Conversion<T> getConversionByClass(Class<T> datumClass) {
}

/**
* Returns the conversion for the given class and logical type.
* Returns the first conversion for the given datum class and logical type.
*
* @param datumClass a Class
* @param datumClass a Class of the datum
* @param logicalType a LogicalType
* @return the conversion for the class and logical type, or null
*/
@SuppressWarnings("unchecked")
public <T> Conversion<T> getConversionByClass(Class<T> datumClass, LogicalType logicalType) {
Map<String, Conversion<?>> conversions = conversionsByClass.get(datumClass);
public <T> Conversion<T> getConversionByClass(Class<? extends T> datumClass, LogicalType logicalType) {
// note this does not use conversionsByClass anymore - which assume instances are of
// type Class<T> and not Class<? extends T>
List<Conversion<?>> conversions = conversionsByLogialTypeName.get(logicalType.getName());
if (conversions != null) {
return (Conversion<T>) conversions.get(logicalType.getName());
for (Conversion<?> conversion : conversions) {
if (conversion.getConvertedType().isAssignableFrom(datumClass)) {
return (Conversion<T>) conversion;
}
}
}
return null;
}
Expand Down Expand Up @@ -916,14 +924,13 @@ public int resolveUnion(Schema union, Object datum) {
// this allows logical type concrete classes to overlap with supported ones
// for example, a conversion could return a map
if (datum != null) {
Map<String, Conversion<?>> conversions = conversionsByClass.get(datum.getClass());
if (conversions != null) {
List<Schema> candidates = union.getTypes();
for (int i = 0; i < candidates.size(); i += 1) {
LogicalType candidateType = candidates.get(i).getLogicalType();
if (candidateType != null) {
Conversion<?> conversion = conversions.get(candidateType.getName());
if (conversion != null) {
List<Schema> candidates = union.getTypes();
for (int i = 0; i < candidates.size(); i += 1) {
LogicalType candidateType = candidates.get(i).getLogicalType();
if (candidateType != null) {
Conversion<?> conversion = conversions.get(candidateType.getName());
if (conversion != null) {
if (conversion.getConvertedType().isInstance(datum)) {
return i;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import java.util.Objects;

public final class CustomType {
public class CustomType {
private final String name;

public CustomType(CharSequence name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ public String getLogicalTypeName() {

@Override
public Schema getRecommendedSchema() {
return Schema.create(Schema.Type.STRING);
Schema stringSchema = Schema.create(Schema.Type.STRING);
return logicalTypeFactory.fromSchema(stringSchema).addToSchema(stringSchema);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotSame;

public class TestGenericLogicalTypes {
Expand Down Expand Up @@ -488,4 +489,25 @@ public void testWriteAutomaticallyRegisteredUri() throws IOException {
assertEquals(expected, read(GenericData.get().createDatumReader(stringSchema), test),
"Should read CustomType as strings");
}

@Test
public void testLogicalTypeSubclassing() throws IOException {
Schema stringSchema = Schema.create(Schema.Type.STRING);
GenericData.setStringType(stringSchema, GenericData.StringType.String);
LogicalType customType = LogicalTypes.getCustomRegisteredTypes().get("custom").fromSchema(stringSchema);
Schema customTypeSchema = customType.addToSchema(Schema.create(Schema.Type.STRING));
Schema unionSchema = Schema.createUnion(Schema.create(Schema.Type.BOOLEAN), customTypeSchema);

// anonymous subclass
CustomType datum1 = new CustomType("foo") {
};
assertNotEquals(datum1.getClass(), CustomType.class);
int index = GenericData.get().resolveUnion(unionSchema, datum1);
assertEquals(1, index, "Should resolve custom type subclass correct schema");

List<Object> expected = Arrays.asList(datum1, false);
File test = write(unionSchema, datum1, false);
assertEquals(expected, read(GENERIC.createDatumReader(unionSchema), test),
"Should convert logical type subclasses");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import java.util.UUID;
import org.apache.avro.Conversion;
import org.apache.avro.Conversions;
import org.apache.avro.CustomType;
import org.apache.avro.CustomTypeConverter;
import org.apache.avro.CustomTypeLogicalTypeFactory;
import org.apache.avro.LogicalType;
import org.apache.avro.LogicalTypes;
import org.apache.avro.Schema;
Expand Down Expand Up @@ -62,10 +65,11 @@ public class TestReflectLogicalTypes {
public static final ReflectData REFLECT = new ReflectData();

@BeforeAll
public static void addUUID() {
public static void addLogicalTypeConversions() {
REFLECT.addLogicalTypeConversion(new Conversions.UUIDConversion());
REFLECT.addLogicalTypeConversion(new Conversions.DecimalConversion());
REFLECT.addLogicalTypeConversion(new TimeConversions.LocalTimestampMillisConversion());
REFLECT.addLogicalTypeConversion(new CustomTypeConverter());
}

@Test
Expand Down Expand Up @@ -595,6 +599,26 @@ void reflectedSchemaLocalDateTime() {
LogicalTypes.fromSchema(actual.getField("localDateTime").schema()), "Should have the correct logical type");
}

@Test
void customLogicalType() throws IOException {
Schema schema = REFLECT.getSchema(RecordWithCustomLogicalType.class);
Schema fieldSchema = schema.getField("customType").schema();
assertEquals(Schema.Type.STRING, fieldSchema.getType(), "Should have the correct physical type");
String actualLogicalTypeName = fieldSchema.getLogicalType().getName();
assertEquals("custom", actualLogicalTypeName, "Should have the correct logical type name");

RecordWithCustomLogicalType record1 = new RecordWithCustomLogicalType();
record1.customType = new CustomType("foo");
RecordWithCustomLogicalType record2 = new RecordWithCustomLogicalType();
// anonymous subclass
record2.customType = new CustomType("bar") {
};

File test = write(REFLECT, schema, record1, record2);
assertEquals(Arrays.asList(record1, record2), read(REFLECT.createDatumReader(schema), test),
"Should match the decimal after round trip");
}

private static <D> List<D> read(DatumReader<D> reader, File file) throws IOException {
List<D> data = new ArrayList<>();

Expand Down Expand Up @@ -731,3 +755,24 @@ public boolean equals(Object obj) {
return Objects.equals(localDateTime, that.localDateTime);
}
}

class RecordWithCustomLogicalType {
CustomType customType;

@Override
public int hashCode() {
return Objects.hash(customType);
}

@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (!(obj instanceof RecordWithCustomLogicalType)) {
return false;
}
RecordWithCustomLogicalType that = (RecordWithCustomLogicalType) obj;
return Objects.equals(customType, that.customType);
}
}