Skip to content

Commit

Permalink
Support deep comparison of unpacked Any messages in FieldNumberTree.
Browse files Browse the repository at this point in the history
RELNOTES=Fixed a bug that caused ProtoTruth to ignore the contents of unpacked `Any` messages. This fix may cause tests to fail, since ProtoTruth will now check whether the message contents match. If so, you may need to change the values that your tests expect, or there may be a bug in the code under test that had been hidden by the Truth bug. Sorry for the trouble.
PiperOrigin-RevId: 577171522
  • Loading branch information
java-team-github-bot authored and Google Java Core Libraries committed Oct 27, 2023
1 parent a12d848 commit 8bd3ef6
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 34 deletions.
Expand Up @@ -69,8 +69,9 @@ static ExtensionRegistry defaultExtensionRegistry() {
return DEFAULT_EXTENSION_REGISTRY;
}

/** Unpack an `Any` proto using the TypeRegistry and ExtensionRegistry on `config`. */
static Optional<Message> unpack(Message any, FluentEqualityConfig config) {
/** Unpack an `Any` proto using the given TypeRegistry and ExtensionRegistry. */
static Optional<Message> unpack(
Message any, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) {
Preconditions.checkArgument(
any.getDescriptorForType().equals(Any.getDescriptor()),
"Expected type google.protobuf.Any, but was: %s",
Expand All @@ -80,13 +81,12 @@ static Optional<Message> unpack(Message any, FluentEqualityConfig config) {
ByteString value = (ByteString) any.getField(valueFieldDescriptor());

try {
Descriptor descriptor = config.useTypeRegistry().getDescriptorForTypeUrl(typeUrl);
Descriptor descriptor = typeRegistry.getDescriptorForTypeUrl(typeUrl);
if (descriptor == null) {
return Optional.absent();
}

Message defaultMessage =
DynamicMessage.parseFrom(descriptor, value, config.useExtensionRegistry());
Message defaultMessage = DynamicMessage.parseFrom(descriptor, value, extensionRegistry);
return Optional.of(defaultMessage);
} catch (InvalidProtocolBufferException e) {
return Optional.absent();
Expand Down
Expand Up @@ -557,6 +557,9 @@ default void printFieldValue(SubScopeId subScopeId, Object o, StringBuilder sb)
case UNKNOWN_FIELD_DESCRIPTOR:
printFieldValue(subScopeId.unknownFieldDescriptor(), o, sb);
return;
case UNPACKED_ANY_VALUE_TYPE:
printFieldValue(AnyUtils.valueFieldDescriptor(), o, sb);
return;
}
throw new AssertionError(subScopeId.kind());
}
Expand Down
Expand Up @@ -16,9 +16,12 @@

package com.google.common.truth.extensions.proto;

import com.google.common.base.Optional;
import com.google.common.collect.Maps;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.TypeRegistry;
import com.google.protobuf.UnknownFieldSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -62,7 +65,8 @@ boolean hasChild(SubScopeId subScopeId) {
return children.containsKey(subScopeId);
}

static FieldNumberTree fromMessage(Message message) {
static FieldNumberTree fromMessage(
Message message, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) {
FieldNumberTree tree = new FieldNumberTree();

// Known fields.
Expand All @@ -72,15 +76,25 @@ static FieldNumberTree fromMessage(Message message) {
FieldNumberTree childTree = new FieldNumberTree();
tree.children.put(subScopeId, childTree);

Object fieldValue = knownFieldValues.get(field);
if (field.getJavaType() == FieldDescriptor.JavaType.MESSAGE) {
if (field.isRepeated()) {
List<?> valueList = (List<?>) fieldValue;
for (Object value : valueList) {
childTree.merge(fromMessage((Message) value));
if (field.equals(AnyUtils.valueFieldDescriptor())) {
// Handle Any protos specially.
Optional<Message> unpackedAny = AnyUtils.unpack(message, typeRegistry, extensionRegistry);
if (unpackedAny.isPresent()) {
tree.children.put(
SubScopeId.ofUnpackedAnyValueType(unpackedAny.get().getDescriptorForType()),
fromMessage(unpackedAny.get(), typeRegistry, extensionRegistry));
}
} else {
Object fieldValue = knownFieldValues.get(field);
if (field.getJavaType() == FieldDescriptor.JavaType.MESSAGE) {
if (field.isRepeated()) {
List<?> valueList = (List<?>) fieldValue;
for (Object value : valueList) {
childTree.merge(fromMessage((Message) value, typeRegistry, extensionRegistry));
}
} else {
childTree.merge(fromMessage((Message) fieldValue, typeRegistry, extensionRegistry));
}
} else {
childTree.merge(fromMessage((Message) fieldValue));
}
}
}
Expand All @@ -91,11 +105,14 @@ static FieldNumberTree fromMessage(Message message) {
return tree;
}

static FieldNumberTree fromMessages(Iterable<? extends Message> messages) {
static FieldNumberTree fromMessages(
Iterable<? extends Message> messages,
TypeRegistry typeRegistry,
ExtensionRegistry extensionRegistry) {
FieldNumberTree tree = new FieldNumberTree();
for (Message message : messages) {
if (message != null) {
tree.merge(fromMessage(message));
tree.merge(fromMessage(message, typeRegistry, extensionRegistry));
}
}
return tree;
Expand Down
Expand Up @@ -28,7 +28,9 @@
import com.google.common.collect.Lists;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.TypeRegistry;
import java.util.List;

/**
Expand Down Expand Up @@ -62,13 +64,17 @@ private static FieldScope create(
// Instantiation methods.
//////////////////////////////////////////////////////////////////////////////////////////////////

static FieldScope createFromSetFields(Message message) {
static FieldScope createFromSetFields(
Message message, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) {
return create(
FieldScopeLogic.partialScope(message),
FieldScopeLogic.partialScope(message, typeRegistry, extensionRegistry),
Functions.constant(String.format("FieldScopes.fromSetFields({%s})", message.toString())));
}

static FieldScope createFromSetFields(Iterable<? extends Message> messages) {
static FieldScope createFromSetFields(
Iterable<? extends Message> messages,
TypeRegistry typeRegistry,
ExtensionRegistry extensionRegistry) {
if (emptyOrAllNull(messages)) {
return create(
FieldScopeLogic.none(),
Expand All @@ -82,7 +88,8 @@ static FieldScope createFromSetFields(Iterable<? extends Message> messages) {
getDescriptors(messages));

return create(
FieldScopeLogic.partialScope(messages, optDescriptor.get()),
FieldScopeLogic.partialScope(
messages, optDescriptor.get(), typeRegistry, extensionRegistry),
Functions.constant(String.format("FieldScopes.fromSetFields(%s)", formatList(messages))));
}

Expand Down
Expand Up @@ -28,7 +28,9 @@
import com.google.errorprone.annotations.ForOverride;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.TypeRegistry;
import java.util.List;

/**
Expand Down Expand Up @@ -267,14 +269,21 @@ public String toString() {
}
}

static FieldScopeLogic partialScope(Message message) {
static FieldScopeLogic partialScope(
Message message, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) {
return new RootPartialScopeLogic(
FieldNumberTree.fromMessage(message), message.toString(), message.getDescriptorForType());
FieldNumberTree.fromMessage(message, typeRegistry, extensionRegistry),
message.toString(),
message.getDescriptorForType());
}

static FieldScopeLogic partialScope(Iterable<? extends Message> messages, Descriptor descriptor) {
static FieldScopeLogic partialScope(
Iterable<? extends Message> messages,
Descriptor descriptor,
TypeRegistry typeRegistry,
ExtensionRegistry extensionRegistry) {
return new RootPartialScopeLogic(
FieldNumberTree.fromMessages(messages),
FieldNumberTree.fromMessages(messages, typeRegistry, extensionRegistry),
Joiner.on(", ").useForNull("null").join(messages),
descriptor);
}
Expand Down Expand Up @@ -304,11 +313,18 @@ protected FieldMatcherLogicBase(boolean isRecursive) {

@Override
final FieldScopeResult policyFor(Descriptor rootDescriptor, SubScopeId subScopeId) {
if (subScopeId.kind() == SubScopeId.Kind.UNKNOWN_FIELD_DESCRIPTOR) {
return FieldScopeResult.EXCLUDED_RECURSIVELY;
FieldDescriptor fieldDescriptor = null;
switch (subScopeId.kind()) {
case FIELD_DESCRIPTOR:
fieldDescriptor = subScopeId.fieldDescriptor();
break;
case UNPACKED_ANY_VALUE_TYPE:
fieldDescriptor = AnyUtils.valueFieldDescriptor();
break;
case UNKNOWN_FIELD_DESCRIPTOR:
return FieldScopeResult.EXCLUDED_RECURSIVELY;
}

FieldDescriptor fieldDescriptor = subScopeId.fieldDescriptor();
if (matchesFieldDescriptor(rootDescriptor, fieldDescriptor)) {
return FieldScopeResult.of(/* included = */ true, isRecursive);
}
Expand Down
Expand Up @@ -19,7 +19,9 @@
import static com.google.common.truth.extensions.proto.FieldScopeUtil.asList;

import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.TypeRegistry;

/** Factory class for {@link FieldScope} instances. */
public final class FieldScopes {
Expand Down Expand Up @@ -66,7 +68,58 @@ public final class FieldScopes {
// Alternatively II, add Scope.PARTIAL support to ProtoFluentEquals, but with a different name and
// explicit documentation that it may cause issues with Proto 3.
public static FieldScope fromSetFields(Message message) {
return FieldScopeImpl.createFromSetFields(message);
return fromSetFields(
message, AnyUtils.defaultTypeRegistry(), AnyUtils.defaultExtensionRegistry());
}

/**
* Returns a {@link FieldScope} which is constrained to precisely those specific field paths that
* are explicitly set in the message. Note that, for version 3 protobufs, such a {@link
* FieldScope} will omit fields in the provided message which are set to default values.
*
* <p>This can be used limit the scope of a comparison to a complex set of fields in a very brief
* statement. Often, {@code message} is the expected half of a comparison about to be performed.
*
* <p>Example usage:
*
* <pre>{@code
* Foo actual = Foo.newBuilder().setBar(3).setBaz(4).build();
* Foo expected = Foo.newBuilder().setBar(3).setBaz(5).build();
* // Fails, because actual.getBaz() != expected.getBaz().
* assertThat(actual).isEqualTo(expected);
*
* Foo scope = Foo.newBuilder().setBar(2).build();
* // Succeeds, because only the field 'bar' is compared.
* assertThat(actual).withPartialScope(FieldScopes.fromSetFields(scope)).isEqualTo(expected);
*
* }</pre>
*
* <p>The returned {@link FieldScope} does not respect repeated field indices nor map keys. For
* example, if the provided message sets different field values for different elements of a
* repeated field, like so:
*
* <pre>{@code
* sub_message: {
* foo: "foo"
* }
* sub_message: {
* bar: "bar"
* }
* }</pre>
*
* <p>The {@link FieldScope} will contain {@code sub_message.foo} and {@code sub_message.bar} for
* *all* repeated {@code sub_messages}, including those beyond index 1.
*
* <p>If there are {@code google.protobuf.Any} protos anywhere within these messages, they will be
* unpacked using the provided {@link TypeRegistry} and {@link ExtensionRegistry} to determine
* which fields within them should be compared.
*
* @see ProtoFluentAssertion#unpackingAnyUsing
* @since 1.2
*/
public static FieldScope fromSetFields(
Message message, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) {
return FieldScopeImpl.createFromSetFields(message, typeRegistry, extensionRegistry);
}

/**
Expand All @@ -89,7 +142,29 @@ public static FieldScope fromSetFields(
* or the {@link FieldScope} for the merge of all the messages. These are equivalent.
*/
public static FieldScope fromSetFields(Iterable<? extends Message> messages) {
return FieldScopeImpl.createFromSetFields(messages);
return fromSetFields(
messages, AnyUtils.defaultTypeRegistry(), AnyUtils.defaultExtensionRegistry());
}

/**
* Creates a {@link FieldScope} covering the fields set in every message in the provided list of
* messages, with the same semantics as in {@link #fromSetFields(Message)}.
*
* <p>This can be thought of as the union of the {@link FieldScope}s for each individual message,
* or the {@link FieldScope} for the merge of all the messages. These are equivalent.
*
* <p>If there are {@code google.protobuf.Any} protos anywhere within these messages, they will be
* unpacked using the provided {@link TypeRegistry} and {@link ExtensionRegistry} to determine
* which fields within them should be compared.
*
* @see ProtoFluentAssertion#unpackingAnyUsing
* @since 1.2
*/
public static FieldScope fromSetFields(
Iterable<? extends Message> messages,
TypeRegistry typeRegistry,
ExtensionRegistry extensionRegistry) {
return FieldScopeImpl.createFromSetFields(messages, typeRegistry, extensionRegistry);
}

/**
Expand Down
Expand Up @@ -273,7 +273,11 @@ final FluentEqualityConfig withExpectedMessages(Iterable<? extends Message> mess
Builder builder = toBuilder().setHasExpectedMessages(true);
if (compareExpectedFieldsOnly()) {
builder.setCompareFieldsScope(
FieldScopeLogic.and(compareFieldsScope(), FieldScopes.fromSetFields(messages).logic()));
FieldScopeLogic.and(
compareFieldsScope(),
FieldScopeImpl.createFromSetFields(
messages, useTypeRegistry(), useExtensionRegistry())
.logic()));
}
return builder.build();
}
Expand Down
Expand Up @@ -221,8 +221,10 @@ private DiffResult diffAnyMessages(
if (shouldCompareValue == FieldScopeResult.EXCLUDED_RECURSIVELY) {
valueDiffResult = SingularField.ignored(name(AnyUtils.valueFieldDescriptor()));
} else {
Optional<Message> unpackedActual = AnyUtils.unpack(actual, config);
Optional<Message> unpackedExpected = AnyUtils.unpack(expected, config);
Optional<Message> unpackedActual =
AnyUtils.unpack(actual, config.useTypeRegistry(), config.useExtensionRegistry());
Optional<Message> unpackedExpected =
AnyUtils.unpack(expected, config.useTypeRegistry(), config.useExtensionRegistry());
if (unpackedActual.isPresent()
&& unpackedExpected.isPresent()
&& descriptorsMatch(unpackedActual.get(), unpackedExpected.get())) {
Expand All @@ -235,7 +237,10 @@ && descriptorsMatch(unpackedActual.get(), unpackedExpected.get())) {
shouldCompareValue == FieldScopeResult.EXCLUDED_NONRECURSIVELY,
AnyUtils.valueFieldDescriptor(),
name(AnyUtils.valueFieldDescriptor()),
config.subScope(rootDescriptor, AnyUtils.valueSubScopeId()));
config.subScope(
rootDescriptor,
SubScopeId.ofUnpackedAnyValueType(
unpackedActual.get().getDescriptorForType())));
} else {
valueDiffResult =
compareSingularValue(
Expand Down
Expand Up @@ -17,13 +17,15 @@
package com.google.common.truth.extensions.proto;

import com.google.auto.value.AutoOneOf;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;

@AutoOneOf(SubScopeId.Kind.class)
abstract class SubScopeId {
enum Kind {
FIELD_DESCRIPTOR,
UNKNOWN_FIELD_DESCRIPTOR;
UNKNOWN_FIELD_DESCRIPTOR,
UNPACKED_ANY_VALUE_TYPE;
}

abstract Kind kind();
Expand All @@ -32,6 +34,8 @@ enum Kind {

abstract UnknownFieldDescriptor unknownFieldDescriptor();

abstract Descriptor unpackedAnyValueType();

/** Returns a short, human-readable version of this identifier. */
final String shortName() {
switch (kind()) {
Expand All @@ -41,6 +45,8 @@ final String shortName() {
: fieldDescriptor().getName();
case UNKNOWN_FIELD_DESCRIPTOR:
return String.valueOf(unknownFieldDescriptor().fieldNumber());
case UNPACKED_ANY_VALUE_TYPE:
return AnyUtils.valueFieldDescriptor().getName();
}
throw new AssertionError(kind());
}
Expand All @@ -52,4 +58,8 @@ static SubScopeId of(FieldDescriptor fieldDescriptor) {
static SubScopeId of(UnknownFieldDescriptor unknownFieldDescriptor) {
return AutoOneOf_SubScopeId.unknownFieldDescriptor(unknownFieldDescriptor);
}

static SubScopeId ofUnpackedAnyValueType(Descriptor unpackedAnyValueType) {
return AutoOneOf_SubScopeId.unpackedAnyValueType(unpackedAnyValueType);
}
}

0 comments on commit 8bd3ef6

Please sign in to comment.