Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Varargs handling in AdditionalAnswers #2664

Merged
merged 3 commits into from Jun 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -19,6 +19,8 @@
import org.mockito.stubbing.VoidAnswer5;
import org.mockito.stubbing.VoidAnswer6;

import java.lang.reflect.Method;

/**
* Functional interfaces to make it easy to implement answers in Java 8
*
Expand All @@ -38,11 +40,11 @@ private AnswerFunctionalInterfaces() {}
* @return a new answer object
*/
public static <T, A> Answer<T> toAnswer(final Answer1<T, A> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 1);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
public T answer(InvocationOnMock invocation) throws Throwable {
return answer.answer((A) invocation.getArgument(0));
return answer.answer(lastParameter(invocation, answerMethod, 0));
}
};
}
Expand All @@ -54,11 +56,11 @@ public T answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <A> Answer<Void> toAnswer(final VoidAnswer1<A> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 1);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
public Void answer(InvocationOnMock invocation) throws Throwable {
answer.answer((A) invocation.getArgument(0));
answer.answer(lastParameter(invocation, answerMethod, 0));
return null;
}
};
Expand All @@ -73,11 +75,13 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <T, A, B> Answer<T> toAnswer(final Answer2<T, A, B> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 2);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
public T answer(InvocationOnMock invocation) throws Throwable {
return answer.answer((A) invocation.getArgument(0), (B) invocation.getArgument(1));
return answer.answer(
(A) invocation.getArgument(0), lastParameter(invocation, answerMethod, 1));
}
};
}
Expand All @@ -90,11 +94,13 @@ public T answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <A, B> Answer<Void> toAnswer(final VoidAnswer2<A, B> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 2);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
public Void answer(InvocationOnMock invocation) throws Throwable {
answer.answer((A) invocation.getArgument(0), (B) invocation.getArgument(1));
answer.answer(
(A) invocation.getArgument(0), lastParameter(invocation, answerMethod, 1));
return null;
}
};
Expand All @@ -110,14 +116,15 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <T, A, B, C> Answer<T> toAnswer(final Answer3<T, A, B, C> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 3);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
public T answer(InvocationOnMock invocation) throws Throwable {
return answer.answer(
(A) invocation.getArgument(0),
(B) invocation.getArgument(1),
(C) invocation.getArgument(2));
lastParameter(invocation, answerMethod, 2));
}
};
}
Expand All @@ -131,14 +138,15 @@ public T answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <A, B, C> Answer<Void> toAnswer(final VoidAnswer3<A, B, C> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 3);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
public Void answer(InvocationOnMock invocation) throws Throwable {
answer.answer(
(A) invocation.getArgument(0),
(B) invocation.getArgument(1),
(C) invocation.getArgument(2));
lastParameter(invocation, answerMethod, 2));
return null;
}
};
Expand All @@ -155,6 +163,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <T, A, B, C, D> Answer<T> toAnswer(final Answer4<T, A, B, C, D> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 4);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
Expand All @@ -163,7 +172,7 @@ public T answer(InvocationOnMock invocation) throws Throwable {
(A) invocation.getArgument(0),
(B) invocation.getArgument(1),
(C) invocation.getArgument(2),
(D) invocation.getArgument(3));
lastParameter(invocation, answerMethod, 3));
}
};
}
Expand All @@ -178,6 +187,7 @@ public T answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <A, B, C, D> Answer<Void> toAnswer(final VoidAnswer4<A, B, C, D> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 4);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
Expand All @@ -186,7 +196,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
(A) invocation.getArgument(0),
(B) invocation.getArgument(1),
(C) invocation.getArgument(2),
(D) invocation.getArgument(3));
lastParameter(invocation, answerMethod, 3));
return null;
}
};
Expand All @@ -204,6 +214,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <T, A, B, C, D, E> Answer<T> toAnswer(final Answer5<T, A, B, C, D, E> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 5);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
Expand All @@ -213,7 +224,7 @@ public T answer(InvocationOnMock invocation) throws Throwable {
(B) invocation.getArgument(1),
(C) invocation.getArgument(2),
(D) invocation.getArgument(3),
(E) invocation.getArgument(4));
lastParameter(invocation, answerMethod, 4));
}
};
}
Expand All @@ -229,6 +240,7 @@ public T answer(InvocationOnMock invocation) throws Throwable {
* @return a new answer object
*/
public static <A, B, C, D, E> Answer<Void> toAnswer(final VoidAnswer5<A, B, C, D, E> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 5);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
Expand All @@ -238,7 +250,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
(B) invocation.getArgument(1),
(C) invocation.getArgument(2),
(D) invocation.getArgument(3),
(E) invocation.getArgument(4));
lastParameter(invocation, answerMethod, 4));
return null;
}
};
Expand All @@ -259,6 +271,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
*/
public static <T, A, B, C, D, E, F> Answer<T> toAnswer(
final Answer6<T, A, B, C, D, E, F> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 6);
return new Answer<T>() {
@Override
@SuppressWarnings("unchecked")
Expand All @@ -269,7 +282,7 @@ public T answer(InvocationOnMock invocation) throws Throwable {
(C) invocation.getArgument(2),
(D) invocation.getArgument(3),
(E) invocation.getArgument(4),
(F) invocation.getArgument(5));
lastParameter(invocation, answerMethod, 5));
}
};
}
Expand All @@ -288,6 +301,7 @@ public T answer(InvocationOnMock invocation) throws Throwable {
*/
public static <A, B, C, D, E, F> Answer<Void> toAnswer(
final VoidAnswer6<A, B, C, D, E, F> answer) {
final Method answerMethod = findAnswerMethod(answer.getClass(), 6);
return new Answer<Void>() {
@Override
@SuppressWarnings("unchecked")
Expand All @@ -298,9 +312,42 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
(C) invocation.getArgument(2),
(D) invocation.getArgument(3),
(E) invocation.getArgument(4),
(F) invocation.getArgument(5));
lastParameter(invocation, answerMethod, 5));
return null;
}
};
}

private static Method findAnswerMethod(final Class<?> type, final int numberOfParameters) {
for (final Method m : type.getDeclaredMethods()) {
if (!m.isBridge()
&& m.getName().equals("answer")
&& m.getParameterTypes().length == numberOfParameters) {
return m;
}
}
throw new IllegalStateException(
"Failed to find answer() method on the supplied class: "
+ type.getName()
+ ", with the supplied number of parameters: "
+ numberOfParameters);
}

@SuppressWarnings("unchecked")
private static <A> A lastParameter(
InvocationOnMock invocation, Method answerMethod, int argumentIndex) {
final Method invocationMethod = invocation.getMethod();

if (invocationMethod.isVarArgs()
&& invocationMethod.getParameterTypes().length == (argumentIndex + 1)) {
final Class<?> invocationRawArgType =
invocationMethod.getParameterTypes()[argumentIndex];
final Class<?> answerRawArgType = answerMethod.getParameterTypes()[argumentIndex];
if (answerRawArgType.isAssignableFrom(invocationRawArgType)) {
return (A) invocation.getRawArguments()[argumentIndex];
}
}

return invocation.getArgument(argumentIndex);
}
}
Expand Up @@ -52,24 +52,24 @@ public ReturnsArgumentAt(int wantedArgumentPosition) {

@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
int argumentPosition = inferWantedArgumentPosition(invocation);
validateIndexWithinInvocationRange(invocation, argumentPosition);

if (wantedArgIndexIsVarargAndSameTypeAsReturnType(
invocation.getMethod(), argumentPosition)) {
if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation)) {
// answer raw vararg array argument
return ((Invocation) invocation).getRawArguments()[argumentPosition];
TimvdLippe marked this conversation as resolved.
Show resolved Hide resolved
return invocation.getRawArguments()[invocation.getRawArguments().length - 1];
}

int argumentPosition = inferWantedArgumentPosition(invocation);
validateIndexWithinInvocationRange(invocation, argumentPosition);

// answer expanded argument at wanted position
return invocation.getArgument(argumentPosition);
}

@Override
public void validateFor(InvocationOnMock invocation) {
public void validateFor(InvocationOnMock invocationOnMock) {
Invocation invocation = (Invocation) invocationOnMock;
int argumentPosition = inferWantedArgumentPosition(invocation);
validateIndexWithinInvocationRange(invocation, argumentPosition);
validateArgumentTypeCompatibility((Invocation) invocation, argumentPosition);
validateIndexWithinTheoreticalInvocationRange(invocation, argumentPosition);
big-andy-coates marked this conversation as resolved.
Show resolved Hide resolved
validateArgumentTypeCompatibility(invocation, argumentPosition);
}

private int inferWantedArgumentPosition(InvocationOnMock invocation) {
Expand All @@ -80,9 +80,26 @@ private int inferWantedArgumentPosition(InvocationOnMock invocation) {
return wantedArgumentPosition;
}

private int inferWantedRawArgumentPosition(InvocationOnMock invocation) {
if (wantedArgumentPosition == LAST_ARGUMENT) {
return invocation.getRawArguments().length - 1;
}

return wantedArgumentPosition;
}

private void validateIndexWithinInvocationRange(
InvocationOnMock invocation, int argumentPosition) {
if (!wantedArgumentPositionIsValidForInvocation(invocation, argumentPosition)) {

if (invocation.getArguments().length <= argumentPosition) {
throw invalidArgumentPositionRangeAtInvocationTime(
invocation, wantedArgumentPosition == LAST_ARGUMENT, wantedArgumentPosition);
}
}

private void validateIndexWithinTheoreticalInvocationRange(
InvocationOnMock invocation, int argumentPosition) {
if (!wantedArgumentPositionIsValidForTheoreticalInvocation(invocation, argumentPosition)) {
throw invalidArgumentPositionRangeAtInvocationTime(
invocation, wantedArgumentPosition == LAST_ARGUMENT, wantedArgumentPosition);
}
Expand All @@ -102,15 +119,16 @@ private void validateArgumentTypeCompatibility(Invocation invocation, int argume
}
}

private boolean wantedArgIndexIsVarargAndSameTypeAsReturnType(
Method method, int argumentPosition) {
private boolean wantedArgIndexIsVarargAndSameTypeAsReturnType(InvocationOnMock invocation) {
int rawArgumentPosition = inferWantedRawArgumentPosition(invocation);
Method method = invocation.getMethod();
Class<?>[] parameterTypes = method.getParameterTypes();
return method.isVarArgs()
&& argumentPosition == /* vararg index */ parameterTypes.length - 1
&& method.getReturnType().isAssignableFrom(parameterTypes[argumentPosition]);
&& rawArgumentPosition == /* vararg index */ parameterTypes.length - 1
&& method.getReturnType().isAssignableFrom(parameterTypes[rawArgumentPosition]);
}

private boolean wantedArgumentPositionIsValidForInvocation(
private boolean wantedArgumentPositionIsValidForTheoreticalInvocation(
InvocationOnMock invocation, int argumentPosition) {
if (argumentPosition < 0) {
return false;
Expand Down Expand Up @@ -145,7 +163,7 @@ private Class<?> inferArgumentType(Invocation invocation, int argumentIndex) {
return parameterTypes[argumentIndex];
}
// if wanted argument is vararg
if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation.getMethod(), argumentIndex)) {
if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation)) {
// return the vararg array if return type is compatible
// because the user probably want to return the array itself if the return type is
// compatible
Expand Down
Expand Up @@ -12,7 +12,6 @@
import java.lang.reflect.Method;

import org.mockito.internal.configuration.plugins.Plugins;
import org.mockito.invocation.Invocation;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.plugins.MemberAccessor;
import org.mockito.stubbing.Answer;
Expand Down Expand Up @@ -45,7 +44,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable {
}

MemberAccessor accessor = Plugins.getMemberAccessor();
Object[] rawArguments = ((Invocation) invocation).getRawArguments();
Object[] rawArguments = invocation.getRawArguments();
return accessor.invoke(delegateMethod, delegatedObject, rawArguments);
} catch (NoSuchMethodException e) {
throw delegatedMethodDoesNotExistOnDelegate(
Expand Down
9 changes: 1 addition & 8 deletions src/main/java/org/mockito/invocation/Invocation.java
Expand Up @@ -41,14 +41,6 @@ public interface Invocation extends InvocationOnMock, DescribedInvocation {
@Override
Location getLocation();

/**
* Returns unprocessed arguments whereas {@link #getArguments()} returns
* arguments already processed (e.g. varargs expended, etc.).
*
* @return unprocessed arguments, exactly as provided to this invocation.
*/
Object[] getRawArguments();

/**
* Wraps each argument using {@link org.mockito.ArgumentMatchers#eq(Object)} or
* {@link org.mockito.AdditionalMatchers#aryEq(Object[])}
Expand All @@ -64,6 +56,7 @@ public interface Invocation extends InvocationOnMock, DescribedInvocation {
* arguments already processed (e.g. varargs expended, etc.).
*
* @return unprocessed arguments, exactly as provided to this invocation.
* @since 4.7.0
*/
Class<?> getRawReturnType();

Expand Down
9 changes: 9 additions & 0 deletions src/main/java/org/mockito/invocation/InvocationOnMock.java
Expand Up @@ -32,6 +32,15 @@ public interface InvocationOnMock extends Serializable {
*/
Method getMethod();

/**
* Returns unprocessed arguments whereas {@link #getArguments()} returns
* arguments already processed (e.g. varargs expended, etc.).
*
* @return unprocessed arguments, exactly as provided to this invocation.
big-andy-coates marked this conversation as resolved.
Show resolved Hide resolved
* @since 4.7.0
*/
Object[] getRawArguments();

/**
* Returns arguments passed to the method.
*
Expand Down