Skip to content

Commit

Permalink
Improve Varargs handling in AdditionalAnswers
Browse files Browse the repository at this point in the history
Fixes: mockito#2644

Fixes issues around vararg handling for the following methods in `AdditionalAnswers`:
 * `returnsFirstArg`
 * `returnsSecondArg`
 * `returnsLastArg`
 * `returnsArgAt`
 * `answer`
 * `answerVoid`

These methods were not correctly handling varargs. For example,

```java
doAnswer(answerVoid(
      (VoidAnswer2<String, Object[]>) logger::info
   )).when(mock)
      .info(any(), (Object[]) any());

mock.info("Some message with {} {} {}", "three", "parameters", "");
```

Would previously have resulted in a `ClassCastException` being thrown from the `mock.info` call.  This was because the `answerVoid` method was not taking into account that the second parameter was a varargs parameter and was attempting to pass the second actual argument `"three"`, rather than the second _raw_ argument `["three", "parameters", ""]`.
  • Loading branch information
big-andy-coates committed Jun 2, 2022
1 parent 1fbef57 commit 057c795
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 20 deletions.
Expand Up @@ -4,6 +4,7 @@
*/
package org.mockito.internal.stubbing.answers;

import org.mockito.invocation.Invocation;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.mockito.stubbing.Answer1;
Expand All @@ -19,6 +20,8 @@
import org.mockito.stubbing.VoidAnswer5;
import org.mockito.stubbing.VoidAnswer6;

import java.lang.reflect.Method;

/**
* Functional interfaces to make it easy to implement answers in Java 8
*
Expand All @@ -38,11 +41,11 @@ private AnswerFunctionalInterfaces() {}
* @return a new answer object
*/
public static <T, A> Answer<T> toAnswer(final Answer1<T, A> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 1);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
public T answer(InvocationOnMock invocation) throws Throwable {
return answer.answer((A) invocation.getArgument(0));
return answer.answer(lastParameter(invocation, answerMethod, 0));
}
};
}
Expand All @@ -54,11 +57,11 @@ public T answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <A> Answer<Void> toAnswer(final VoidAnswer1<A> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 1);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
public Void answer(InvocationOnMock invocation) throws Throwable {
answer.answer((A) invocation.getArgument(0));
answer.answer(lastParameter(invocation, answerMethod, 0));
return null;
}
};
Expand All @@ -73,11 +76,13 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <T, A, B> Answer<T> toAnswer(final Answer2<T, A, B> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 2);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
public T answer(InvocationOnMock invocation) throws Throwable {
return answer.answer((A) invocation.getArgument(0), (B) invocation.getArgument(1));
return answer.answer(
(A) invocation.getArgument(0), lastParameter(invocation, answerMethod, 1));
}
};
}
Expand All @@ -90,11 +95,13 @@ public T answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <A, B> Answer<Void> toAnswer(final VoidAnswer2<A, B> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 2);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
public Void answer(InvocationOnMock invocation) throws Throwable {
answer.answer((A) invocation.getArgument(0), (B) invocation.getArgument(1));
answer.answer(
(A) invocation.getArgument(0), lastParameter(invocation, answerMethod, 1));
return null;
}
};
Expand Down Expand Up @@ -303,4 +310,35 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
}
};
}

private static Method findAnswerMethod(final Class<?> type, final int numberOfParameters) {
for (final Method m : type.getDeclaredMethods()) {
if (!m.isBridge()
&& m.getName().equals("answer")
&& m.getParameterTypes().length == numberOfParameters) {
return m;
}
}
// Todo: throw a descent error message
throw new RuntimeException("todo");
}

@SuppressWarnings("unchecked")
private static <A> A lastParameter(
InvocationOnMock invocationOnMock, Method answerMethod, int argumentIndex) {
final Method invocationMethod = invocationOnMock.getMethod();
if (invocationMethod.isVarArgs()
&& invocationMethod.getParameterTypes().length == (argumentIndex + 1)) {
final Class<?> invocationRawArgType =
invocationMethod.getParameterTypes()[argumentIndex];
final Class<?> answerRawArgType = answerMethod.getParameterTypes()[argumentIndex];
if (answerRawArgType.isAssignableFrom(
invocationRawArgType)) { // Todo: test super types.
Invocation invocation = (Invocation) invocationOnMock;
return (A) invocation.getRawArguments()[argumentIndex];
}
}

return (A) invocationOnMock.getArgument(argumentIndex);
}
}
Expand Up @@ -51,24 +51,26 @@ public ReturnsArgumentAt(int wantedArgumentPosition) {
}

@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
int argumentPosition = inferWantedArgumentPosition(invocation);
validateIndexWithinInvocationRange(invocation, argumentPosition);
public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
Invocation invocation = (Invocation) invocationOnMock;

if (wantedArgIndexIsVarargAndSameTypeAsReturnType(
invocation.getMethod(), argumentPosition)) {
if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation)) {
// answer raw vararg array argument
return ((Invocation) invocation).getRawArguments()[argumentPosition];
int rawArgumentPosition = inferWantedRawArgumentPosition(invocation);
return invocation.getRawArguments()[rawArgumentPosition];
}

int argumentPosition = inferWantedArgumentPosition(invocation);
validateIndexWithinInvocationRange(invocation, argumentPosition);

// answer expanded argument at wanted position
return invocation.getArgument(argumentPosition);
}

@Override
public void validateFor(InvocationOnMock invocation) {
int argumentPosition = inferWantedArgumentPosition(invocation);
validateIndexWithinInvocationRange(invocation, argumentPosition);
validateIndexWithinTheoreticalInvocationRange(invocation, argumentPosition);
validateArgumentTypeCompatibility((Invocation) invocation, argumentPosition);
}

Expand All @@ -80,9 +82,26 @@ private int inferWantedArgumentPosition(InvocationOnMock invocation) {
return wantedArgumentPosition;
}

private int inferWantedRawArgumentPosition(Invocation invocation) {
if (wantedArgumentPosition == LAST_ARGUMENT) {
return invocation.getRawArguments().length - 1;
}

return wantedArgumentPosition;
}

private void validateIndexWithinInvocationRange(
InvocationOnMock invocation, int argumentPosition) {
if (!wantedArgumentPositionIsValidForInvocation(invocation, argumentPosition)) {

if (argumentPosition < 0 || invocation.getArguments().length <= argumentPosition) {
throw invalidArgumentPositionRangeAtInvocationTime(
invocation, wantedArgumentPosition == LAST_ARGUMENT, wantedArgumentPosition);
}
}

private void validateIndexWithinTheoreticalInvocationRange(
InvocationOnMock invocation, int argumentPosition) {
if (!wantedArgumentPositionIsValidForTheoreticalInvocation(invocation, argumentPosition)) {
throw invalidArgumentPositionRangeAtInvocationTime(
invocation, wantedArgumentPosition == LAST_ARGUMENT, wantedArgumentPosition);
}
Expand All @@ -102,15 +121,16 @@ private void validateArgumentTypeCompatibility(Invocation invocation, int argume
}
}

private boolean wantedArgIndexIsVarargAndSameTypeAsReturnType(
Method method, int argumentPosition) {
private boolean wantedArgIndexIsVarargAndSameTypeAsReturnType(Invocation invocation) {
int rawArgumentPosition = inferWantedRawArgumentPosition(invocation);
Method method = invocation.getMethod();
Class<?>[] parameterTypes = method.getParameterTypes();
return method.isVarArgs()
&& argumentPosition == /* vararg index */ parameterTypes.length - 1
&& method.getReturnType().isAssignableFrom(parameterTypes[argumentPosition]);
&& rawArgumentPosition == /* vararg index */ parameterTypes.length - 1
&& method.getReturnType().isAssignableFrom(parameterTypes[rawArgumentPosition]);
}

private boolean wantedArgumentPositionIsValidForInvocation(
private boolean wantedArgumentPositionIsValidForTheoreticalInvocation(
InvocationOnMock invocation, int argumentPosition) {
if (argumentPosition < 0) {
return false;
Expand Down Expand Up @@ -145,7 +165,7 @@ private Class<?> inferArgumentType(Invocation invocation, int argumentIndex) {
return parameterTypes[argumentIndex];
}
// if wanted argument is vararg
if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation.getMethod(), argumentIndex)) {
if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation)) {
// return the vararg array if return type is compatible
// because the user probably want to return the array itself if the return type is
// compatible
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/mockito/invocation/InvocationOnMock.java
Expand Up @@ -56,6 +56,8 @@ public interface InvocationOnMock extends Serializable {
*/
<T> T getArgument(int index);

// Todo: add getRawArguments to this interface?

/**
* Returns casted argument at the given index. This method is analogous to
* {@link #getArgument(int)}, but is necessary to circumvent issues when dealing with generics.
Expand Down
Expand Up @@ -5,6 +5,7 @@
package org.mockitousage.stubbing;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.within;
import static org.mockito.AdditionalAnswers.answer;
import static org.mockito.AdditionalAnswers.answerVoid;
Expand All @@ -27,6 +28,7 @@
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.exceptions.base.MockitoException;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer1;
import org.mockito.stubbing.Answer2;
Expand All @@ -52,10 +54,39 @@ public void can_return_arguments_of_invocation() throws Exception {
given(iMethods.objectArgMethod(any())).will(returnsFirstArg());
given(iMethods.threeArgumentMethod(eq(0), any(), anyString())).will(returnsSecondArg());
given(iMethods.threeArgumentMethod(eq(1), any(), anyString())).will(returnsLastArg());
given(iMethods.mixedVarargsReturningString(eq(1), any())).will(returnsArgAt(2));

assertThat(iMethods.objectArgMethod("first")).isEqualTo("first");
assertThat(iMethods.threeArgumentMethod(0, "second", "whatever")).isEqualTo("second");
assertThat(iMethods.threeArgumentMethod(1, "whatever", "last")).isEqualTo("last");
assertThat(iMethods.mixedVarargsReturningString(1, "a", "b")).isEqualTo("b");
}

@Test
public void can_return_var_arguments_of_invocation() throws Exception {
given(iMethods.mixedVarargsReturningStringArray(eq(1), any())).will(returnsLastArg());
given(iMethods.mixedVarargsReturningObjectArray(eq(1), any())).will(returnsArgAt(1));

assertThat(iMethods.mixedVarargsReturningStringArray(1, "the", "var", "args"))
.containsExactlyInAnyOrder("the", "var", "args");
assertThat(iMethods.mixedVarargsReturningObjectArray(1, "the", "var", "args"))
.containsExactlyInAnyOrder("the", "var", "args");
}

@Test
public void returns_arg_at_throws_on_out_of_range_var_args() throws Exception {
given(iMethods.mixedVarargsReturningString(eq(1), any())).will(returnsArgAt(3));

assertThatThrownBy(() -> iMethods.mixedVarargsReturningString(1, "a", "b"))
.isInstanceOf(MockitoException.class)
.hasMessageContaining("Invalid argument index");

assertThatThrownBy(
() ->
given(iMethods.mixedVarargsReturningStringArray(eq(1), any()))
.will(returnsArgAt(3)))
.isInstanceOf(MockitoException.class)
.hasMessageContaining("The argument of type 'String' cannot be returned");
}

@Test
Expand Down Expand Up @@ -347,4 +378,78 @@ public void answer(
// expect the answer to write correctly to "target"
verify(target, times(1)).simpleMethod("hello", 1, 2, 3, 4, 5);
}

@Test
public void can_return_based_on_strongly_types_one_parameter_var_args_function()
throws Exception {
given(iMethods.varargs(any()))
.will(
answer(
new Answer1<Integer, String[]>() {
public Integer answer(String[] strings) {
return strings.length;
}
}));

assertThat(iMethods.varargs("some", "args")).isEqualTo(2);
}

@Test
public void will_execute_a_void_based_on_strongly_typed_one_parameter_var_args_function()
throws Exception {
final IMethods target = mock(IMethods.class);

given(iMethods.varargs(any()))
.will(
answerVoid(
new VoidAnswer1<String[]>() {
public void answer(String[] s) {
target.varargs(s);
}
}));

// invoke on iMethods
iMethods.varargs("some", "args");

// expect the answer to write correctly to "target"
verify(target, times(1)).varargs("some", "args");
}

@Test
public void can_return_based_on_strongly_typed_two_parameter_var_args_function()
throws Exception {
given(iMethods.mixedVarargsReturningString(any(), any()))
.will(
answer(
new Answer2<String, Object, String[]>() {
public String answer(Object o, String[] s) {
return String.join("-", s);
}
}));

assertThat(iMethods.mixedVarargsReturningString(1, "var", "args")).isEqualTo("var-args");
}

@Test
public void will_execute_a_void_based_on_strongly_typed_two_parameter_var_args_function()
throws Exception {
final IMethods target = mock(IMethods.class);

given(iMethods.mixedVarargsReturningString(any(), any()))
.will(
answerVoid(
new VoidAnswer2<Object, String[]>() {
public void answer(Object o, String[] s) {
target.mixedVarargsReturningString(o, s);
}
}));

// invoke on iMethods
iMethods.mixedVarargsReturningString(1, "var", "args");

// expect the answer to write correctly to "target"
verify(target, times(1)).mixedVarargsReturningString(1, "var", "args");
}

// todo: tests for other versions of answer and voidAnswer.
}

0 comments on commit 057c795

Please sign in to comment.