Skip to content

Commit

Permalink
AVRO-3947: Support any subclass or instance in custom LogicalType Con…
Browse files Browse the repository at this point in the history
…versions
  • Loading branch information
tmoschou committed Feb 26, 2024
1 parent 627e8d5 commit 884b0f3
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.nio.charset.StandardCharsets;
import java.time.temporal.Temporal;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
Expand All @@ -36,6 +37,7 @@
import java.util.UUID;
import java.util.concurrent.ConcurrentMap;

import com.fasterxml.jackson.databind.JsonNode;
import org.apache.avro.AvroMissingFieldException;
import org.apache.avro.AvroRuntimeException;
import org.apache.avro.AvroTypeException;
Expand All @@ -57,8 +59,6 @@
import org.apache.avro.io.FastReaderBuilder;
import org.apache.avro.util.Utf8;
import org.apache.avro.util.internal.Accessor;

import com.fasterxml.jackson.databind.JsonNode;
import org.apache.avro.util.springframework.ConcurrentReferenceHashMap;

import static org.apache.avro.util.springframework.ConcurrentReferenceHashMap.ReferenceType.WEAK;
Expand Down Expand Up @@ -141,6 +141,8 @@ private void loadConversions() {

private Map<Class<?>, Map<String, Conversion<?>>> conversionsByClass = new IdentityHashMap<>();

private Map<String, List<Conversion<?>>> conversionsByLogialTypeName = new HashMap<>();

public Collection<Conversion<?>> getConversions() {
return conversions.values();
}
Expand All @@ -153,11 +155,11 @@ public Collection<Conversion<?>> getConversions() {
* @param conversion a logical type Conversion.
*/
public void addLogicalTypeConversion(Conversion<?> conversion) {
conversions.put(conversion.getLogicalTypeName(), conversion);
String logicalTypeName = conversion.getLogicalTypeName();
conversions.put(logicalTypeName, conversion);
Class<?> type = conversion.getConvertedType();
Map<String, Conversion<?>> conversionsForClass = conversionsByClass.computeIfAbsent(type,
k -> new LinkedHashMap<>());
conversionsForClass.put(conversion.getLogicalTypeName(), conversion);
conversionsByClass.computeIfAbsent(type, k -> new LinkedHashMap<>()).put(logicalTypeName, conversion);
conversionsByLogialTypeName.computeIfAbsent(logicalTypeName, k -> new ArrayList<>()).add(conversion);
}

/**
Expand All @@ -176,17 +178,21 @@ public <T> Conversion<T> getConversionByClass(Class<T> datumClass) {
}

/**
* Returns the conversion for the given class and logical type.
* Returns the first conversion for the given datum class and logical type.
*
* @param datumClass a Class
* @param datumClass a Class of the datum
* @param logicalType a LogicalType
* @return the conversion for the class and logical type, or null
*/
@SuppressWarnings("unchecked")
public <T> Conversion<T> getConversionByClass(Class<T> datumClass, LogicalType logicalType) {
Map<String, Conversion<?>> conversions = conversionsByClass.get(datumClass);
public <T> Conversion<T> getConversionByClass(Class<? extends T> datumClass, LogicalType logicalType) {
List<Conversion<?>> conversions = conversionsByLogialTypeName.get(logicalType.getName());
if (conversions != null) {
return (Conversion<T>) conversions.get(logicalType.getName());
for (Conversion<?> conversion : conversions) {
if (conversion.getConvertedType().isAssignableFrom(datumClass)) {
return (Conversion<T>) conversion;
}
}
}
return null;
}
Expand Down Expand Up @@ -916,14 +922,13 @@ public int resolveUnion(Schema union, Object datum) {
// this allows logical type concrete classes to overlap with supported ones
// for example, a conversion could return a map
if (datum != null) {
Map<String, Conversion<?>> conversions = conversionsByClass.get(datum.getClass());
if (conversions != null) {
List<Schema> candidates = union.getTypes();
for (int i = 0; i < candidates.size(); i += 1) {
LogicalType candidateType = candidates.get(i).getLogicalType();
if (candidateType != null) {
Conversion<?> conversion = conversions.get(candidateType.getName());
if (conversion != null) {
List<Schema> candidates = union.getTypes();
for (int i = 0; i < candidates.size(); i += 1) {
LogicalType candidateType = candidates.get(i).getLogicalType();
if (candidateType != null) {
Conversion<?> conversion = conversions.get(candidateType.getName());
if (conversion != null) {
if (conversion.getConvertedType().isInstance(datum)) {
return i;
}
}
Expand Down Expand Up @@ -1375,7 +1380,7 @@ public <T> T deepCopy(Schema schema, T value) {
LogicalType logicalType = schema.getLogicalType();
if (logicalType == null) // not a logical type -- use raw copy
return (T) deepCopyRaw(schema, value);
Conversion conversion = getConversionByClass(value.getClass(), logicalType);
Conversion conversion = getConversionFor(logicalType);
if (conversion == null) // no conversion defined -- try raw copy
return (T) deepCopyRaw(schema, value);
// logical type with conversion: convert to raw, copy, then convert back to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import java.util.Objects;

public final class CustomType {
public class CustomType {
private final String name;

public CustomType(CharSequence name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ public String getLogicalTypeName() {

@Override
public Schema getRecommendedSchema() {
return Schema.create(Schema.Type.STRING);
Schema stringSchema = Schema.create(Schema.Type.STRING);
return logicalTypeFactory.fromSchema(stringSchema).addToSchema(stringSchema);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotSame;

public class TestGenericLogicalTypes {
Expand Down Expand Up @@ -488,4 +489,25 @@ public void testWriteAutomaticallyRegisteredUri() throws IOException {
assertEquals(expected, read(GenericData.get().createDatumReader(stringSchema), test),
"Should read CustomType as strings");
}

@Test
public void testLogicalTypeSubclassing() throws IOException {
Schema stringSchema = Schema.create(Schema.Type.STRING);
GenericData.setStringType(stringSchema, GenericData.StringType.String);
LogicalType customType = LogicalTypes.getCustomRegisteredTypes().get("custom").fromSchema(stringSchema);
Schema customTypeSchema = customType.addToSchema(Schema.create(Schema.Type.STRING));
Schema unionSchema = Schema.createUnion(Schema.create(Schema.Type.BOOLEAN), customTypeSchema);

// anonymous subclass
CustomType datum1 = new CustomType("foo") {
};
assertNotEquals(datum1.getClass(), CustomType.class);
int index = GenericData.get().resolveUnion(unionSchema, datum1);
assertEquals(1, index, "Should resolve custom type subclass correct schema");

List<Object> expected = Arrays.asList(datum1, false);
File test = write(unionSchema, datum1, false);
assertEquals(expected, read(GENERIC.createDatumReader(unionSchema), test),
"Should convert logical type subclasses");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import java.util.UUID;
import org.apache.avro.Conversion;
import org.apache.avro.Conversions;
import org.apache.avro.CustomType;
import org.apache.avro.CustomTypeConverter;
import org.apache.avro.CustomTypeLogicalTypeFactory;
import org.apache.avro.LogicalType;
import org.apache.avro.LogicalTypes;
import org.apache.avro.Schema;
Expand Down Expand Up @@ -62,10 +65,11 @@ public class TestReflectLogicalTypes {
public static final ReflectData REFLECT = new ReflectData();

@BeforeAll
public static void addUUID() {
public static void addLogicalTypeConversions() {
REFLECT.addLogicalTypeConversion(new Conversions.UUIDConversion());
REFLECT.addLogicalTypeConversion(new Conversions.DecimalConversion());
REFLECT.addLogicalTypeConversion(new TimeConversions.LocalTimestampMillisConversion());
REFLECT.addLogicalTypeConversion(new CustomTypeConverter());
}

@Test
Expand Down Expand Up @@ -595,6 +599,26 @@ void reflectedSchemaLocalDateTime() {
LogicalTypes.fromSchema(actual.getField("localDateTime").schema()), "Should have the correct logical type");
}

@Test
void customLogicalType() throws IOException {
Schema schema = REFLECT.getSchema(RecordWithCustomLogicalType.class);
Schema fieldSchema = schema.getField("customType").schema();
assertEquals(Schema.Type.STRING, fieldSchema.getType(), "Should have the correct physical type");
String actualLogicalTypeName = fieldSchema.getLogicalType().getName();
assertEquals("custom", actualLogicalTypeName, "Should have the correct logical type name");

RecordWithCustomLogicalType record1 = new RecordWithCustomLogicalType();
record1.customType = new CustomType("foo");
RecordWithCustomLogicalType record2 = new RecordWithCustomLogicalType();
// anonymous subclass
record2.customType = new CustomType("bar") {
};

File test = write(REFLECT, schema, record1, record2);
assertEquals(Arrays.asList(record1, record2), read(REFLECT.createDatumReader(schema), test),
"Should match the decimal after round trip");
}

private static <D> List<D> read(DatumReader<D> reader, File file) throws IOException {
List<D> data = new ArrayList<>();

Expand Down Expand Up @@ -731,3 +755,24 @@ public boolean equals(Object obj) {
return Objects.equals(localDateTime, that.localDateTime);
}
}

class RecordWithCustomLogicalType {
CustomType customType;

@Override
public int hashCode() {
return Objects.hash(customType);
}

@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (!(obj instanceof RecordWithCustomLogicalType)) {
return false;
}
RecordWithCustomLogicalType that = (RecordWithCustomLogicalType) obj;
return Objects.equals(customType, that.customType);
}
}

0 comments on commit 884b0f3

Please sign in to comment.