Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve vararg handling #2805

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/main/java/org/mockito/ArgumentCaptor.java
Expand Up @@ -62,11 +62,12 @@
@CheckReturnValue
public class ArgumentCaptor<T> {

private final CapturingMatcher<T> capturingMatcher = new CapturingMatcher<T>();
private final CapturingMatcher<T> capturingMatcher;
private final Class<? extends T> clazz;

private ArgumentCaptor(Class<? extends T> clazz) {
this.clazz = clazz;
this.capturingMatcher = new CapturingMatcher<T>(clazz);
}

/**
Expand Down
Expand Up @@ -9,6 +9,8 @@
import org.mockito.ArgumentMatcher;
import org.mockito.internal.matchers.VarargMatcher;

import java.util.Optional;

public class HamcrestArgumentMatcher<T> implements ArgumentMatcher<T> {

private final Matcher matcher;
Expand All @@ -26,6 +28,10 @@ public boolean isVarargMatcher() {
return matcher instanceof VarargMatcher;
}

public Optional<VarargMatcher> varargMatcher() {
return isVarargMatcher() ? Optional.of((VarargMatcher) matcher) : Optional.empty();
}

@Override
public String toString() {
// TODO SF add unit tests and integ test coverage for toString()
Expand Down
Expand Up @@ -6,6 +6,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import org.mockito.ArgumentMatcher;
import org.mockito.internal.hamcrest.HamcrestArgumentMatcher;
Expand Down Expand Up @@ -58,14 +59,25 @@ public static MatcherApplicationStrategy getMatcherApplicationStrategyFor(
* </ul>
*/
public boolean forEachMatcherAndArgument(ArgumentMatcherAction action) {
if (invocation.getArguments().length == matchers.size()) {
return argsMatch(invocation.getArguments(), matchers, action);
}

final boolean isVararg =
invocation.getMethod().isVarArgs()
&& invocation.getRawArguments().length == matchers.size()
&& isLastMatcherVarargMatcher(matchers);
&& getLastMatcherVarargMatcherType(matchers).isPresent();

if (isVararg) {
final Class<?> matcherType = getLastMatcherVarargMatcherType(matchers).get();
final Class<?> paramType =
invocation.getMethod()
.getParameterTypes()[
invocation.getMethod().getParameterTypes().length - 1];
if (paramType.isAssignableFrom(matcherType)) {
return argsMatch(invocation.getRawArguments(), matchers, action);
}
}

if (invocation.getArguments().length == matchers.size()) {
return argsMatch(invocation.getArguments(), matchers, action);
}

if (isVararg) {
int times = varargLength(invocation);
Expand All @@ -91,12 +103,17 @@ private boolean argsMatch(
return true;
}

private static boolean isLastMatcherVarargMatcher(List<? extends ArgumentMatcher<?>> matchers) {
private static Optional<Class<?>> getLastMatcherVarargMatcherType(
final List<? extends ArgumentMatcher<?>> matchers) {
ArgumentMatcher<?> argumentMatcher = lastMatcher(matchers);
if (argumentMatcher instanceof HamcrestArgumentMatcher<?>) {
return ((HamcrestArgumentMatcher<?>) argumentMatcher).isVarargMatcher();
return ((HamcrestArgumentMatcher<?>) argumentMatcher)
.varargMatcher()
.map(VarargMatcher::type);
}
return argumentMatcher instanceof VarargMatcher;
return argumentMatcher instanceof VarargMatcher
? Optional.of((VarargMatcher) argumentMatcher).map(VarargMatcher::type)
: Optional.empty();
}

private List<? extends ArgumentMatcher<?>> appendLastMatcherNTimes(
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/mockito/internal/matchers/Any.java
Expand Up @@ -21,4 +21,9 @@ public boolean matches(Object actual) {
public String toString() {
return "<any>";
}

@Override
public Class<?> type() {
return Object.class;
}
}
10 changes: 10 additions & 0 deletions src/main/java/org/mockito/internal/matchers/CapturingMatcher.java
Expand Up @@ -19,12 +19,17 @@
public class CapturingMatcher<T>
implements ArgumentMatcher<T>, CapturesArguments, VarargMatcher, Serializable {

private final Class<? extends T> clazz;
private final List<Object> arguments = new ArrayList<>();

private final ReadWriteLock lock = new ReentrantReadWriteLock();
private final Lock readLock = lock.readLock();
private final Lock writeLock = lock.writeLock();

public CapturingMatcher(final Class<? extends T> clazz) {
this.clazz = clazz;
}

@Override
public boolean matches(Object argument) {
return true;
Expand Down Expand Up @@ -66,4 +71,9 @@ public void captureFrom(Object argument) {
writeLock.unlock();
}
}

@Override
public Class<?> type() {
return clazz;
}
}
7 changes: 6 additions & 1 deletion src/main/java/org/mockito/internal/matchers/InstanceOf.java
Expand Up @@ -11,7 +11,7 @@

public class InstanceOf implements ArgumentMatcher<Object>, Serializable {

private final Class<?> clazz;
final Class<?> clazz;
private final String description;

public InstanceOf(Class<?> clazz) {
Expand Down Expand Up @@ -44,5 +44,10 @@ public VarArgAware(Class<?> clazz) {
public VarArgAware(Class<?> clazz, String describedAs) {
super(clazz, describedAs);
}

@Override
public Class<?> type() {
return clazz;
}
}
}
42 changes: 41 additions & 1 deletion src/main/java/org/mockito/internal/matchers/VarargMatcher.java
Expand Up @@ -10,4 +10,44 @@
* Internal interface that informs Mockito that the matcher is intended to capture varargs.
* This information is needed when mockito collects the arguments.
*/
public interface VarargMatcher extends Serializable {}
public interface VarargMatcher extends Serializable {

/**
* The type of the argument the matcher matches.
*
* <p>If a vararg aware matcher:
* <ul>
* <li>is at the parameter index of a vararg parameter</li>
* <li>is the last matcher passed</li>
* <li>matchers the raw type of the vararg parameter</li>
* </ul>
*
* Then the matcher is matched against the vararg raw parameter.
* Otherwise, the matcher will be matched against each element in the vararg raw parameters.
*
* <p>For example:
*
* <pre class="code"><code class="java">
* // Given vararg method with signature:
* int someVarargMethod(int x, String... args);
*
* // The following will match the last matcher against the contents of the `args` array:
* (as the above criteria are met)
* mock.someVarargMethod(eq(1), any(String[].class));
*
* // The following will match the last matcher against each element of the `args` array:
* // (as the type of the last matcher does not match the raw type of the vararg parameter)
* mock.someVarargMethod(eq(1), any(String.class));
*
* // The following will match only invocations with two strings in the 'args' array:
* // (as there are more matchers than raw arguments)
* mock.someVarargMethod(eq(1), any(), any());
* </code></pre>
*
* @return the type this matcher handles.
* @since 4.10.0
*/
default Class<?> type() {
return Void.class;
}
}
Expand Up @@ -136,7 +136,7 @@ public void should_be_similar_if_is_overloaded_but_used_with_different_arg() thr
public void should_capture_arguments_from_invocation() throws Exception {
// given
Invocation invocation = new InvocationBuilder().args("1", 100).toInvocation();
CapturingMatcher capturingMatcher = new CapturingMatcher();
CapturingMatcher capturingMatcher = new CapturingMatcher(List.class);
InvocationMatcher invocationMatcher =
new InvocationMatcher(invocation, (List) asList(new Equals("1"), capturingMatcher));

Expand Down Expand Up @@ -167,7 +167,7 @@ public void should_capture_varargs_as_vararg() throws Exception {
// given
mock.mixedVarargs(1, "a", "b");
Invocation invocation = getLastInvocation();
CapturingMatcher m = new CapturingMatcher();
CapturingMatcher m = new CapturingMatcher(List.class);
InvocationMatcher invocationMatcher =
new InvocationMatcher(invocation, Arrays.<ArgumentMatcher>asList(new Equals(1), m));

Expand Down
Expand Up @@ -9,6 +9,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.internal.invocation.MatcherApplicationStrategy.getMatcherApplicationStrategyFor;
import static org.mockito.internal.matchers.Any.ANY;

Expand Down Expand Up @@ -225,12 +226,32 @@ public void shouldMatchAnyEvenIfMatcherIsWrappedInHamcrestMatcher() {
recordAction.assertContainsExactly(argumentMatcher, argumentMatcher);
}

@Test
public void shouldMatchAnyThatMatchesRawVarArgType() {
// given
invocation = varargs("1", "2");
InstanceOf.VarArgAware any = new InstanceOf.VarArgAware(String[].class, "<any String[]>");
matchers = asList(any);

// when
getMatcherApplicationStrategyFor(invocation, matchers)
.forEachMatcherAndArgument(recordAction);

// then
recordAction.assertContainsExactly(any);
}

private static class IntMatcher extends BaseMatcher<Integer> implements VarargMatcher {
public boolean matches(Object o) {
return true;
}

public void describeTo(Description description) {}

@Override
public Class<?> type() {
return Integer.class;
}
}

private Invocation mixedVarargs(Object a, String... s) {
Expand Down
Expand Up @@ -19,7 +19,7 @@ public class CapturingMatcherTest extends TestBase {
@Test
public void should_capture_arguments() throws Exception {
// given
CapturingMatcher<String> m = new CapturingMatcher<String>();
CapturingMatcher<String> m = new CapturingMatcher<String>(String.class);

// when
m.captureFrom("foo");
Expand All @@ -32,7 +32,7 @@ public void should_capture_arguments() throws Exception {
@Test
public void should_know_last_captured_value() throws Exception {
// given
CapturingMatcher<String> m = new CapturingMatcher<String>();
CapturingMatcher<String> m = new CapturingMatcher<String>(String.class);

// when
m.captureFrom("foo");
Expand All @@ -45,7 +45,7 @@ public void should_know_last_captured_value() throws Exception {
@Test
public void should_scream_when_nothing_yet_captured() throws Exception {
// given
CapturingMatcher<String> m = new CapturingMatcher<String>();
CapturingMatcher<String> m = new CapturingMatcher<String>(String.class);

try {
// when
Expand All @@ -59,7 +59,7 @@ public void should_scream_when_nothing_yet_captured() throws Exception {
@Test
public void should_not_fail_when_used_in_concurrent_tests() throws Exception {
// given
final CapturingMatcher<String> m = new CapturingMatcher<String>();
final CapturingMatcher<String> m = new CapturingMatcher<String>(String.class);

// when
m.captureFrom("concurrent access");
Expand Down
4 changes: 4 additions & 0 deletions src/test/java/org/mockitousage/IMethods.java
Expand Up @@ -199,6 +199,10 @@ String sixArgumentVarArgsMethod(

Object[] mixedVarargsReturningObjectArray(Object i, String... string);

String methodWithVarargAndNonVarargVariants(String string);

String methodWithVarargAndNonVarargVariants(String... string);

List<String> listReturningMethod(Object... objects);

LinkedList<String> linkedListReturningMethod();
Expand Down
10 changes: 10 additions & 0 deletions src/test/java/org/mockitousage/MethodsImpl.java
Expand Up @@ -376,6 +376,16 @@ public Object[] mixedVarargsReturningObjectArray(Object i, String... string) {
return null;
}

@Override
public String methodWithVarargAndNonVarargVariants(String string) {
return "plain";
}

@Override
public String methodWithVarargAndNonVarargVariants(String... string) {
return "varargs";
}

public void varargsbyte(byte... bytes) {}

public List<String> listReturningMethod(Object... objects) {
Expand Down