Skip to content

Commit

Permalink
Match extension functions (#4459)
Browse files Browse the repository at this point in the history
* Add missing tests

* Match extension functions
  • Loading branch information
BraisGabin committed Mar 24, 2022
1 parent 93d1f43 commit d7ac501
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 5 deletions.
Expand Up @@ -48,7 +48,9 @@ class ForbiddenMethodCall(config: Config = Config.empty) : Rule(config) {
"Methods can be defined without full signature (i.e. `java.time.LocalDate.now`) which will report " +
"calls of all methods with this name or with full signature " +
"(i.e. `java.time.LocalDate(java.time.Clock)`) which would report only call " +
"with this concrete signature."
"with this concrete signature. If you want to forbid an extension function like" +
"`fun String.hello(a: Int)` you should add the receiver parameter as the first parameter like this: " +
"`hello(kotlin.String, kotlin.Int)`"
)
private val methods: List<FunctionMatcher> by config(
listOf(
Expand Down
Expand Up @@ -318,5 +318,21 @@ class ForbiddenMethodCallSpec : Spek({
).compileAndLintWithContext(env, code)
assertThat(findings).hasSize(1)
}

it("should report extension functions") {
val code = """
package org.example
fun String.bar() = Unit
fun foo() {
"".bar()
}
"""
val findings = ForbiddenMethodCall(
TestConfig(mapOf(METHODS to listOf("org.example.bar(kotlin.String)")))
).compileAndLintWithContext(env, code)
assertThat(findings).hasSize(1)
}
}
})
Expand Up @@ -35,8 +35,9 @@ sealed class FunctionMatcher {
override fun match(callableDescriptor: CallableDescriptor): Boolean {
if (callableDescriptor.fqNameOrNull()?.asString() != fullyQualifiedName) return false

val encounteredParamTypes = callableDescriptor.valueParameters
.map { it.type.fqNameOrNull()?.asString() }
val encounteredParamTypes =
(listOfNotNull(callableDescriptor.extensionReceiverParameter) + callableDescriptor.valueParameters)
.map { it.type.fqNameOrNull()?.asString() }

return encounteredParamTypes == parameters
}
Expand All @@ -45,8 +46,9 @@ sealed class FunctionMatcher {
if (bindingContext == BindingContext.EMPTY) return false
if (function.name != fullyQualifiedName) return false

val encounteredParameters = function.valueParameters
.map { bindingContext[BindingContext.TYPE, it.typeReference]?.fqNameOrNull()?.toString() }
val encounteredParameters =
(listOfNotNull(function.receiverTypeReference) + function.valueParameters.map { it.typeReference })
.map { bindingContext[BindingContext.TYPE, it]?.fqNameOrNull()?.toString() }

return encounteredParameters == parameters
}
Expand Down
Expand Up @@ -186,6 +186,66 @@ class FunctionMatcherSpec(private val env: KotlinCoreEnvironment) {
val methodSignature = FunctionMatcher.fromFunctionSignature("toString(String)")
assertThat(methodSignature.match(function, bindingContext)).isEqualTo(result)
}

@DisplayName("When lambdas foo(() -> kotlin.String)")
@ParameterizedTest(name = "in case {0} it return {1}")
@CsvSource(
"fun foo(a: () -> String), true",
"fun foo(a: () -> Unit), true",
"fun foo(a: (String) -> String), false",
"fun foo(a: (String) -> Unit), false",
"fun foo(a: (Int) -> Unit), false",
)
fun `When foo(() - kotlin#String)`(code: String, result: Boolean) {
val (function, bindingContext) = buildKtFunction(env, code)
val methodSignature = FunctionMatcher.fromFunctionSignature("foo(() -> kotlin.String)")
assertThat(methodSignature.match(function, bindingContext)).isEqualTo(result)
}

@DisplayName("When lambdas foo((kotlin.String) -> Unit)")
@ParameterizedTest(name = "in case {0} it return {1}")
@CsvSource(
"fun foo(a: () -> String), false",
"fun foo(a: () -> Unit), false",
"fun foo(a: (String) -> String), true",
"fun foo(a: (String) -> Unit), true",
"fun foo(a: (Int) -> Unit), true",
)
fun `When foo((kotlin#String) - Unit)`(code: String, result: Boolean) {
val (function, bindingContext) = buildKtFunction(env, code)
val methodSignature = FunctionMatcher.fromFunctionSignature("foo((kotlin.String) -> Unit)")
assertThat(methodSignature.match(function, bindingContext)).isEqualTo(result)
}

@DisplayName("When extension functions foo(kotlin.String)")
@ParameterizedTest(name = "in case {0} it return {1}")
@CsvSource(
"fun String.foo(), true",
"fun foo(a: String), true",
"fun Int.foo(), false",
"fun String.foo(a: Int), false",
"'fun foo(a: String, ba: Int)', false",
)
fun `When foo(kotlin#String)`(code: String, result: Boolean) {
val (function, bindingContext) = buildKtFunction(env, code)
val methodSignature = FunctionMatcher.fromFunctionSignature("foo(kotlin.String)")
assertThat(methodSignature.match(function, bindingContext)).isEqualTo(result)
}

@DisplayName("When extension functions foo(kotlin.String, kotlin.Int)")
@ParameterizedTest(name = "in case {0} it return {1}")
@CsvSource(
"fun String.foo(), false",
"fun foo(a: String), false",
"fun Int.foo(), false",
"fun String.foo(a: Int), true",
"'fun foo(a: String, ba: Int)', true",
)
fun `When foo(kotlin#String, kotlin#Int)`(code: String, result: Boolean) {
val (function, bindingContext) = buildKtFunction(env, code)
val methodSignature = FunctionMatcher.fromFunctionSignature("foo(kotlin.String, kotlin.Int)")
assertThat(methodSignature.match(function, bindingContext)).isEqualTo(result)
}
}
}

Expand Down

0 comments on commit d7ac501

Please sign in to comment.