diff --git a/mockito-kotlin/src/main/kotlin/org/mockito/kotlin/Matchers.kt b/mockito-kotlin/src/main/kotlin/org/mockito/kotlin/Matchers.kt index 631c5516..19f91a78 100644 --- a/mockito-kotlin/src/main/kotlin/org/mockito/kotlin/Matchers.kt +++ b/mockito-kotlin/src/main/kotlin/org/mockito/kotlin/Matchers.kt @@ -25,9 +25,10 @@ package org.mockito.kotlin -import org.mockito.kotlin.internal.createInstance import org.mockito.ArgumentMatcher import org.mockito.ArgumentMatchers +import org.mockito.kotlin.internal.createInstance +import kotlin.reflect.KClass /** Object argument that is equal to the given value. */ fun eq(value: T): T { @@ -51,7 +52,18 @@ inline fun anyOrNull(): T { /** Matches any vararg object, including nulls. */ inline fun anyVararg(): T { - return ArgumentMatchers.any() ?: createInstance() + return anyVararg(T::class) +} + +fun anyVararg(clazz: KClass): T { + return ArgumentMatchers.argThat(VarargMatcher(clazz.java))?: createInstance(clazz) +} + +private class VarargMatcher(private val clazz: Class) : ArgumentMatcher{ + override fun matches(t: T): Boolean = true + + // In Java >= 12 you can do clazz.arrayClass() + override fun type(): Class<*> = java.lang.reflect.Array.newInstance(clazz, 0).javaClass } /** Matches any array of type T. */ diff --git a/tests/src/test/kotlin/test/MatchersTest.kt b/tests/src/test/kotlin/test/MatchersTest.kt index 4064ec16..ac3021a5 100644 --- a/tests/src/test/kotlin/test/MatchersTest.kt +++ b/tests/src/test/kotlin/test/MatchersTest.kt @@ -8,6 +8,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.kotlin.* import org.mockito.stubbing.Answer import java.io.IOException +import kotlin.reflect.KClass class MatchersTest : TestBase() { @@ -67,6 +68,14 @@ class MatchersTest : TestBase() { } } + @Test + fun anyVarargMatching() { + mock().apply { + whenever(varargBooleanResult(anyVararg())).thenReturn(true) + expect(varargBooleanResult()).toBe(true) + } + } + @Test fun anyNull_neverVerifiesAny() { mock().apply { @@ -275,7 +284,7 @@ class MatchersTest : TestBase() { /* Given */ val t = mock() // a matcher to check if any of the varargs was equals to "b" - val matcher = VarargAnyMatcher({ "b" == it }, true, false) + val matcher = VarargAnyMatcher({ "b" == it }, String::class.java, true, false) /* When */ whenever(t.varargBooleanResult(argThat(matcher))).thenAnswer(matcher) @@ -289,7 +298,7 @@ class MatchersTest : TestBase() { /* Given */ val t = mock() // a matcher to check if any of the varargs was equals to "d" - val matcher = VarargAnyMatcher({ "d" == it }, true, false) + val matcher = VarargAnyMatcher({ "d" == it }, String::class.java, true, false) /* When */ whenever(t.varargBooleanResult(argThat(matcher))).thenAnswer(matcher) @@ -317,18 +326,20 @@ class MatchersTest : TestBase() { */ private class VarargAnyMatcher( private val match: ((T) -> Boolean), + private val clazz: Class, private val success: R, private val failure: R ) : ArgumentMatcher, Answer { private var anyMatched = false override fun matches(t: T): Boolean { - anyMatched = anyMatched or match(t) - return true + @Suppress("UNCHECKED_CAST") // No idea how to solve this better + anyMatched = (t as Array).any(match) + return anyMatched } override fun answer(i: InvocationOnMock) = if (anyMatched) success else failure - override fun type(): Class<*> = Any::class.java + override fun type(): Class<*> = java.lang.reflect.Array.newInstance(clazz, 0).javaClass } }