diff --git a/detekt-core/src/main/resources/default-detekt-config.yml b/detekt-core/src/main/resources/default-detekt-config.yml index 4874ec93ed69..eb4d620d923c 100644 --- a/detekt-core/src/main/resources/default-detekt-config.yml +++ b/detekt-core/src/main/resources/default-detekt-config.yml @@ -190,6 +190,8 @@ coroutines: active: true SleepInsteadOfDelay: active: true + SuspendFunInsideRunCatching: + active: false SuspendFunWithCoroutineScopeReceiver: active: false SuspendFunWithFlowReturnType: diff --git a/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/CoroutinesProvider.kt b/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/CoroutinesProvider.kt index fe9ade01a858..79b477b77f86 100644 --- a/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/CoroutinesProvider.kt +++ b/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/CoroutinesProvider.kt @@ -21,7 +21,8 @@ class CoroutinesProvider : DefaultRuleSetProvider { RedundantSuspendModifier(config), SleepInsteadOfDelay(config), SuspendFunWithFlowReturnType(config), - SuspendFunWithCoroutineScopeReceiver(config) + SuspendFunWithCoroutineScopeReceiver(config), + SuspendFunInsideRunCatching(config), ) ) } diff --git a/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunInsideRunCatching.kt b/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunInsideRunCatching.kt new file mode 100644 index 000000000000..8c1dd75c6c3d --- /dev/null +++ b/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunInsideRunCatching.kt @@ -0,0 +1,140 @@ +package io.gitlab.arturbosch.detekt.rules.coroutines + +import io.gitlab.arturbosch.detekt.api.CodeSmell +import io.gitlab.arturbosch.detekt.api.Config +import io.gitlab.arturbosch.detekt.api.Debt +import io.gitlab.arturbosch.detekt.api.Entity +import io.gitlab.arturbosch.detekt.api.Issue +import io.gitlab.arturbosch.detekt.api.Rule +import io.gitlab.arturbosch.detekt.api.Severity +import io.gitlab.arturbosch.detekt.api.internal.RequiresTypeResolution +import org.jetbrains.kotlin.backend.common.descriptors.isSuspend +import org.jetbrains.kotlin.descriptors.CallableDescriptor +import org.jetbrains.kotlin.descriptors.FunctionDescriptor +import org.jetbrains.kotlin.name.FqName +import org.jetbrains.kotlin.psi.KtCallExpression +import org.jetbrains.kotlin.psi.psiUtil.findDescendantOfType +import org.jetbrains.kotlin.psi.psiUtil.forEachDescendantOfType +import org.jetbrains.kotlin.psi.psiUtil.getParentOfTypesAndPredicate +import org.jetbrains.kotlin.resolve.BindingContext +import org.jetbrains.kotlin.resolve.calls.util.getResolvedCall +import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe + +/** + * Suspend functions should not be called inside `runCatching` as `runCatching` catches + * all the exception while for Coroutine cooperative cancellation to work, we have to + * never catch the `CancellationException` exception or rethrowing it again if caught + * + * See https://kotlinlang.org/docs/cancellation-and-timeouts.html#cancellation-is-cooperative + * + * + * suspend fun bar(delay: Long) { + * check(delay <= 1_000L) + * delay(delay) + * } + * + * suspend fun foo() { + * runCatching { + * bar(1_000L) + * } + * } + * + * + * + * suspend fun bar(delay: Long) { + * check(delay <= 1_000L) + * delay(delay) + * } + * + * suspend fun foo() { + * try { + * bar(1_000L) + * } catch (e: IllegalStateException) { + * // handle error + * } + * } + * + * // Alternate + * suspend fun foo() { + * bar(1_000L) + * } + * + * + */ +@RequiresTypeResolution +class SuspendFunInsideRunCatching(config: Config) : Rule(config) { + override val issue = Issue( + id = "SuspendFunInsideRunCatching", + severity = Severity.Minor, + description = "The `suspend` functions should be called inside `runCatching` block as it also swallows " + + "`CancellationException` which is important for cooperative cancellation." + + "You should either use specific `try-catch` only catching exception that you are expecting" + + " or rethrow the `CancellationException` if already caught", + debt = Debt.TEN_MINS + ) + + override fun visitCallExpression(expression: KtCallExpression) { + super.visitCallExpression(expression) + + val resultingDescriptor = expression.getResolvedCall(bindingContext)?.resultingDescriptor + resultingDescriptor ?: return + if (resultingDescriptor.fqNameSafe != RUN_CATCHING_FQ) return + + expression.forEachDescendantOfType { descendant -> + if (descendant.getResolvedCall(bindingContext)?.resultingDescriptor?.isSuspend == true && shouldReport( + resultingDescriptor, + descendant, + bindingContext, + ) + ) { + val message = + "The suspend function call ${descendant.text} is inside `runCatching`. You should either " + + "use specific `try-catch` only catching exception that you are expecting or rethrow the " + + "`CancellationException` if already caught" + report( + CodeSmell( + issue, + Entity.from(expression), + message + ) + ) + } + } + } + + private fun shouldReport( + runCatchingCallableDescriptor: CallableDescriptor, + callExpression: KtCallExpression, + bindingContext: BindingContext, + ): Boolean { + val firstNonInlineOrRunCatchingParent = + callExpression.getParentOfTypesAndPredicate(true, KtCallExpression::class.java) { parentCallExp -> + val parentCallFunctionDescriptor = + parentCallExp.getResolvedCall(bindingContext)?.resultingDescriptor as? FunctionDescriptor + parentCallFunctionDescriptor ?: return@getParentOfTypesAndPredicate false + + val isParentRunCatching = parentCallFunctionDescriptor.fqNameSafe == RUN_CATCHING_FQ + val isInline = parentCallFunctionDescriptor.isInline + val noInlineAndCrossInlineValueParametersIndex = + parentCallFunctionDescriptor.valueParameters.filter { valueParameterDescriptor -> + valueParameterDescriptor.isCrossinline || valueParameterDescriptor.isNoinline + }.map { + it.index + } + val callExpressionIndexInParentCall = parentCallExp.valueArguments.indexOfFirst { valueArgument -> + valueArgument?.findDescendantOfType { + it == callExpression + } != null + } + isParentRunCatching || + isInline.not() || + noInlineAndCrossInlineValueParametersIndex.contains(callExpressionIndexInParentCall) + } + return firstNonInlineOrRunCatchingParent.getResolvedCall(bindingContext)?.resultingDescriptor == + runCatchingCallableDescriptor + } + + companion object { + private val RUN_CATCHING_FQ = FqName("kotlin.runCatching") + } +} diff --git a/detekt-rules-coroutines/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunInsideRunCatchingSpec.kt b/detekt-rules-coroutines/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunInsideRunCatchingSpec.kt new file mode 100644 index 000000000000..aad89f0109b9 --- /dev/null +++ b/detekt-rules-coroutines/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunInsideRunCatchingSpec.kt @@ -0,0 +1,531 @@ +package io.gitlab.arturbosch.detekt.rules.coroutines + +import io.gitlab.arturbosch.detekt.api.Config +import io.gitlab.arturbosch.detekt.api.Finding +import io.gitlab.arturbosch.detekt.rules.KotlinCoreEnvironmentTest +import io.gitlab.arturbosch.detekt.test.compileAndLintWithContext +import org.assertj.core.api.Assertions.assertThat +import org.jetbrains.kotlin.cli.jvm.compiler.KotlinCoreEnvironment +import org.junit.jupiter.api.Test + +@KotlinCoreEnvironmentTest +class SuspendFunInsideRunCatchingSpec(private val env: KotlinCoreEnvironment) { + + private val subject = SuspendFunInsideRunCatching(Config.empty) + + @Test + fun `does report suspend function call in runCatching`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + runCatching { + delay(1000L) + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "delay(1000L)") + } + + @Test + fun `does report for in case of nested runCatching`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + runCatching { + delay(1000L) + runCatching { + delay(2000L) + } + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "delay(1000L)", "delay(2000L)") + } + + @Test + fun `does report for delay() in suspend functions`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + runCatching { + delay(1000L) + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "delay(1000L)") + } + + @Test + fun `does not report no suspending function is used inside runBlocking`() { + val code = """ + suspend fun foo() { + runCatching { + Thread.sleep(1000L) + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings) + } + + @Test + fun `does not report when try catch is used`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + try { + delay(1000L) + } catch (e: IllegalStateException) { + // handle error + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings) + } + + @Test + fun `does report when suspend fun is called inside inline function`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + runCatching { + listOf(1L, 2L, 3L).map { + delay(it) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "delay(it)") + } + + @Test + fun `does report when inside inline function with noinline and cross inline parameters in same order`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + import kotlinx.coroutines.runBlocking + + inline fun foo( + noinline noinlineBlock: suspend () -> Unit, + inlineBlock: () -> Unit, + crossinline crossinlineBlock: suspend () -> Unit, + ) = inlineBlock().toString() + MainScope().launch { + noinlineBlock() + }.toString() + runBlocking { + crossinlineBlock() + }.toString() + noinlineBlock() + + suspend fun bar() + { + runCatching { + foo( + noinlineBlock = { + delay(2000L) + }, + inlineBlock = { delay(1000L) }, + ) { + + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "delay(1000L)") + } + + // Failing + @Test + fun `does report when inside inline function with noinline and cross inline parameters not in same order`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + import kotlinx.coroutines.runBlocking + + inline fun foo( + noinline noinlineBlock: suspend () -> Unit, + inlineBlock: () -> Unit, + crossinline crossinlineBlock: suspend () -> Unit, + ) = inlineBlock().toString() + MainScope().launch { + noinlineBlock() + }.toString() + runBlocking { + crossinlineBlock() + }.toString() + noinlineBlock() + + suspend fun bar() + { + runCatching { + foo( + inlineBlock = { delay(1000L) }, + noinlineBlock = { + delay(2000L) + }, + ) { + + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "delay(1000L)") + } + + @Test + fun `does report when lambda in inline function is passed as crossinline`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + + inline fun foo(crossinline block: suspend () -> R) = MainScope().launch { + block() + } + suspend fun bar() { + runCatching { + foo { + delay(1000L) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings) + } + + @Test + fun `does not report when lambda in inline function is passed as noinline`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + + inline fun foo(noinline block: suspend () -> R) = MainScope().launch { + block() + } + suspend fun suspendFun() = delay(1000) + suspend fun bar() { + runCatching { + foo { + delay(1000L) + } + + val baz = suspend { + delay(1000L) + } + foo(baz) + foo(::suspendFun) + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings) + } + + @Test + fun `does report when suspend fun is called as extension function`() { + val code = """ + import kotlinx.coroutines.delay + + private suspend fun List.await() = delay(100L) + + suspend fun foo() { + runCatching { + listOf(1L, 2L, 3L).await() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).hasSize(1) + assertFindings(findings, "await()") + } + + // Failing + @Test + fun `does report when suspending iterator is used`() { + val code = """ + import kotlinx.coroutines.delay + + class SuspendingIterator { + suspend operator fun iterator(): Iterator = iterator { yield("value") } + } + + suspend fun bar() { + runCatching { + for (x in SuspendingIterator()) { + println(x) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "iterator.next()") + } + + @Test + fun `does report when suspend function is invoked`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + val suspendBlock = suspend { } + runCatching { + suspendBlock() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "suspendBlock()") + } + + @Test + fun `does report in case of suspendCancellableCoroutine`() { + val code = """ + import kotlinx.coroutines.delay + import kotlinx.coroutines.suspendCancellableCoroutine + + suspend fun foo() { + runCatching { + suspendCancellableCoroutine { + it.resume(Unit) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).hasSize(1) + assertFindings( + findings, + """ + suspendCancellableCoroutine { + it.resume(Unit) + } + """.trimIndent() + ) + } + + @Test + fun `does report in case suspend callable refernce is invoked`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun bar() = delay(1000) + suspend fun foo() { + runCatching { + ::bar.invoke() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "invoke()") + } + + @Test + fun `does report in case suspend local function is invoked`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + suspend fun localFun() = delay(1000L) + runCatching { + localFun() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).hasSize(1) + assertFindings(findings, "localFun()") + } + + @Test + fun `does report in suspend operator is invoked`() { + val code = """ + import kotlinx.coroutines.delay + + class C { + suspend operator fun invoke() = delay(1000L) + } + + suspend fun foo() { + runCatching { + C()() + C().invoke() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "C()()", "invoke()") + } + + @Test + fun `does not report coroutine is launched`() { + val code = """ + import kotlinx.coroutines.delay + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.launch + + suspend fun foo() { + runCatching { + MainScope().launch { + delay(1000L) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings) + } + + @Test + fun `does report when job is joined`() { + val code = """ + import kotlinx.coroutines.delay + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.launch + + suspend fun foo() { + runCatching { + MainScope().launch { + delay(1000L) + }.join() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "join()") + } + + @Test + fun `does not report async is used`() { + val code = """ + import kotlinx.coroutines.async + import kotlinx.coroutines.delay + import kotlinx.coroutines.MainScope + + suspend fun foo() { + runCatching { + MainScope().async { + delay(1000L) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings) + } + + @Test + fun `does report async is awaited`() { + val code = """ + import kotlinx.coroutines.async + import kotlinx.coroutines.delay + import kotlinx.coroutines.MainScope + + suspend fun foo() { + runCatching { + MainScope().async { + delay(1000L) + }.await() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "await()") + } + + @Test + fun `does not report when suspend block is passed to non inline function`() { + val code = """ + import kotlinx.coroutines.delay + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.launch + + fun bar(lambda: suspend () -> Unit) { + MainScope().launch { lambda() } + } + + suspend fun foo() { + runCatching { + bar { + delay(1000L) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings) + } + + @Test + fun `does not report when suspend fun is called inside runBlocking`() { + val code = """ + import kotlinx.coroutines.delay + import kotlinx.coroutines.runBlocking + + suspend fun foo() { + runCatching { + runBlocking { + delay(1000L) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings) + } + + @Test + fun `does report when suspend fun is called in string interpolation`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + runCatching { + val string = "${'$'}{delay(1000)}" + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindings(findings, "delay(1000)") + } + + private fun assertFindings(findings: List, vararg funCallExpression: String) { + assertThat(findings).hasSize(funCallExpression.size) + assertThat(findings.map { it.message }).containsExactlyInAnyOrder( + *funCallExpression.map { + "The suspend function call $it is inside `runCatching`. You should either " + + "use specific `try-catch` only catching exception that you are expecting or rethrow the " + + "`CancellationException` if already caught" + }.toTypedArray() + ) + } +}