Skip to content

Commit

Permalink
Improve Varargs handling in AdditionalAnswers
Browse files Browse the repository at this point in the history
Fixes: mockito#2644

Fixes issues around vararg handling for the following methods in `AdditionalAnswers`:
 * `returnsFirstArg`
 * `returnsSecondArg`
 * `returnsLastArg`
 * `returnsArgAt`
 * `answer`
 * `answerVoid`

These methods were not correctly handling varargs. For example,

```java
doAnswer(answerVoid(
      (VoidAnswer2<String, Object[]>) logger::info
   )).when(mock)
      .info(any(), (Object[]) any());

mock.info("Some message with {} {} {}", "three", "parameters", "");
```

Would previously have resulted in a `ClassCastException` being thrown from the `mock.info` call.  This was because the `answerVoid` method was not taking into account that the second parameter was a varargs parameter and was attempting to pass the second actual argument `"three"`, rather than the second _raw_ argument `["three", "parameters", ""]`.
  • Loading branch information
big-andy-coates committed Jun 6, 2022
1 parent 1fbef57 commit 459efad
Show file tree
Hide file tree
Showing 5 changed files with 535 additions and 30 deletions.
Expand Up @@ -4,6 +4,8 @@
*/
package org.mockito.internal.stubbing.answers;

import org.mockito.exceptions.base.MockitoException;
import org.mockito.invocation.Invocation;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.mockito.stubbing.Answer1;
Expand All @@ -19,6 +21,10 @@
import org.mockito.stubbing.VoidAnswer5;
import org.mockito.stubbing.VoidAnswer6;

import java.lang.reflect.Method;

import static org.mockito.internal.util.StringUtil.join;

/**
* Functional interfaces to make it easy to implement answers in Java 8
*
Expand All @@ -38,11 +44,11 @@ private AnswerFunctionalInterfaces() {}
* @return a new answer object
*/
public static <T, A> Answer<T> toAnswer(final Answer1<T, A> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 1);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
public T answer(InvocationOnMock invocation) throws Throwable {
return answer.answer((A) invocation.getArgument(0));
return answer.answer(lastParameter(invocation, answerMethod, 0));
}
};
}
Expand All @@ -54,11 +60,11 @@ public T answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <A> Answer<Void> toAnswer(final VoidAnswer1<A> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 1);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
public Void answer(InvocationOnMock invocation) throws Throwable {
answer.answer((A) invocation.getArgument(0));
answer.answer(lastParameter(invocation, answerMethod, 0));
return null;
}
};
Expand All @@ -73,11 +79,13 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <T, A, B> Answer<T> toAnswer(final Answer2<T, A, B> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 2);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
public T answer(InvocationOnMock invocation) throws Throwable {
return answer.answer((A) invocation.getArgument(0), (B) invocation.getArgument(1));
return answer.answer(
(A) invocation.getArgument(0), lastParameter(invocation, answerMethod, 1));
}
};
}
Expand All @@ -90,11 +98,13 @@ public T answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <A, B> Answer<Void> toAnswer(final VoidAnswer2<A, B> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 2);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
public Void answer(InvocationOnMock invocation) throws Throwable {
answer.answer((A) invocation.getArgument(0), (B) invocation.getArgument(1));
answer.answer(
(A) invocation.getArgument(0), lastParameter(invocation, answerMethod, 1));
return null;
}
};
Expand All @@ -110,14 +120,15 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <T, A, B, C> Answer<T> toAnswer(final Answer3<T, A, B, C> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 3);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
public T answer(InvocationOnMock invocation) throws Throwable {
return answer.answer(
(A) invocation.getArgument(0),
(B) invocation.getArgument(1),
(C) invocation.getArgument(2));
lastParameter(invocation, answerMethod, 2));
}
};
}
Expand All @@ -131,14 +142,15 @@ public T answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <A, B, C> Answer<Void> toAnswer(final VoidAnswer3<A, B, C> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 3);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
public Void answer(InvocationOnMock invocation) throws Throwable {
answer.answer(
(A) invocation.getArgument(0),
(B) invocation.getArgument(1),
(C) invocation.getArgument(2));
lastParameter(invocation, answerMethod, 2));
return null;
}
};
Expand All @@ -155,6 +167,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <T, A, B, C, D> Answer<T> toAnswer(final Answer4<T, A, B, C, D> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 4);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
Expand All @@ -163,7 +176,7 @@ public T answer(InvocationOnMock invocation) throws Throwable {
(A) invocation.getArgument(0),
(B) invocation.getArgument(1),
(C) invocation.getArgument(2),
(D) invocation.getArgument(3));
lastParameter(invocation, answerMethod, 3));
}
};
}
Expand All @@ -178,6 +191,7 @@ public T answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <A, B, C, D> Answer<Void> toAnswer(final VoidAnswer4<A, B, C, D> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 4);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
Expand All @@ -186,7 +200,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
(A) invocation.getArgument(0),
(B) invocation.getArgument(1),
(C) invocation.getArgument(2),
(D) invocation.getArgument(3));
lastParameter(invocation, answerMethod, 3));
return null;
}
};
Expand All @@ -204,6 +218,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <T, A, B, C, D, E> Answer<T> toAnswer(final Answer5<T, A, B, C, D, E> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 5);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
Expand All @@ -213,7 +228,7 @@ public T answer(InvocationOnMock invocation) throws Throwable {
(B) invocation.getArgument(1),
(C) invocation.getArgument(2),
(D) invocation.getArgument(3),
(E) invocation.getArgument(4));
lastParameter(invocation, answerMethod, 4));
}
};
}
Expand All @@ -229,6 +244,7 @@ public T answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <A, B, C, D, E> Answer<Void> toAnswer(final VoidAnswer5<A, B, C, D, E> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 5);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
Expand All @@ -238,7 +254,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
(B) invocation.getArgument(1),
(C) invocation.getArgument(2),
(D) invocation.getArgument(3),
(E) invocation.getArgument(4));
lastParameter(invocation, answerMethod, 4));
return null;
}
};
Expand All @@ -259,6 +275,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
*/
public static <T, A, B, C, D, E, F> Answer<T> toAnswer(
final Answer6<T, A, B, C, D, E, F> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 6);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
Expand All @@ -269,7 +286,7 @@ public T answer(InvocationOnMock invocation) throws Throwable {
(C) invocation.getArgument(2),
(D) invocation.getArgument(3),
(E) invocation.getArgument(4),
(F) invocation.getArgument(5));
lastParameter(invocation, answerMethod, 5));
}
};
}
Expand All @@ -288,6 +305,7 @@ public T answer(InvocationOnMock invocation) throws Throwable {
*/
public static <A, B, C, D, E, F> Answer<Void> toAnswer(
final VoidAnswer6<A, B, C, D, E, F> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 6);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
Expand All @@ -298,9 +316,43 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
(C) invocation.getArgument(2),
(D) invocation.getArgument(3),
(E) invocation.getArgument(4),
(F) invocation.getArgument(5));
lastParameter(invocation, answerMethod, 5));
return null;
}
};
}

private static Method findAnswerMethod(final Class<?> type, final int numberOfParameters) {
for (final Method m : type.getDeclaredMethods()) {
if (!m.isBridge()
&& m.getName().equals("answer")
&& m.getParameterTypes().length == numberOfParameters) {
return m;
}
}
throw new IllegalStateException(
"Failed to find answer() method on the supplied class: "
+ type.getName()
+ ", with the supplied number of parameters: "
+ numberOfParameters);
}

@SuppressWarnings("unchecked")
private static <A> A lastParameter(
InvocationOnMock invocationOnMock, Method answerMethod, int argumentIndex) {
final Method invocationMethod = invocationOnMock.getMethod();

if (invocationMethod.isVarArgs()
&& invocationMethod.getParameterTypes().length == (argumentIndex + 1)) {
final Class<?> invocationRawArgType =
invocationMethod.getParameterTypes()[argumentIndex];
final Class<?> answerRawArgType = answerMethod.getParameterTypes()[argumentIndex];
if (answerRawArgType.isAssignableFrom(invocationRawArgType)) {
Invocation invocation = (Invocation) invocationOnMock;
return (A) invocation.getRawArguments()[argumentIndex];
}
}

return invocationOnMock.getArgument(argumentIndex);
}
}
Expand Up @@ -51,25 +51,27 @@ public ReturnsArgumentAt(int wantedArgumentPosition) {
}

@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
int argumentPosition = inferWantedArgumentPosition(invocation);
validateIndexWithinInvocationRange(invocation, argumentPosition);
public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
Invocation invocation = (Invocation) invocationOnMock;

if (wantedArgIndexIsVarargAndSameTypeAsReturnType(
invocation.getMethod(), argumentPosition)) {
if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation)) {
// answer raw vararg array argument
return ((Invocation) invocation).getRawArguments()[argumentPosition];
return invocation.getRawArguments()[invocation.getRawArguments().length - 1];
}

int argumentPosition = inferWantedArgumentPosition(invocation);
validateIndexWithinInvocationRange(invocation, argumentPosition);

// answer expanded argument at wanted position
return invocation.getArgument(argumentPosition);
}

@Override
public void validateFor(InvocationOnMock invocation) {
public void validateFor(InvocationOnMock invocationOnMock) {
Invocation invocation = (Invocation) invocationOnMock;
int argumentPosition = inferWantedArgumentPosition(invocation);
validateIndexWithinInvocationRange(invocation, argumentPosition);
validateArgumentTypeCompatibility((Invocation) invocation, argumentPosition);
validateIndexWithinTheoreticalInvocationRange(invocation, argumentPosition);
validateArgumentTypeCompatibility(invocation, argumentPosition);
}

private int inferWantedArgumentPosition(InvocationOnMock invocation) {
Expand All @@ -80,9 +82,26 @@ private int inferWantedArgumentPosition(InvocationOnMock invocation) {
return wantedArgumentPosition;
}

private int inferWantedRawArgumentPosition(Invocation invocation) {
if (wantedArgumentPosition == LAST_ARGUMENT) {
return invocation.getRawArguments().length - 1;
}

return wantedArgumentPosition;
}

private void validateIndexWithinInvocationRange(
InvocationOnMock invocation, int argumentPosition) {
if (!wantedArgumentPositionIsValidForInvocation(invocation, argumentPosition)) {

if (invocation.getArguments().length <= argumentPosition) {
throw invalidArgumentPositionRangeAtInvocationTime(
invocation, wantedArgumentPosition == LAST_ARGUMENT, wantedArgumentPosition);
}
}

private void validateIndexWithinTheoreticalInvocationRange(
InvocationOnMock invocation, int argumentPosition) {
if (!wantedArgumentPositionIsValidForTheoreticalInvocation(invocation, argumentPosition)) {
throw invalidArgumentPositionRangeAtInvocationTime(
invocation, wantedArgumentPosition == LAST_ARGUMENT, wantedArgumentPosition);
}
Expand All @@ -102,15 +121,16 @@ private void validateArgumentTypeCompatibility(Invocation invocation, int argume
}
}

private boolean wantedArgIndexIsVarargAndSameTypeAsReturnType(
Method method, int argumentPosition) {
private boolean wantedArgIndexIsVarargAndSameTypeAsReturnType(Invocation invocation) {
int rawArgumentPosition = inferWantedRawArgumentPosition(invocation);
Method method = invocation.getMethod();
Class<?>[] parameterTypes = method.getParameterTypes();
return method.isVarArgs()
&& argumentPosition == /* vararg index */ parameterTypes.length - 1
&& method.getReturnType().isAssignableFrom(parameterTypes[argumentPosition]);
&& rawArgumentPosition == /* vararg index */ parameterTypes.length - 1
&& method.getReturnType().isAssignableFrom(parameterTypes[rawArgumentPosition]);
}

private boolean wantedArgumentPositionIsValidForInvocation(
private boolean wantedArgumentPositionIsValidForTheoreticalInvocation(
InvocationOnMock invocation, int argumentPosition) {
if (argumentPosition < 0) {
return false;
Expand Down Expand Up @@ -145,7 +165,7 @@ private Class<?> inferArgumentType(Invocation invocation, int argumentIndex) {
return parameterTypes[argumentIndex];
}
// if wanted argument is vararg
if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation.getMethod(), argumentIndex)) {
if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation)) {
// return the vararg array if return type is compatible
// because the user probably want to return the array itself if the return type is
// compatible
Expand Down
20 changes: 20 additions & 0 deletions src/test/java/org/mockitousage/IMethods.java
Expand Up @@ -133,8 +133,26 @@ String simpleMethod(

String threeArgumentMethodWithStrings(int valueOne, String valueTwo, String valueThree);

String threeArgumentVarArgsMethod(int valueOne, String valueTwo, String... valueThree);

String fourArgumentMethod(int valueOne, String valueTwo, String valueThree, boolean[] array);

String fourArgumentVarArgsMethod(
int valueOne, String valueTwo, int valueThree, String... valueFour);

String fiveArgumentVarArgsMethod(
int valueOne, String valueTwo, int valueThree, String valueFour, String... valueFive);

String sixArgumentVarArgsMethod(
int valueOne,
String valueTwo,
int valueThree,
String valueFour,
String valueFive,
String... valueSix);

int arrayVarargsMethod(String[]... arrayVarArgs);

void twoArgumentMethod(int one, int two);

void arrayMethod(String[] strings);
Expand Down Expand Up @@ -235,6 +253,8 @@ String simpleMethod(

Integer toIntWrapper(int i);

Integer toIntWrapperVarArgs(int i, Object... varargs);

String forObject(Object object);

<T> String genericToString(T arg);
Expand Down

0 comments on commit 459efad

Please sign in to comment.