Skip to content

Commit

Permalink
fix: Actual types passed to generic POJO super classes and to generic…
Browse files Browse the repository at this point in the history
… collection fields are ignored firebase#5334
  • Loading branch information
eranl committed Sep 14, 2023
1 parent c5a894c commit 6bed512
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 45 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.firebase.firestore;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* When deserializing to generic collection fields of a generic class annotated with this
* annotation, generic type mappings will be strictly enforced. Without this annotation, such
* collection fields will take values of any type.
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
public @interface StrictCollectionTypes {}
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
import com.google.firebase.firestore.IgnoreExtraProperties;
import com.google.firebase.firestore.PropertyName;
import com.google.firebase.firestore.ServerTimestamp;
import com.google.firebase.firestore.StrictCollectionTypes;
import com.google.firebase.firestore.ThrowOnExtraProperties;
import java.lang.reflect.AccessibleObject;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.GenericArrayType;
import java.lang.reflect.Member;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.ParameterizedType;
Expand All @@ -43,7 +45,6 @@
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -185,11 +186,12 @@ private static <T> Object serialize(T o, ErrorPath path) {
}

@SuppressWarnings({"unchecked", "TypeParameterUnusedInFormals"})
private static <T> T deserializeToType(Object o, Type type, DeserializeContext context) {
private static <T> T deserializeToType(
Object o, Type type, TypeMapper typeMapper, DeserializeContext context) {
if (o == null) {
return null;
} else if (type instanceof ParameterizedType) {
return deserializeToParameterizedType(o, (ParameterizedType) type, context);
return deserializeToParameterizedType(o, (ParameterizedType) type, typeMapper, context);
} else if (type instanceof Class) {
return deserializeToClass(o, (Class<T>) type, context);
} else if (type instanceof WildcardType) {
Expand All @@ -205,12 +207,12 @@ private static <T> T deserializeToType(Object o, Type type, DeserializeContext c
// has at least an upper bound of Object.
Type[] upperBounds = ((WildcardType) type).getUpperBounds();
hardAssert(upperBounds.length > 0, "Unexpected type bounds on wildcard " + type);
return deserializeToType(o, upperBounds[0], context);
return deserializeToType(o, upperBounds[0], typeMapper, context);
} else if (type instanceof TypeVariable) {
// As above, TypeVariables always have at least one upper bound of Object.
Type[] upperBounds = ((TypeVariable<?>) type).getBounds();
hardAssert(upperBounds.length > 0, "Unexpected type bounds on type variable " + type);
return deserializeToType(o, upperBounds[0], context);
return deserializeToType(o, upperBounds[0], typeMapper, context);

} else if (type instanceof GenericArrayType) {
throw deserializeError(
Expand Down Expand Up @@ -258,11 +260,11 @@ private static <T> T deserializeToClass(Object o, Class<T> clazz, DeserializeCon

@SuppressWarnings({"unchecked", "TypeParameterUnusedInFormals"})
private static <T> T deserializeToParameterizedType(
Object o, ParameterizedType type, DeserializeContext context) {
Object o, ParameterizedType type, TypeMapper typeMapper, DeserializeContext context) {
// getRawType should always return a Class<?>
Class<?> rawType = (Class<?>) type.getRawType();
if (List.class.isAssignableFrom(rawType)) {
Type genericType = type.getActualTypeArguments()[0];
Type genericType = typeMapper.resolve(type.getActualTypeArguments()[0], true);
if (o instanceof List) {
List<Object> list = (List<Object>) o;
List<Object> result = new ArrayList<>(list.size());
Expand All @@ -271,6 +273,7 @@ private static <T> T deserializeToParameterizedType(
deserializeToType(
list.get(i),
genericType,
typeMapper,
context.newInstanceWithErrorPath(context.errorPath.child("[" + i + "]"))));
}
return (T) result;
Expand All @@ -279,7 +282,7 @@ private static <T> T deserializeToParameterizedType(
}
} else if (Map.class.isAssignableFrom(rawType)) {
Type keyType = type.getActualTypeArguments()[0];
Type valueType = type.getActualTypeArguments()[1];
Type valueType = typeMapper.resolve(type.getActualTypeArguments()[1], true);
if (!keyType.equals(String.class)) {
throw deserializeError(
context.errorPath,
Expand All @@ -293,6 +296,7 @@ private static <T> T deserializeToParameterizedType(
deserializeToType(
entry.getValue(),
valueType,
typeMapper,
context.newInstanceWithErrorPath(context.errorPath.child(entry.getKey()))));
}
return (T) result;
Expand All @@ -302,16 +306,8 @@ private static <T> T deserializeToParameterizedType(
} else {
Map<String, Object> map = expectMap(o, context);
BeanMapper<T> mapper = (BeanMapper<T>) loadOrCreateBeanMapperForClass(rawType);
HashMap<TypeVariable<Class<T>>, Type> typeMapping = new HashMap<>();
TypeVariable<Class<T>>[] typeVariables = mapper.clazz.getTypeParameters();
Type[] types = type.getActualTypeArguments();
if (types.length != typeVariables.length) {
throw new IllegalStateException("Mismatched lengths for type variables and actual types");
}
for (int i = 0; i < typeVariables.length; i++) {
typeMapping.put(typeVariables[i], types[i]);
}
return mapper.deserialize(map, typeMapping, context);
return mapper.deserialize(
map, TypeMapper.of(TypeMapper.of(type, mapper.clazz), typeMapper), context);
}
}

Expand Down Expand Up @@ -586,6 +582,7 @@ private static class BeanMapper<T> {
private final Map<String, Method> getters;
private final Map<String, Method> setters;
private final Map<String, Field> fields;
private final Map<Member, TypeMapper> _typeMappers = new HashMap();

// A set of property names that were annotated with @ServerTimestamp.
private final HashSet<String> serverTimestamps;
Expand Down Expand Up @@ -648,6 +645,7 @@ private static class BeanMapper<T> {
// getMethods/getFields only returns public methods/fields we need to traverse the
// class hierarchy to find the appropriate setter or field.
Class<? super T> currentClass = clazz;
TypeMapper typeMapper = TypeMapper.empty();
do {
// Add any setters
for (Method method : currentClass.getDeclaredMethods()) {
Expand All @@ -666,6 +664,7 @@ private static class BeanMapper<T> {
if (existingSetter == null) {
method.setAccessible(true);
setters.put(propertyName, method);
_typeMappers.put(method, typeMapper);
applySetterAnnotations(method);
} else if (!isSetterOverride(method, existingSetter)) {
// We require that setters with conflicting property names are
Expand Down Expand Up @@ -703,13 +702,22 @@ private static class BeanMapper<T> {
&& !fields.containsKey(propertyName)) {
field.setAccessible(true);
fields.put(propertyName, field);
_typeMappers.put(field, typeMapper);
applyFieldAnnotations(field);
}
}

Class<? super T> superclass = currentClass.getSuperclass();
Type genericSuperclass = currentClass.getGenericSuperclass();
if (genericSuperclass instanceof ParameterizedType) {
typeMapper = TypeMapper.of((ParameterizedType) genericSuperclass, superclass);
} else {
typeMapper = TypeMapper.empty();
}

// Traverse class hierarchy until we reach java.lang.Object which contains a bunch
// of fields/getters we don't want to serialize
currentClass = currentClass.getSuperclass();
currentClass = superclass;
} while (currentClass != null && !currentClass.equals(Object.class));

if (properties.isEmpty()) {
Expand Down Expand Up @@ -740,13 +748,10 @@ private void addProperty(String property) {
}

T deserialize(Map<String, Object> values, DeserializeContext context) {
return deserialize(values, Collections.emptyMap(), context);
return deserialize(values, TypeMapper.empty(), context);
}

T deserialize(
Map<String, Object> values,
Map<TypeVariable<Class<T>>, Type> types,
DeserializeContext context) {
T deserialize(Map<String, Object> values, TypeMapper typeMapper, DeserializeContext context) {
if (constructor == null) {
throw deserializeError(
context.errorPath,
Expand All @@ -767,18 +772,26 @@ T deserialize(
if (params.length != 1) {
throw deserializeError(childPath, "Setter does not have exactly one parameter");
}
Type resolvedType = resolveType(params[0], types);
TypeMapper setterTypeMapper = TypeMapper.of(_typeMappers.get(setter), typeMapper);
Type resolvedType = setterTypeMapper.resolve(params[0], false);
Object value =
CustomClassMapper.deserializeToType(
entry.getValue(), resolvedType, context.newInstanceWithErrorPath(childPath));
entry.getValue(),
resolvedType,
setterTypeMapper,
context.newInstanceWithErrorPath(childPath));
invoke(setter, instance, value);
deserialzedProperties.add(propertyName);
} else if (fields.containsKey(propertyName)) {
Field field = fields.get(propertyName);
Type resolvedType = resolveType(field.getGenericType(), types);
TypeMapper fieldTypeMapper = TypeMapper.of(_typeMappers.get(field), typeMapper);
Type resolvedType = fieldTypeMapper.resolve(field.getGenericType(), false);
Object value =
CustomClassMapper.deserializeToType(
entry.getValue(), resolvedType, context.newInstanceWithErrorPath(childPath));
entry.getValue(),
resolvedType,
fieldTypeMapper,
context.newInstanceWithErrorPath(childPath));
try {
field.set(instance, value);
} catch (IllegalAccessException e) {
Expand All @@ -798,7 +811,7 @@ T deserialize(
}
}
}
populateDocumentIdProperties(types, context, instance, deserialzedProperties);
populateDocumentIdProperties(typeMapper, context, instance, deserialzedProperties);

return instance;
}
Expand All @@ -807,7 +820,7 @@ T deserialize(
// applied to a property that is already deserialized from the firestore document)
// a runtime exception will be thrown.
private void populateDocumentIdProperties(
Map<TypeVariable<Class<T>>, Type> types,
TypeMapper typeMapper,
DeserializeContext context,
T instance,
HashSet<String> deserialzedProperties) {
Expand All @@ -829,16 +842,20 @@ private void populateDocumentIdProperties(
if (params.length != 1) {
throw deserializeError(childPath, "Setter does not have exactly one parameter");
}
Type resolvedType = resolveType(params[0], types);
Type resolvedType =
TypeMapper.of(_typeMappers.get(setter), typeMapper).resolve(params[0], false);
if (resolvedType == String.class) {
invoke(setter, instance, context.documentRef.getId());
} else {
invoke(setter, instance, context.documentRef);
}
} else {
Field docIdField = fields.get(docIdPropertyName);
Type resolvedType =
TypeMapper.of(_typeMappers.get(docIdField), typeMapper)
.resolve(docIdField.getType(), false);
try {
if (docIdField.getType() == String.class) {
if (resolvedType == String.class) {
docIdField.set(instance, context.documentRef.getId());
} else {
docIdField.set(instance, context.documentRef);
Expand All @@ -850,19 +867,6 @@ private void populateDocumentIdProperties(
}
}

private Type resolveType(Type type, Map<TypeVariable<Class<T>>, Type> types) {
if (type instanceof TypeVariable) {
Type resolvedType = types.get(type);
if (resolvedType == null) {
throw new IllegalStateException("Could not resolve type " + type);
} else {
return resolvedType;
}
} else {
return type;
}
}

Map<String, Object> serialize(T object, ErrorPath path) {
// TODO(wuandy): Add logic to skip @DocumentId annotated fields in serialization.
if (!clazz.isAssignableFrom(object.getClass())) {
Expand Down Expand Up @@ -1121,6 +1125,65 @@ private static String serializedName(String methodName) {
}
}


/**
* Resolves generic type variables to actual types, based on context. Special-cases collection
* types, for backward compatibility: If the containing class is annotated with {@link
* StrictCollectionTypes}, then resolution is performed normally. Otherwise, the type is not
* resolved, allowing the collection to take values of any type.
*/
private interface TypeMapper {
TypeMapper EMPTY = (typeVariable, collection) -> null;

static TypeMapper of(ParameterizedType type, Class<?> clazz) {
TypeVariable<? extends Class<?>>[] typeVariables = clazz.getTypeParameters();
Type[] types = type.getActualTypeArguments();
if (types.length != typeVariables.length) {
throw new IllegalStateException("Mismatched lengths for type variables and actual types");
}
Map<TypeVariable<?>, Type> typeMapping = new HashMap<>();
for (int i = 0; i < typeVariables.length; i++) {
typeMapping.put(typeVariables[i], types[i]);
}

boolean strictCollectionTypes = clazz.isAnnotationPresent(StrictCollectionTypes.class);
return (typeVariable, collection) -> collection && ! strictCollectionTypes? typeVariable : typeMapping.get(typeVariable);
}

static TypeMapper of(TypeMapper innerMapper, TypeMapper outerMapper) {
return (typeVariable, collection) -> map(typeVariable, innerMapper, outerMapper, collection);
}

static TypeMapper empty() {
return EMPTY;
}

default Type resolve(Type type, boolean collection) {
if (type instanceof TypeVariable) {
Type resolvedType = get((TypeVariable<?>) type, collection);
if (resolvedType == null) {
throw new IllegalStateException("Could not resolve type " + type);
}

return resolvedType;
}

return type;
}

static Type map(TypeVariable<?> typeVariable, TypeMapper innerMapper, TypeMapper outerMapper, boolean collection) {
Type type = innerMapper.get(typeVariable, collection);
if (type != null) {
return type;
}

return outerMapper.get(typeVariable, collection);
}

Type get(TypeVariable<?> typeVariable, boolean collection);
}


/**
* Immutable class representing the path to a specific field in an object. Used to provide better
* error messages.
Expand Down

0 comments on commit 6bed512

Please sign in to comment.