diff --git a/src/main/java/org/mockito/internal/stubbing/answers/AnswerFunctionalInterfaces.java b/src/main/java/org/mockito/internal/stubbing/answers/AnswerFunctionalInterfaces.java index e4692f26d9..d8f824f471 100644 --- a/src/main/java/org/mockito/internal/stubbing/answers/AnswerFunctionalInterfaces.java +++ b/src/main/java/org/mockito/internal/stubbing/answers/AnswerFunctionalInterfaces.java @@ -4,6 +4,7 @@ */ package org.mockito.internal.stubbing.answers; +import org.mockito.invocation.Invocation; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer1; @@ -19,6 +20,8 @@ import org.mockito.stubbing.VoidAnswer5; import org.mockito.stubbing.VoidAnswer6; +import java.lang.reflect.Method; + /** * Functional interfaces to make it easy to implement answers in Java 8 * @@ -38,11 +41,11 @@ private AnswerFunctionalInterfaces() {} * @return a new answer object */ public static Answer toAnswer(final Answer1 answer) { + final Method answerMethod = findAnswerMethod(answer.getClass(), 1); return new Answer() { @Override - @SuppressWarnings("unchecked") public T answer(InvocationOnMock invocation) throws Throwable { - return answer.answer((A) invocation.getArgument(0)); + return answer.answer(lastParameter(invocation, answerMethod, 0)); } }; } @@ -54,11 +57,11 @@ public T answer(InvocationOnMock invocation) throws Throwable { * @return a new answer object */ public static Answer toAnswer(final VoidAnswer1 answer) { + final Method answerMethod = findAnswerMethod(answer.getClass(), 1); return new Answer() { @Override - @SuppressWarnings("unchecked") public Void answer(InvocationOnMock invocation) throws Throwable { - answer.answer((A) invocation.getArgument(0)); + answer.answer(lastParameter(invocation, answerMethod, 0)); return null; } }; @@ -73,11 +76,13 @@ public Void answer(InvocationOnMock invocation) throws Throwable { * @return a new answer object */ public static Answer toAnswer(final Answer2 answer) { + final Method answerMethod = findAnswerMethod(answer.getClass(), 2); return new Answer() { @Override @SuppressWarnings("unchecked") public T answer(InvocationOnMock invocation) throws Throwable { - return answer.answer((A) invocation.getArgument(0), (B) invocation.getArgument(1)); + return answer.answer( + (A) invocation.getArgument(0), lastParameter(invocation, answerMethod, 1)); } }; } @@ -90,11 +95,13 @@ public T answer(InvocationOnMock invocation) throws Throwable { * @return a new answer object */ public static Answer toAnswer(final VoidAnswer2 answer) { + final Method answerMethod = findAnswerMethod(answer.getClass(), 2); return new Answer() { @Override @SuppressWarnings("unchecked") public Void answer(InvocationOnMock invocation) throws Throwable { - answer.answer((A) invocation.getArgument(0), (B) invocation.getArgument(1)); + answer.answer( + (A) invocation.getArgument(0), lastParameter(invocation, answerMethod, 1)); return null; } }; @@ -303,4 +310,35 @@ public Void answer(InvocationOnMock invocation) throws Throwable { } }; } + + private static Method findAnswerMethod(final Class type, final int numberOfParameters) { + for (final Method m : type.getDeclaredMethods()) { + if (!m.isBridge() + && m.getName().equals("answer") + && m.getParameterTypes().length == numberOfParameters) { + return m; + } + } + // Todo: throw a descent error message + throw new RuntimeException("todo"); + } + + @SuppressWarnings("unchecked") + private static A lastParameter( + InvocationOnMock invocationOnMock, Method answerMethod, int argumentIndex) { + final Method invocationMethod = invocationOnMock.getMethod(); + if (invocationMethod.isVarArgs() + && invocationMethod.getParameterTypes().length == (argumentIndex + 1)) { + final Class invocationRawArgType = + invocationMethod.getParameterTypes()[argumentIndex]; + final Class answerRawArgType = answerMethod.getParameterTypes()[argumentIndex]; + if (answerRawArgType.isAssignableFrom( + invocationRawArgType)) { // Todo: test super types. + Invocation invocation = (Invocation) invocationOnMock; + return (A) invocation.getRawArguments()[argumentIndex]; + } + } + + return (A) invocationOnMock.getArgument(argumentIndex); + } } diff --git a/src/main/java/org/mockito/internal/stubbing/answers/ReturnsArgumentAt.java b/src/main/java/org/mockito/internal/stubbing/answers/ReturnsArgumentAt.java index cd661152c3..885f79453a 100644 --- a/src/main/java/org/mockito/internal/stubbing/answers/ReturnsArgumentAt.java +++ b/src/main/java/org/mockito/internal/stubbing/answers/ReturnsArgumentAt.java @@ -51,16 +51,18 @@ public ReturnsArgumentAt(int wantedArgumentPosition) { } @Override - public Object answer(InvocationOnMock invocation) throws Throwable { - int argumentPosition = inferWantedArgumentPosition(invocation); - validateIndexWithinInvocationRange(invocation, argumentPosition); + public Object answer(InvocationOnMock invocationOnMock) throws Throwable { + Invocation invocation = (Invocation) invocationOnMock; - if (wantedArgIndexIsVarargAndSameTypeAsReturnType( - invocation.getMethod(), argumentPosition)) { + if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation)) { // answer raw vararg array argument - return ((Invocation) invocation).getRawArguments()[argumentPosition]; + int rawArgumentPosition = inferWantedRawArgumentPosition(invocation); + return invocation.getRawArguments()[rawArgumentPosition]; } + int argumentPosition = inferWantedArgumentPosition(invocation); + validateIndexWithinInvocationRange(invocation, argumentPosition); + // answer expanded argument at wanted position return invocation.getArgument(argumentPosition); } @@ -68,7 +70,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable { @Override public void validateFor(InvocationOnMock invocation) { int argumentPosition = inferWantedArgumentPosition(invocation); - validateIndexWithinInvocationRange(invocation, argumentPosition); + validateIndexWithinTheoreticalInvocationRange(invocation, argumentPosition); validateArgumentTypeCompatibility((Invocation) invocation, argumentPosition); } @@ -80,9 +82,26 @@ private int inferWantedArgumentPosition(InvocationOnMock invocation) { return wantedArgumentPosition; } + private int inferWantedRawArgumentPosition(Invocation invocation) { + if (wantedArgumentPosition == LAST_ARGUMENT) { + return invocation.getRawArguments().length - 1; + } + + return wantedArgumentPosition; + } + private void validateIndexWithinInvocationRange( InvocationOnMock invocation, int argumentPosition) { - if (!wantedArgumentPositionIsValidForInvocation(invocation, argumentPosition)) { + + if (argumentPosition < 0 || invocation.getArguments().length <= argumentPosition) { + throw invalidArgumentPositionRangeAtInvocationTime( + invocation, wantedArgumentPosition == LAST_ARGUMENT, wantedArgumentPosition); + } + } + + private void validateIndexWithinTheoreticalInvocationRange( + InvocationOnMock invocation, int argumentPosition) { + if (!wantedArgumentPositionIsValidForTheoreticalInvocation(invocation, argumentPosition)) { throw invalidArgumentPositionRangeAtInvocationTime( invocation, wantedArgumentPosition == LAST_ARGUMENT, wantedArgumentPosition); } @@ -102,15 +121,16 @@ private void validateArgumentTypeCompatibility(Invocation invocation, int argume } } - private boolean wantedArgIndexIsVarargAndSameTypeAsReturnType( - Method method, int argumentPosition) { + private boolean wantedArgIndexIsVarargAndSameTypeAsReturnType(Invocation invocation) { + int rawArgumentPosition = inferWantedRawArgumentPosition(invocation); + Method method = invocation.getMethod(); Class[] parameterTypes = method.getParameterTypes(); return method.isVarArgs() - && argumentPosition == /* vararg index */ parameterTypes.length - 1 - && method.getReturnType().isAssignableFrom(parameterTypes[argumentPosition]); + && rawArgumentPosition == /* vararg index */ parameterTypes.length - 1 + && method.getReturnType().isAssignableFrom(parameterTypes[rawArgumentPosition]); } - private boolean wantedArgumentPositionIsValidForInvocation( + private boolean wantedArgumentPositionIsValidForTheoreticalInvocation( InvocationOnMock invocation, int argumentPosition) { if (argumentPosition < 0) { return false; @@ -145,7 +165,7 @@ private Class inferArgumentType(Invocation invocation, int argumentIndex) { return parameterTypes[argumentIndex]; } // if wanted argument is vararg - if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation.getMethod(), argumentIndex)) { + if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation)) { // return the vararg array if return type is compatible // because the user probably want to return the array itself if the return type is // compatible diff --git a/src/main/java/org/mockito/invocation/InvocationOnMock.java b/src/main/java/org/mockito/invocation/InvocationOnMock.java index 269926ea6b..039308cc7f 100644 --- a/src/main/java/org/mockito/invocation/InvocationOnMock.java +++ b/src/main/java/org/mockito/invocation/InvocationOnMock.java @@ -56,6 +56,8 @@ public interface InvocationOnMock extends Serializable { */ T getArgument(int index); + // Todo: add getRawArguments to this interface? + /** * Returns casted argument at the given index. This method is analogous to * {@link #getArgument(int)}, but is necessary to circumvent issues when dealing with generics. diff --git a/src/test/java/org/mockitousage/stubbing/StubbingWithAdditionalAnswersTest.java b/src/test/java/org/mockitousage/stubbing/StubbingWithAdditionalAnswersTest.java index 7dfed447a9..4147e8e351 100644 --- a/src/test/java/org/mockitousage/stubbing/StubbingWithAdditionalAnswersTest.java +++ b/src/test/java/org/mockitousage/stubbing/StubbingWithAdditionalAnswersTest.java @@ -5,6 +5,7 @@ package org.mockitousage.stubbing; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.within; import static org.mockito.AdditionalAnswers.answer; import static org.mockito.AdditionalAnswers.answerVoid; @@ -27,6 +28,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.exceptions.base.MockitoException; import org.mockito.junit.MockitoJUnitRunner; import org.mockito.stubbing.Answer1; import org.mockito.stubbing.Answer2; @@ -52,10 +54,39 @@ public void can_return_arguments_of_invocation() throws Exception { given(iMethods.objectArgMethod(any())).will(returnsFirstArg()); given(iMethods.threeArgumentMethod(eq(0), any(), anyString())).will(returnsSecondArg()); given(iMethods.threeArgumentMethod(eq(1), any(), anyString())).will(returnsLastArg()); + given(iMethods.mixedVarargsReturningString(eq(1), any())).will(returnsArgAt(2)); assertThat(iMethods.objectArgMethod("first")).isEqualTo("first"); assertThat(iMethods.threeArgumentMethod(0, "second", "whatever")).isEqualTo("second"); assertThat(iMethods.threeArgumentMethod(1, "whatever", "last")).isEqualTo("last"); + assertThat(iMethods.mixedVarargsReturningString(1, "a", "b")).isEqualTo("b"); + } + + @Test + public void can_return_var_arguments_of_invocation() throws Exception { + given(iMethods.mixedVarargsReturningStringArray(eq(1), any())).will(returnsLastArg()); + given(iMethods.mixedVarargsReturningObjectArray(eq(1), any())).will(returnsArgAt(1)); + + assertThat(iMethods.mixedVarargsReturningStringArray(1, "the", "var", "args")) + .containsExactlyInAnyOrder("the", "var", "args"); + assertThat(iMethods.mixedVarargsReturningObjectArray(1, "the", "var", "args")) + .containsExactlyInAnyOrder("the", "var", "args"); + } + + @Test + public void returns_arg_at_throws_on_out_of_range_var_args() throws Exception { + given(iMethods.mixedVarargsReturningString(eq(1), any())).will(returnsArgAt(3)); + + assertThatThrownBy(() -> iMethods.mixedVarargsReturningString(1, "a", "b")) + .isInstanceOf(MockitoException.class) + .hasMessageContaining("Invalid argument index"); + + assertThatThrownBy( + () -> + given(iMethods.mixedVarargsReturningStringArray(eq(1), any())) + .will(returnsArgAt(3))) + .isInstanceOf(MockitoException.class) + .hasMessageContaining("The argument of type 'String' cannot be returned"); } @Test @@ -347,4 +378,78 @@ public void answer( // expect the answer to write correctly to "target" verify(target, times(1)).simpleMethod("hello", 1, 2, 3, 4, 5); } + + @Test + public void can_return_based_on_strongly_types_one_parameter_var_args_function() + throws Exception { + given(iMethods.varargs(any())) + .will( + answer( + new Answer1() { + public Integer answer(String[] strings) { + return strings.length; + } + })); + + assertThat(iMethods.varargs("some", "args")).isEqualTo(2); + } + + @Test + public void will_execute_a_void_based_on_strongly_typed_one_parameter_var_args_function() + throws Exception { + final IMethods target = mock(IMethods.class); + + given(iMethods.varargs(any())) + .will( + answerVoid( + new VoidAnswer1() { + public void answer(String[] s) { + target.varargs(s); + } + })); + + // invoke on iMethods + iMethods.varargs("some", "args"); + + // expect the answer to write correctly to "target" + verify(target, times(1)).varargs("some", "args"); + } + + @Test + public void can_return_based_on_strongly_typed_two_parameter_var_args_function() + throws Exception { + given(iMethods.mixedVarargsReturningString(any(), any())) + .will( + answer( + new Answer2() { + public String answer(Object o, String[] s) { + return String.join("-", s); + } + })); + + assertThat(iMethods.mixedVarargsReturningString(1, "var", "args")).isEqualTo("var-args"); + } + + @Test + public void will_execute_a_void_based_on_strongly_typed_two_parameter_var_args_function() + throws Exception { + final IMethods target = mock(IMethods.class); + + given(iMethods.mixedVarargsReturningString(any(), any())) + .will( + answerVoid( + new VoidAnswer2() { + public void answer(Object o, String[] s) { + target.mixedVarargsReturningString(o, s); + } + })); + + // invoke on iMethods + iMethods.mixedVarargsReturningString(1, "var", "args"); + + // expect the answer to write correctly to "target" + verify(target, times(1)).mixedVarargsReturningString(1, "var", "args"); + } + + // todo: tests for other versions of answer and voidAnswer. }