From 48a3014e4c3a364eec3efda0db0fedaf4e01c16e Mon Sep 17 00:00:00 2001 From: Big Andy <8012398+big-andy-coates@users.noreply.github.com> Date: Mon, 28 Nov 2022 17:51:07 +0000 Subject: [PATCH] Improve vararg handling Fixes: https://github.com/mockito/mockito/issues/2796 Add an optional method to `VarargMatcher`, which implementations can choose to override to return the type of object the matcher is matching. This is used by `MatcherApplicationStrategy` to determine if the type of matcher used to match a vararg parameter is of a type compatible with the vararg parameter. Where a vararg compatible matcher is found, the matcher is used to match the _raw_ parameters. --- src/main/java/org/mockito/ArgumentCaptor.java | 3 +- .../hamcrest/HamcrestArgumentMatcher.java | 6 + .../MatcherApplicationStrategy.java | 34 +++- .../org/mockito/internal/matchers/Any.java | 5 + .../internal/matchers/CapturingMatcher.java | 10 ++ .../mockito/internal/matchers/InstanceOf.java | 7 +- .../internal/matchers/VarargMatcher.java | 42 ++++- .../invocation/InvocationMatcherTest.java | 4 +- .../MatcherApplicationStrategyTest.java | 21 +++ .../matchers/CapturingMatcherTest.java | 8 +- src/test/java/org/mockitousage/IMethods.java | 4 + .../java/org/mockitousage/MethodsImpl.java | 10 ++ .../mockitousage/matchers/VarargsTest.java | 148 ++++++++++++++++-- 13 files changed, 274 insertions(+), 28 deletions(-) diff --git a/src/main/java/org/mockito/ArgumentCaptor.java b/src/main/java/org/mockito/ArgumentCaptor.java index afb3add2ba..ed9e398f91 100644 --- a/src/main/java/org/mockito/ArgumentCaptor.java +++ b/src/main/java/org/mockito/ArgumentCaptor.java @@ -62,11 +62,12 @@ @CheckReturnValue public class ArgumentCaptor { - private final CapturingMatcher capturingMatcher = new CapturingMatcher(); + private final CapturingMatcher capturingMatcher; private final Class clazz; private ArgumentCaptor(Class clazz) { this.clazz = clazz; + this.capturingMatcher = new CapturingMatcher(clazz); } /** diff --git a/src/main/java/org/mockito/internal/hamcrest/HamcrestArgumentMatcher.java b/src/main/java/org/mockito/internal/hamcrest/HamcrestArgumentMatcher.java index 99869faf73..a39b9159ea 100644 --- a/src/main/java/org/mockito/internal/hamcrest/HamcrestArgumentMatcher.java +++ b/src/main/java/org/mockito/internal/hamcrest/HamcrestArgumentMatcher.java @@ -9,6 +9,8 @@ import org.mockito.ArgumentMatcher; import org.mockito.internal.matchers.VarargMatcher; +import java.util.Optional; + public class HamcrestArgumentMatcher implements ArgumentMatcher { private final Matcher matcher; @@ -26,6 +28,10 @@ public boolean isVarargMatcher() { return matcher instanceof VarargMatcher; } + public Optional varargMatcher() { + return isVarargMatcher() ? Optional.of((VarargMatcher) matcher) : Optional.empty(); + } + @Override public String toString() { // TODO SF add unit tests and integ test coverage for toString() diff --git a/src/main/java/org/mockito/internal/invocation/MatcherApplicationStrategy.java b/src/main/java/org/mockito/internal/invocation/MatcherApplicationStrategy.java index 5e16f27040..796b7d88f7 100644 --- a/src/main/java/org/mockito/internal/invocation/MatcherApplicationStrategy.java +++ b/src/main/java/org/mockito/internal/invocation/MatcherApplicationStrategy.java @@ -4,8 +4,10 @@ */ package org.mockito.internal.invocation; +import java.lang.reflect.Type; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import org.mockito.ArgumentMatcher; import org.mockito.internal.hamcrest.HamcrestArgumentMatcher; @@ -58,14 +60,25 @@ public static MatcherApplicationStrategy getMatcherApplicationStrategyFor( * */ public boolean forEachMatcherAndArgument(ArgumentMatcherAction action) { - if (invocation.getArguments().length == matchers.size()) { - return argsMatch(invocation.getArguments(), matchers, action); - } - final boolean isVararg = invocation.getMethod().isVarArgs() && invocation.getRawArguments().length == matchers.size() - && isLastMatcherVarargMatcher(matchers); + && getLastMatcherVarargMatcherType(matchers).isPresent(); + + if (isVararg) { + final Class matcherType = getLastMatcherVarargMatcherType(matchers).get(); + final Class paramType = + invocation.getMethod() + .getParameterTypes()[ + invocation.getMethod().getParameterTypes().length - 1]; + if (paramType.isAssignableFrom(matcherType)) { + return argsMatch(invocation.getRawArguments(), matchers, action); + } + } + + if (invocation.getArguments().length == matchers.size()) { + return argsMatch(invocation.getArguments(), matchers, action); + } if (isVararg) { int times = varargLength(invocation); @@ -91,12 +104,17 @@ private boolean argsMatch( return true; } - private static boolean isLastMatcherVarargMatcher(List> matchers) { + private static Optional> getLastMatcherVarargMatcherType( + final List> matchers) { ArgumentMatcher argumentMatcher = lastMatcher(matchers); if (argumentMatcher instanceof HamcrestArgumentMatcher) { - return ((HamcrestArgumentMatcher) argumentMatcher).isVarargMatcher(); + return ((HamcrestArgumentMatcher) argumentMatcher) + .varargMatcher() + .map(VarargMatcher::type); } - return argumentMatcher instanceof VarargMatcher; + return argumentMatcher instanceof VarargMatcher + ? Optional.of((VarargMatcher) argumentMatcher).map(VarargMatcher::type) + : Optional.empty(); } private List> appendLastMatcherNTimes( diff --git a/src/main/java/org/mockito/internal/matchers/Any.java b/src/main/java/org/mockito/internal/matchers/Any.java index 7ad113feed..1a71a7e2bc 100644 --- a/src/main/java/org/mockito/internal/matchers/Any.java +++ b/src/main/java/org/mockito/internal/matchers/Any.java @@ -21,4 +21,9 @@ public boolean matches(Object actual) { public String toString() { return ""; } + + @Override + public Class type() { + return Object.class; + } } diff --git a/src/main/java/org/mockito/internal/matchers/CapturingMatcher.java b/src/main/java/org/mockito/internal/matchers/CapturingMatcher.java index 5138839ddb..3d7c022883 100644 --- a/src/main/java/org/mockito/internal/matchers/CapturingMatcher.java +++ b/src/main/java/org/mockito/internal/matchers/CapturingMatcher.java @@ -19,12 +19,17 @@ public class CapturingMatcher implements ArgumentMatcher, CapturesArguments, VarargMatcher, Serializable { + private final Class clazz; private final List arguments = new ArrayList<>(); private final ReadWriteLock lock = new ReentrantReadWriteLock(); private final Lock readLock = lock.readLock(); private final Lock writeLock = lock.writeLock(); + public CapturingMatcher(final Class clazz) { + this.clazz = clazz; + } + @Override public boolean matches(Object argument) { return true; @@ -66,4 +71,9 @@ public void captureFrom(Object argument) { writeLock.unlock(); } } + + @Override + public Class type() { + return clazz; + } } diff --git a/src/main/java/org/mockito/internal/matchers/InstanceOf.java b/src/main/java/org/mockito/internal/matchers/InstanceOf.java index 1b4b30e875..88558b64f2 100644 --- a/src/main/java/org/mockito/internal/matchers/InstanceOf.java +++ b/src/main/java/org/mockito/internal/matchers/InstanceOf.java @@ -11,7 +11,7 @@ public class InstanceOf implements ArgumentMatcher, Serializable { - private final Class clazz; + final Class clazz; private final String description; public InstanceOf(Class clazz) { @@ -44,5 +44,10 @@ public VarArgAware(Class clazz) { public VarArgAware(Class clazz, String describedAs) { super(clazz, describedAs); } + + @Override + public Class type() { + return clazz; + } } } diff --git a/src/main/java/org/mockito/internal/matchers/VarargMatcher.java b/src/main/java/org/mockito/internal/matchers/VarargMatcher.java index 43a27596f8..d484a09c3d 100644 --- a/src/main/java/org/mockito/internal/matchers/VarargMatcher.java +++ b/src/main/java/org/mockito/internal/matchers/VarargMatcher.java @@ -10,4 +10,44 @@ * Internal interface that informs Mockito that the matcher is intended to capture varargs. * This information is needed when mockito collects the arguments. */ -public interface VarargMatcher extends Serializable {} +public interface VarargMatcher extends Serializable { + + /** + * The type of the argument the matcher matches. + * + *

If a vararg aware matcher: + *

    + *
  • is at the parameter index of a vararg parameter
  • + *
  • is the last matcher passed
  • + *
  • matchers the raw type of the vararg parameter
  • + *
+ * + * Then the matcher is matched against the vararg raw parameter. + * Otherwise, the matcher will be matched against each element in the vararg raw parameters. + * + *

For example: + * + *


+     *  // Given vararg method with signature:
+     *  int someVarargMethod(int x, String... args);
+     *
+     *  // The following will match the last matcher against the contents of the `args` array:
+     *  (as the above criteria are met)
+     *  mock.someVarargMethod(eq(1), any(String[].class));
+     *
+     *  // The following will match the last matcher against each element of the `args` array:
+     *  // (as the type of the last matcher does not match the raw type of the vararg parameter)
+     *  mock.someVarargMethod(eq(1), any(String.class));
+     *
+     *  // The following will match only invocations with two strings in the 'args' array:
+     *  // (as there are more matchers than raw arguments)
+     *  mock.someVarargMethod(eq(1), any(), any());
+     * 
+ * + * @return the type this matcher handles. + * @since 4.10.0 + */ + default Class type() { + return Void.class; + } +} diff --git a/src/test/java/org/mockito/internal/invocation/InvocationMatcherTest.java b/src/test/java/org/mockito/internal/invocation/InvocationMatcherTest.java index 0b3d09d5f1..ab53be545f 100644 --- a/src/test/java/org/mockito/internal/invocation/InvocationMatcherTest.java +++ b/src/test/java/org/mockito/internal/invocation/InvocationMatcherTest.java @@ -136,7 +136,7 @@ public void should_be_similar_if_is_overloaded_but_used_with_different_arg() thr public void should_capture_arguments_from_invocation() throws Exception { // given Invocation invocation = new InvocationBuilder().args("1", 100).toInvocation(); - CapturingMatcher capturingMatcher = new CapturingMatcher(); + CapturingMatcher capturingMatcher = new CapturingMatcher(List.class); InvocationMatcher invocationMatcher = new InvocationMatcher(invocation, (List) asList(new Equals("1"), capturingMatcher)); @@ -167,7 +167,7 @@ public void should_capture_varargs_as_vararg() throws Exception { // given mock.mixedVarargs(1, "a", "b"); Invocation invocation = getLastInvocation(); - CapturingMatcher m = new CapturingMatcher(); + CapturingMatcher m = new CapturingMatcher(List.class); InvocationMatcher invocationMatcher = new InvocationMatcher(invocation, Arrays.asList(new Equals(1), m)); diff --git a/src/test/java/org/mockito/internal/invocation/MatcherApplicationStrategyTest.java b/src/test/java/org/mockito/internal/invocation/MatcherApplicationStrategyTest.java index baf7a5ad5a..432dc7ee1d 100644 --- a/src/test/java/org/mockito/internal/invocation/MatcherApplicationStrategyTest.java +++ b/src/test/java/org/mockito/internal/invocation/MatcherApplicationStrategyTest.java @@ -9,6 +9,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.internal.invocation.MatcherApplicationStrategy.getMatcherApplicationStrategyFor; import static org.mockito.internal.matchers.Any.ANY; @@ -225,12 +226,32 @@ public void shouldMatchAnyEvenIfMatcherIsWrappedInHamcrestMatcher() { recordAction.assertContainsExactly(argumentMatcher, argumentMatcher); } + @Test + public void shouldMatchAnyThatMatchesRawVarArgType() { + // given + invocation = varargs("1", "2"); + InstanceOf.VarArgAware any = new InstanceOf.VarArgAware(String[].class, ""); + matchers = asList(any); + + // when + getMatcherApplicationStrategyFor(invocation, matchers) + .forEachMatcherAndArgument(recordAction); + + // then + recordAction.assertContainsExactly(any); + } + private static class IntMatcher extends BaseMatcher implements VarargMatcher { public boolean matches(Object o) { return true; } public void describeTo(Description description) {} + + @Override + public Class type() { + return Integer.class; + } } private Invocation mixedVarargs(Object a, String... s) { diff --git a/src/test/java/org/mockito/internal/matchers/CapturingMatcherTest.java b/src/test/java/org/mockito/internal/matchers/CapturingMatcherTest.java index a4c6b59149..768d08d0b8 100644 --- a/src/test/java/org/mockito/internal/matchers/CapturingMatcherTest.java +++ b/src/test/java/org/mockito/internal/matchers/CapturingMatcherTest.java @@ -19,7 +19,7 @@ public class CapturingMatcherTest extends TestBase { @Test public void should_capture_arguments() throws Exception { // given - CapturingMatcher m = new CapturingMatcher(); + CapturingMatcher m = new CapturingMatcher(String.class); // when m.captureFrom("foo"); @@ -32,7 +32,7 @@ public void should_capture_arguments() throws Exception { @Test public void should_know_last_captured_value() throws Exception { // given - CapturingMatcher m = new CapturingMatcher(); + CapturingMatcher m = new CapturingMatcher(String.class); // when m.captureFrom("foo"); @@ -45,7 +45,7 @@ public void should_know_last_captured_value() throws Exception { @Test public void should_scream_when_nothing_yet_captured() throws Exception { // given - CapturingMatcher m = new CapturingMatcher(); + CapturingMatcher m = new CapturingMatcher(String.class); try { // when @@ -59,7 +59,7 @@ public void should_scream_when_nothing_yet_captured() throws Exception { @Test public void should_not_fail_when_used_in_concurrent_tests() throws Exception { // given - final CapturingMatcher m = new CapturingMatcher(); + final CapturingMatcher m = new CapturingMatcher(String.class); // when m.captureFrom("concurrent access"); diff --git a/src/test/java/org/mockitousage/IMethods.java b/src/test/java/org/mockitousage/IMethods.java index 06492f09cb..1b14af4714 100644 --- a/src/test/java/org/mockitousage/IMethods.java +++ b/src/test/java/org/mockitousage/IMethods.java @@ -199,6 +199,10 @@ String sixArgumentVarArgsMethod( Object[] mixedVarargsReturningObjectArray(Object i, String... string); + String methodWithVarargAndNonVarargVariants(String string); + + String methodWithVarargAndNonVarargVariants(String... string); + List listReturningMethod(Object... objects); LinkedList linkedListReturningMethod(); diff --git a/src/test/java/org/mockitousage/MethodsImpl.java b/src/test/java/org/mockitousage/MethodsImpl.java index e2d53ba7ba..f98a4f141e 100644 --- a/src/test/java/org/mockitousage/MethodsImpl.java +++ b/src/test/java/org/mockitousage/MethodsImpl.java @@ -376,6 +376,16 @@ public Object[] mixedVarargsReturningObjectArray(Object i, String... string) { return null; } + @Override + public String methodWithVarargAndNonVarargVariants(String string) { + return "plain"; + } + + @Override + public String methodWithVarargAndNonVarargVariants(String... string) { + return "varargs"; + } + public void varargsbyte(byte... bytes) {} public List listReturningMethod(Object... objects) { diff --git a/src/test/java/org/mockitousage/matchers/VarargsTest.java b/src/test/java/org/mockitousage/matchers/VarargsTest.java index dfb726942b..18e56af3ee 100644 --- a/src/test/java/org/mockitousage/matchers/VarargsTest.java +++ b/src/test/java/org/mockitousage/matchers/VarargsTest.java @@ -4,12 +4,14 @@ */ package org.mockitousage.matchers; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isNotNull; import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -198,7 +200,7 @@ public void shouldCaptureVarArgs_noArgs() { verify(mock).varargs(captor.capture()); - assertThat(captor).isEmpty(); + assertThatCaptor(captor).isEmpty(); } @Test @@ -208,7 +210,7 @@ public void shouldCaptureVarArgs_oneNullArg_eqNull() { verify(mock).varargs(captor.capture()); - assertThat(captor).areExactly(1, NULL); + assertThatCaptor(captor).areExactly(1, NULL); } /** @@ -221,7 +223,7 @@ public void shouldCaptureVarArgs_nullArrayArg() { mock.varargs(argArray); verify(mock).varargs(captor.capture()); - assertThat(captor).areExactly(1, NULL); + assertThatCaptor(captor).areExactly(1, NULL); } @Test @@ -230,7 +232,7 @@ public void shouldCaptureVarArgs_twoArgsOneCapture() { verify(mock).varargs(captor.capture()); - assertThat(captor).contains("1", "2"); + assertThatCaptor(captor).contains("1", "2"); } @Test @@ -239,7 +241,7 @@ public void shouldCaptureVarArgs_twoArgsTwoCaptures() { verify(mock).varargs(captor.capture(), captor.capture()); - assertThat(captor).contains("1", "2"); + assertThatCaptor(captor).contains("1", "2"); } @Test @@ -248,7 +250,7 @@ public void shouldCaptureVarArgs_oneNullArgument() { verify(mock).varargs(captor.capture()); - assertThat(captor).contains("1", (String) null); + assertThatCaptor(captor).contains("1", (String) null); } @Test @@ -257,7 +259,7 @@ public void shouldCaptureVarArgs_oneNullArgument2() { verify(mock).varargs(captor.capture(), captor.capture()); - assertThat(captor).contains("1", (String) null); + assertThatCaptor(captor).contains("1", (String) null); } @Test @@ -277,7 +279,7 @@ public void shouldCaptureVarArgs_3argsCaptorMatcherMix() { verify(mock).varargs(captor.capture(), eq("2"), captor.capture()); - assertThat(captor).containsExactly("1", "3"); + assertThatCaptor(captor).containsExactly("1", "3"); } @Test @@ -290,7 +292,7 @@ public void shouldNotCaptureVarArgs_3argsCaptorMatcherMix() { } catch (ArgumentsAreDifferent expected) { } - assertThat(captor).isEmpty(); + assertThatCaptor(captor).isEmpty(); } @Test @@ -321,7 +323,7 @@ public void shouldCaptureVarArgsAsArray() { verify(mock).varargs(varargCaptor.capture()); - assertThat(varargCaptor).containsExactly(new String[] {"1", "2"}); + assertThatCaptor(varargCaptor).containsExactly(new String[] {"1", "2"}); } @Test @@ -342,7 +344,131 @@ public void shouldNotMatchVaraArgs() { Assertions.assertThat(mock.varargsObject(1)).isNull(); } - private static AbstractListAssert> assertThat( + @Test + public void shouldDifferentiateNonVarargVariant() { + given(mock.methodWithVarargAndNonVarargVariants(any(String.class))) + .willReturn("single arg method"); + + assertThat(mock.methodWithVarargAndNonVarargVariants("a")).isEqualTo("single arg method"); + assertThat(mock.methodWithVarargAndNonVarargVariants(new String[] {"a"})).isNull(); + assertThat(mock.methodWithVarargAndNonVarargVariants("a", "b")).isNull(); + } + + @Test + public void shouldMockVarargsInvocation_single_vararg_matcher() { + given(mock.methodWithVarargAndNonVarargVariants(any(String[].class))) + .willReturn("var arg method"); + + assertThat(mock.methodWithVarargAndNonVarargVariants("a")).isNull(); + assertThat(mock.methodWithVarargAndNonVarargVariants(new String[] {"a"})) + .isEqualTo("var arg method"); + assertThat(mock.methodWithVarargAndNonVarargVariants("a", "b")).isEqualTo("var arg method"); + } + + @Test + public void shouldMockVarargsInvocation_multiple_vararg_matcher() { + given(mock.methodWithVarargAndNonVarargVariants(any(String.class), any(String.class))) + .willReturn("var arg method"); + + assertThat(mock.methodWithVarargAndNonVarargVariants("a")).isNull(); + assertThat(mock.methodWithVarargAndNonVarargVariants(new String[] {"a"})).isNull(); + assertThat(mock.methodWithVarargAndNonVarargVariants("a", "b")).isEqualTo("var arg method"); + assertThat(mock.methodWithVarargAndNonVarargVariants(new String[] {"a", "b"})) + .isEqualTo("var arg method"); + assertThat(mock.methodWithVarargAndNonVarargVariants("a", "b", "c")).isNull(); + } + + @Test + public void shouldMockVarargsInvocationUsingCasts() { + given(mock.methodWithVarargAndNonVarargVariants((String) any())) + .willReturn("single arg method"); + given(mock.methodWithVarargAndNonVarargVariants((String[]) any())) + .willReturn("var arg method"); + + assertThat(mock.methodWithVarargAndNonVarargVariants("a")).isEqualTo("single arg method"); + assertThat(mock.methodWithVarargAndNonVarargVariants()).isEqualTo("var arg method"); + assertThat(mock.methodWithVarargAndNonVarargVariants(new String[] {"a"})) + .isEqualTo("var arg method"); + assertThat(mock.methodWithVarargAndNonVarargVariants("a", "b")).isEqualTo("var arg method"); + } + + @Test + public void shouldMockVarargsInvocationForSuperType() { + given(mock.varargsReturningString(any(Object[].class))).willReturn("a"); + + assertThat(mock.varargsReturningString("a", "b")).isEqualTo("a"); + } + + @Test + public void shouldHandleArrayVarargsMethods() { + given(mock.arrayVarargsMethod(any(String[][].class))).willReturn(1); + + assertThat(mock.arrayVarargsMethod(new String[]{})).isEqualTo(1); + } + + @Test + public void shouldCaptureVarArgs_NullArrayArg1() { + mock.varargs((String[]) null); + ArgumentCaptor captor = ArgumentCaptor.forClass(String[].class); + + verify(mock).varargs(captor.capture()); + + assertThat(captor.getValue()).isNull(); + } + + @Test + public void shouldCaptureVarArgs_NullArrayArg2() { + mock.varargs((String[]) null); + ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); + + verify(mock).varargs(captor.capture()); + + assertThat(captor.getValue()).isNull(); + } + + @Test + public void shouldVerifyVarArgs_any_NullArrayArg1() { + mock.varargs((String[]) null); + + verify(mock).varargs(ArgumentMatchers.any()); + } + + @Test + public void shouldVerifyVarArgs_any_NullArrayArg2() { + mock.varargs((String) null); + + verify(mock).varargs(ArgumentMatchers.any()); + } + + @Test + public void shouldVerifyVarArgs_eq_NullArrayArg1() { + mock.varargs((String[]) null); + + verify(mock).varargs(ArgumentMatchers.eq(null)); + } + + @Test + public void shouldVerifyVarArgs_eq_NullArrayArg2() { + mock.varargs((String) null); + + verify(mock).varargs(ArgumentMatchers.eq(null)); + } + + @Test + public void shouldVerifyVarArgs_isNull_NullArrayArg1() { + mock.varargs((String[]) null); + + verify(mock).varargs(ArgumentMatchers.isNull()); + } + + @Test + public void shouldVerifyVarArgs_isNull_NullArrayArg2() { + mock.varargs((String) null); + + verify(mock).varargs(ArgumentMatchers.isNull()); + } + + private static AbstractListAssert> assertThatCaptor( ArgumentCaptor captor) { return Assertions.assertThat(captor.getAllValues()); }