From a3c2bbc6924dd0712adfaf31576634b636a1104d Mon Sep 17 00:00:00 2001 From: Atul Gupta Date: Wed, 5 Apr 2023 16:05:31 +0530 Subject: [PATCH] Add SuspendFunSwallowedCancellation rule (#5666) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: RĂ³bert Papp --- .../main/resources/default-detekt-config.yml | 2 + .../rules/coroutines/CoroutinesProvider.kt | 3 +- .../SuspendFunSwallowedCancellation.kt | 188 +++ .../SuspendFunSwallowedCancellationSpec.kt | 1280 +++++++++++++++++ 4 files changed, 1472 insertions(+), 1 deletion(-) create mode 100644 detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunSwallowedCancellation.kt create mode 100644 detekt-rules-coroutines/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunSwallowedCancellationSpec.kt diff --git a/detekt-core/src/main/resources/default-detekt-config.yml b/detekt-core/src/main/resources/default-detekt-config.yml index ce69fea8ae2..b9b98209a9d 100644 --- a/detekt-core/src/main/resources/default-detekt-config.yml +++ b/detekt-core/src/main/resources/default-detekt-config.yml @@ -190,6 +190,8 @@ coroutines: active: true SleepInsteadOfDelay: active: true + SuspendFunSwallowedCancellation: + active: false SuspendFunWithCoroutineScopeReceiver: active: false SuspendFunWithFlowReturnType: diff --git a/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/CoroutinesProvider.kt b/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/CoroutinesProvider.kt index fe9ade01a85..b4661c86d04 100644 --- a/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/CoroutinesProvider.kt +++ b/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/CoroutinesProvider.kt @@ -21,7 +21,8 @@ class CoroutinesProvider : DefaultRuleSetProvider { RedundantSuspendModifier(config), SleepInsteadOfDelay(config), SuspendFunWithFlowReturnType(config), - SuspendFunWithCoroutineScopeReceiver(config) + SuspendFunWithCoroutineScopeReceiver(config), + SuspendFunSwallowedCancellation(config), ) ) } diff --git a/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunSwallowedCancellation.kt b/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunSwallowedCancellation.kt new file mode 100644 index 00000000000..45265bf1b48 --- /dev/null +++ b/detekt-rules-coroutines/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunSwallowedCancellation.kt @@ -0,0 +1,188 @@ +package io.gitlab.arturbosch.detekt.rules.coroutines + +import io.gitlab.arturbosch.detekt.api.CodeSmell +import io.gitlab.arturbosch.detekt.api.Config +import io.gitlab.arturbosch.detekt.api.Debt +import io.gitlab.arturbosch.detekt.api.Entity +import io.gitlab.arturbosch.detekt.api.Issue +import io.gitlab.arturbosch.detekt.api.Rule +import io.gitlab.arturbosch.detekt.api.Severity +import io.gitlab.arturbosch.detekt.api.internal.RequiresTypeResolution +import org.jetbrains.kotlin.builtins.StandardNames.COROUTINES_PACKAGE_FQ_NAME +import org.jetbrains.kotlin.com.intellij.psi.PsiElement +import org.jetbrains.kotlin.descriptors.FunctionDescriptor +import org.jetbrains.kotlin.descriptors.PropertyDescriptor +import org.jetbrains.kotlin.name.FqName +import org.jetbrains.kotlin.name.Name +import org.jetbrains.kotlin.psi.KtCallExpression +import org.jetbrains.kotlin.psi.KtExpression +import org.jetbrains.kotlin.psi.KtForExpression +import org.jetbrains.kotlin.psi.KtNameReferenceExpression +import org.jetbrains.kotlin.psi.KtOperationExpression +import org.jetbrains.kotlin.psi.KtValueArgument +import org.jetbrains.kotlin.psi.psiUtil.anyDescendantOfType +import org.jetbrains.kotlin.psi.psiUtil.getParentOfType +import org.jetbrains.kotlin.resolve.BindingContext +import org.jetbrains.kotlin.resolve.calls.util.getParameterForArgument +import org.jetbrains.kotlin.resolve.calls.util.getResolvedCall +import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe +import org.jetbrains.kotlin.utils.addToStdlib.ifTrue + +/** + * `suspend` functions should not be called inside `runCatching`'s lambda block, because `runCatching` catches all the + * `Exception`s. For Coroutines to work in all cases, developers should make sure to propagate `CancellationException` + * exceptions. This means `CancellationException` should never be: + * * caught and swallowed (even if logged) + * * caught and propagated to external systems + * * caught and shown to the user + * + * they must always be rethrown in the same context. + * + * Using `runCatching` increases this risk of mis-handling cancellation. If you catch and don't rethrow all the + * `CancellationException`, your coroutines are not cancelled even if you cancel their `CoroutineScope`. + * + * This can very easily lead to: + * * unexpected crashes + * * extremely hard to diagnose bugs + * * memory leaks + * * performance issues + * * battery drain + * + * See reference, [Kotlin doc](https://kotlinlang.org/docs/cancellation-and-timeouts.html#cancellation-is-cooperative). + * + * If your project wants to use `runCatching` and `Result` objects, it is recommended to write a `coRunCatching` + * utility function which immediately re-throws `CancellationException`; and forbid `runCatching` and `suspend` + * combinations by activating this rule. + * + * + * @@Throws(IllegalStateException::class) + * suspend fun bar(delay: Long) { + * check(delay <= 1_000L) + * delay(delay) + * } + * + * suspend fun foo() { + * runCatching { + * bar(1_000L) + * } + * } + * + * + * + * @@Throws(IllegalStateException::class) + * suspend fun bar(delay: Long) { + * check(delay <= 1_000L) + * delay(delay) + * } + * + * suspend fun foo() { + * try { + * bar(1_000L) + * } catch (e: IllegalStateException) { + * // handle error + * } + * } + * + * // Alternate + * @@Throws(IllegalStateException::class) + * suspend fun foo() { + * bar(1_000L) + * } + * + * + */ +@RequiresTypeResolution +class SuspendFunSwallowedCancellation(config: Config) : Rule(config) { + override val issue = Issue( + id = javaClass.simpleName, + severity = Severity.Minor, + description = "`runCatching` does not propagate `CancellationException`, don't use it with `suspend` lambda " + + "blocks.", + debt = Debt.TEN_MINS + ) + + override fun visitCallExpression(expression: KtCallExpression) { + super.visitCallExpression(expression) + + val resultingDescriptor = expression.getResolvedCall(bindingContext)?.resultingDescriptor ?: return + + if (resultingDescriptor.fqNameSafe != RUN_CATCHING_FQ) return + + fun shouldTraverseInside(element: PsiElement): Boolean = + expression == element || shouldTraverseInside(element, bindingContext) + + expression.anyDescendantOfType(::shouldTraverseInside) { descendant -> + descendant.hasSuspendCalls() + }.ifTrue { report(expression) } + } + + @Suppress("ReturnCount") + private fun shouldTraverseInside(psiElement: PsiElement, bindingContext: BindingContext): Boolean { + return when (psiElement) { + is KtCallExpression -> { + val callableDescriptor = + (psiElement.getResolvedCall(bindingContext)?.resultingDescriptor as? FunctionDescriptor) + ?: return false + + callableDescriptor.fqNameSafe != RUN_CATCHING_FQ && callableDescriptor.isInline + } + is KtValueArgument -> { + val callExpression = psiElement.getParentOfType(true) + val valueParameterDescriptor = + callExpression?.getResolvedCall(bindingContext)?.getParameterForArgument(psiElement) ?: return false + + valueParameterDescriptor.isCrossinline.not() && valueParameterDescriptor.isNoinline.not() + } + else -> true + } + } + + @Suppress("ReturnCount") + private fun KtExpression.hasSuspendCalls(): Boolean { + return when (this) { + is KtForExpression -> { + val loopRangeIterator = bindingContext[BindingContext.LOOP_RANGE_ITERATOR_RESOLVED_CALL, loopRange] + val loopRangeHasNext = + bindingContext[BindingContext.LOOP_RANGE_HAS_NEXT_RESOLVED_CALL, loopRange] + val loopRangeNext = bindingContext[BindingContext.LOOP_RANGE_NEXT_RESOLVED_CALL, loopRange] + listOf(loopRangeIterator, loopRangeHasNext, loopRangeNext).any { + it?.resultingDescriptor?.isSuspend == true + } + } + is KtCallExpression, is KtOperationExpression -> { + val resolvedCall = getResolvedCall(bindingContext) ?: return false + (resolvedCall.resultingDescriptor as? FunctionDescriptor)?.isSuspend == true + } + is KtNameReferenceExpression -> { + val resolvedCall = getResolvedCall(bindingContext) ?: return false + val propertyDescriptor = resolvedCall.resultingDescriptor as? PropertyDescriptor + propertyDescriptor?.fqNameSafe == COROUTINE_CONTEXT_FQ_NAME + } + else -> { + false + } + } + } + + private fun report( + expression: KtCallExpression, + ) { + report( + CodeSmell( + issue, + Entity.from((expression.calleeExpression as? PsiElement) ?: expression), + "The `runCatching` has suspend call inside. You should either use specific `try-catch` " + + "only catching exception that you are expecting or rethrow the `CancellationException` if " + + "already caught." + ) + ) + } + + companion object { + private val RUN_CATCHING_FQ = FqName("kotlin.runCatching") + + // Based on code from Kotlin project: + // https://github.com/JetBrains/kotlin/commit/87bbac9d43e15557a2ff0dc3254fd41a9d5639e1 + private val COROUTINE_CONTEXT_FQ_NAME = COROUTINES_PACKAGE_FQ_NAME.child(Name.identifier("coroutineContext")) + } +} diff --git a/detekt-rules-coroutines/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunSwallowedCancellationSpec.kt b/detekt-rules-coroutines/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunSwallowedCancellationSpec.kt new file mode 100644 index 00000000000..708bc67f957 --- /dev/null +++ b/detekt-rules-coroutines/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/coroutines/SuspendFunSwallowedCancellationSpec.kt @@ -0,0 +1,1280 @@ +package io.gitlab.arturbosch.detekt.rules.coroutines + +import io.gitlab.arturbosch.detekt.api.Config +import io.gitlab.arturbosch.detekt.api.Finding +import io.gitlab.arturbosch.detekt.api.SourceLocation +import io.gitlab.arturbosch.detekt.rules.KotlinCoreEnvironmentTest +import io.gitlab.arturbosch.detekt.test.assertThat +import io.gitlab.arturbosch.detekt.test.compileAndLintWithContext +import org.jetbrains.kotlin.cli.jvm.compiler.KotlinCoreEnvironment +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test + +@KotlinCoreEnvironmentTest +class SuspendFunSwallowedCancellationSpec(private val env: KotlinCoreEnvironment) { + + private val subject = SuspendFunSwallowedCancellation(Config.empty) + + @Test + fun `does report suspend function call in runCatching`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + runCatching { + delay(1000L) + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(4, 5)), + listOf(SourceLocation(4, 16)) + ) + } + + @Test + fun `does report for in case of nested runCatching`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun bar() = delay(2000) + + suspend fun foo() { + runCatching { + delay(1000L) + runCatching { + bar() + } + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(6, 5), SourceLocation(8, 9)), + listOf(SourceLocation(6, 16), SourceLocation(8, 20)) + ) + } + + @Test + fun `does report for in case of nested runCatching with suspend fun call in inner runCatching`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + + suspend fun bar() = delay(2000) + + suspend fun foo() { + runCatching { + MainScope().launch { + delay(1000L) + runCatching { + bar() + } + } + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(11, 13)), + listOf(SourceLocation(11, 24)) + ) + } + + @Test + fun `does report for in case of nested runCatching with suspend fun call in outer runCatching`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + + suspend fun bar() = delay(2000) + + suspend fun foo() { + runCatching { + delay(1000L) + runCatching { + MainScope().launch { + bar() + } + } + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(8, 5)), + listOf(SourceLocation(8, 16)) + ) + } + + @Test + fun `does report for delay() in suspend functions`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + runCatching { + delay(1000L) + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(4, 5)), + listOf(SourceLocation(4, 16)) + ) + } + + @Test + fun `does report for coroutineContext in suspend functions`() { + val code = """ + import kotlinx.coroutines.delay + import kotlin.coroutines.coroutineContext + + suspend fun foo() { + runCatching { + coroutineContext + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(5, 5)), + listOf(SourceLocation(5, 16)) + ) + } + + @Test + fun `does not report no suspending function is used inside runBlocking`() { + val code = """ + @Suppress("RedundantSuspendModifier") + suspend fun foo() { + runCatching { + Thread.sleep(1000L) + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } + + @Test + fun `does report when _when_ is used in result`() { + val code = """ + import kotlinx.coroutines.delay + suspend fun bar() = delay(1000L) + suspend fun foo(): Result<*> { + val result = runCatching { bar() } + when(result.isSuccess) { + true -> TODO() + false -> TODO() + } + return result + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(4, 18)), + listOf(SourceLocation(4, 29)) + ) + } + + @Test + fun `does report when onSuccess is used in result`() { + val code = """ + import kotlinx.coroutines.delay + suspend fun bar() = delay(1000L) + suspend fun foo() { + runCatching { bar() }.onSuccess { + TODO() + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(4, 5)), + listOf(SourceLocation(4, 16)) + ) + } + + @Test + fun `does report when runCatching is used as function expression`() { + val code = """ + import kotlinx.coroutines.delay + suspend fun bar() = delay(1000L) + suspend fun foo() = runCatching { bar() } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(3, 21)), + listOf(SourceLocation(3, 32)) + ) + } + + @Test + fun `does not report when try catch is used`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + try { + delay(1000L) + } catch (e: IllegalStateException) { + // handle error + } + } + """.trimIndent() + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } + + @Nested + inner class WithLambda { + @Test + fun `does report when suspend fun is called inside inline function`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + runCatching { + listOf(1L, 2L, 3L).map { + delay(it) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(4, 5)), + listOf(SourceLocation(4, 16)) + ) + } + + @Test + fun `does not report when lambda in non suspending inline function is passed as crossinline`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + + inline fun foo(crossinline block: suspend () -> R) = MainScope().launch { + block() + } + suspend fun bar() { + runCatching { + foo { + delay(1000L) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } + + @Test + fun `does report when lambda in suspend inline function is passed as crossinline`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + + suspend inline fun foo(crossinline block: suspend () -> R) = block() + suspend fun bar() { + runCatching { + foo { + delay(1000L) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(7, 5)), + listOf(SourceLocation(7, 16)) + ) + } + + @Test + fun `does not report when lambda parameter chain has noinline function call`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + + inline fun inline(block: () -> R) = block() + fun noInline(block: suspend () -> R) = MainScope().launch { block() } + + suspend fun bar() { + runCatching { + inline { + noInline { + inline { + delay(1000) + } + } + + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } + + @Test + fun `does report when lambda parameter chain has all inlined function call`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + + inline fun inline(block: () -> R) = block() + + suspend fun bar() { + runCatching { + inline { + inline { + delay(1000) + } + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(8, 5)), + listOf(SourceLocation(8, 16)) + ) + } + + @Test + fun `does not report when lambda in non suspending inline function is passed as noinline`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + + inline fun foo(noinline block: suspend () -> R) = MainScope().launch { + block() + } + suspend fun suspendFun() = delay(1000) + suspend fun bar() { + runCatching { + foo { + delay(1000L) + } + + val baz = suspend { + delay(1000L) + } + foo(baz) + foo(::suspendFun) + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } + + @Test + fun `does report when lambda in suspend inline function is passed as noinline`() { + val code = """ + import kotlinx.coroutines.delay + + suspend inline fun foo(noinline block: suspend () -> R) = block() + suspend fun bar() { + runCatching { + foo { + delay(1000L) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(5, 5)), + listOf(SourceLocation(5, 16)) + ) + } + + @Test + fun `does report when suspend fun is called as extension function`() { + val code = """ + import kotlinx.coroutines.delay + @Suppress("RedundantSuspendModifier") + suspend fun List.await() = delay(this.size.toLong()) + + suspend fun foo() { + runCatching { + listOf(1L, 2L, 3L).await() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).hasSize(1) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(6, 5)), + listOf(SourceLocation(6, 16)) + ) + } + + @Test + fun `does report when inside inline function with noinline and cross inline parameters in same order`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + import kotlinx.coroutines.runBlocking + + inline fun foo( + noinline noinlineBlock: suspend () -> Unit, + inlineBlock: () -> Unit, + crossinline crossinlineBlock: suspend () -> Unit, + ) = inlineBlock().toString() + MainScope().launch { + noinlineBlock() + } + runBlocking { + crossinlineBlock() + }.toString() + + suspend fun bar() { + runCatching { + foo( + noinlineBlock = { + delay(2000L) + }, + inlineBlock = { delay(1000L) }, + ) { + + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(17, 5)), + listOf(SourceLocation(17, 16)) + ) + } + + @Test + fun `does report when inside inline function with noinline and cross inline parameters not in same order`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + import kotlinx.coroutines.runBlocking + + inline fun foo( + noinline noinlineBlock: suspend () -> Unit, + inlineBlock: () -> Unit, + crossinline crossinlineBlock: suspend () -> Unit, + ) = inlineBlock().toString() + MainScope().launch { + noinlineBlock() + } + runBlocking { + crossinlineBlock() + }.toString() + + suspend fun bar() { + runCatching { + foo( + inlineBlock = { delay(1000L) }, + noinlineBlock = { + delay(2000L) + }, + ) { + + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(17, 5)), + listOf(SourceLocation(17, 16)) + ) + } + } + + @Test + fun `does report when suspend fun in for subject is used`() { + val code = """ + @Suppress("RedundantSuspendModifier") + suspend fun bar() = 10 + + suspend fun foo() { + runCatching { + for (i in 1..bar()) { + println(i) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(5, 5)), + listOf(SourceLocation(5, 16)) + ) + } + + @Test + fun `does report in case of suspendCancellableCoroutine`() { + val code = """ + import kotlinx.coroutines.delay + import kotlinx.coroutines.suspendCancellableCoroutine + import kotlin.coroutines.resume + + suspend fun foo() { + runCatching { + suspendCancellableCoroutine { + it.resume(Unit) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(6, 5)), + listOf(SourceLocation(6, 16)) + ) + } + + @Test + fun `does report in case suspend callable reference is invoked`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun bar() = delay(1000) + suspend fun foo() { + runCatching { + ::bar.invoke() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(5, 5)), + listOf(SourceLocation(5, 16)) + ) + } + + @Test + fun `does report in case suspend local function is invoked`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + suspend fun localFun() = delay(1000L) + runCatching { + localFun() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(5, 5)), + listOf(SourceLocation(5, 16)) + ) + } + + @Test + fun `does not report coroutine is launched`() { + val code = """ + import kotlinx.coroutines.delay + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.launch + + suspend fun foo() { + runCatching { + MainScope().launch { + delay(1000L) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } + + @Test + fun `does report when job is joined`() { + val code = """ + import kotlinx.coroutines.delay + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.launch + + suspend fun foo() { + runCatching { + MainScope().launch { + delay(1000L) + }.join() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(6, 5)), + listOf(SourceLocation(6, 16)) + ) + } + + @Test + fun `does not report async is used`() { + val code = """ + import kotlinx.coroutines.async + import kotlinx.coroutines.delay + import kotlinx.coroutines.MainScope + + suspend fun foo() { + runCatching { + MainScope().async { + delay(1000L) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } + + @Test + fun `does report async is awaited`() { + val code = """ + import kotlinx.coroutines.async + import kotlinx.coroutines.delay + import kotlinx.coroutines.MainScope + + suspend fun foo() { + runCatching { + @Suppress("RedundantAsync") + MainScope().async { + delay(1000L) + }.await() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(6, 5)), + listOf(SourceLocation(6, 16)) + ) + } + + @Test + fun `does not report when suspend fun is called inside runBlocking`() { + val code = """ + import kotlinx.coroutines.delay + import kotlinx.coroutines.runBlocking + + suspend fun foo() { + runCatching { + runBlocking { + delay(1000L) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } + + @Test + fun `does report when suspend fun is called in string interpolation`() { + val code = """ + import kotlinx.coroutines.delay + @Suppress("RedundantSuspendModifier") + suspend fun foo() { + runCatching { + val string = "${'$'}{delay(1000)}" + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(4, 5)), + listOf(SourceLocation(4, 16)) + ) + } + + @Nested + inner class WithOperators { + @Test + fun `does report in case of suspend invoked operator`() { + val code = """ + import kotlinx.coroutines.delay + + class C { + suspend operator fun invoke() = delay(1000L) + } + + suspend fun foo() { + runCatching { + C()() + C().invoke() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(8, 5)), + listOf(SourceLocation(8, 16)) + ) + } + + @Test + fun `does report in case of suspend inc and dec operator`() { + val code = """ + @Suppress("RedundantSuspendModifier") + class OperatorClass { + suspend operator fun inc(): OperatorClass = OperatorClass() + suspend operator fun dec(): OperatorClass = OperatorClass() + } + + suspend fun foo() { + runCatching { + var operatorClass = OperatorClass() + operatorClass++ + } + runCatching { + var operatorClass = OperatorClass() + operatorClass-- + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(8, 5), SourceLocation(12, 5)), + listOf(SourceLocation(8, 16), SourceLocation(12, 16)) + ) + } + + @Test + fun `does report in case of suspend inc and dec operators called as function`() { + val code = """ + @Suppress("RedundantSuspendModifier") + class OperatorClass { + suspend operator fun inc(): OperatorClass = OperatorClass() + suspend operator fun dec(): OperatorClass = OperatorClass() + } + + suspend fun foo() { + runCatching { + val operatorClass = OperatorClass() + operatorClass.inc() + } + runCatching { + val operatorClass = OperatorClass() + operatorClass.dec() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(8, 5), SourceLocation(12, 5)), + listOf(SourceLocation(8, 16), SourceLocation(12, 16)) + ) + } + + @Test + fun `does report in case of suspend not operator`() { + val code = """ + class OperatorClass { + @Suppress("RedundantSuspendModifier") + suspend operator fun not() = false + } + + suspend fun foo() { + runCatching { + val operatorClass = OperatorClass() + println(!operatorClass) + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(7, 5)), + listOf(SourceLocation(7, 16)) + ) + } + + @Test + fun `does report in case of suspend unaryPlus operator`() { + val code = """ + class OperatorClass { + @Suppress("RedundantSuspendModifier") + suspend operator fun unaryPlus() = OperatorClass() + } + + suspend fun foo() { + runCatching { + val operatorClass = OperatorClass() + println(+operatorClass) + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(7, 5)), + listOf(SourceLocation(7, 16)) + ) + } + + @Test + fun `does report in case of suspend plus operator`() { + val code = """ + class OperatorClass { + @Suppress("RedundantSuspendModifier") + suspend operator fun plus(test: OperatorClass) = test + } + + suspend fun foo() { + runCatching { + val operatorClass1 = OperatorClass() + val operatorClass2 = OperatorClass() + println(operatorClass1 + operatorClass2) + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(7, 5)), + listOf(SourceLocation(7, 16)) + ) + } + + @Test + fun `does report in case of suspend div operator`() { + val code = """ + class OperatorClass { + @Suppress("RedundantSuspendModifier") + suspend operator fun div(value: Int) = OperatorClass() + } + + suspend fun foo() { + runCatching { + val operatorClass = OperatorClass() + println(operatorClass / 2) + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(7, 5)), + listOf(SourceLocation(7, 16)) + ) + } + + @Test + fun `does report in case of suspend compareTo operator`() { + val code = """ + class OperatorClass { + @Suppress("RedundantSuspendModifier") + suspend operator fun compareTo(operatorClass: OperatorClass) = 0 + } + + suspend fun foo() { + runCatching { + val operatorClass1 = OperatorClass() + val operatorClass2 = OperatorClass() + println(operatorClass1 < operatorClass2) + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(7, 5)), + listOf(SourceLocation(7, 16)) + ) + } + + @Test + fun `does report in case of suspend plusAssign operator`() { + val code = """ + class OperatorClass { + @Suppress("RedundantSuspendModifier") + suspend operator fun plusAssign(operatorClass: OperatorClass) { } + } + + suspend fun foo() { + runCatching { + val operatorClass1 = OperatorClass() + val operatorClass2 = OperatorClass() + operatorClass1 += (operatorClass2) + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(7, 5)), + listOf(SourceLocation(7, 16)) + ) + } + + @Test + fun `does report in case of suspend plus called as plusAssign operator`() { + val code = """ + class C + @Suppress("RedundantSuspendModifier") + suspend operator fun C.plus(i: Int): C = TODO() + + suspend fun f() { + runCatching { + var x = C() + x += 1 + } + + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(6, 5)), + listOf(SourceLocation(6, 16)) + ) + } + + @Test + fun `does report in case of suspend divAssign operator`() { + val code = """ + import kotlinx.coroutines.delay + class OperatorClass { + suspend operator fun divAssign(operatorClass: OperatorClass) { + delay(1000) + } + } + + suspend fun foo() { + runCatching { + val operatorClass1 = OperatorClass() + val operatorClass2 = OperatorClass() + operatorClass1 /= (operatorClass2) + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(9, 5)), + listOf(SourceLocation(9, 16)) + ) + } + + @Test + fun `does report in case of suspend rangeTo operator`() { + val code = """ + class OperatorClass { + @Suppress("RedundantSuspendModifier") + suspend operator fun rangeTo(operatorClass: OperatorClass) = OperatorClass() + } + + suspend fun foo() { + runCatching { + val operatorClass1 = OperatorClass() + val operatorClass2 = OperatorClass() + println(operatorClass1..operatorClass2) + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(7, 5)), + listOf(SourceLocation(7, 16)) + ) + } + + @Test + fun `does report in case of suspend times operator`() { + val code = """ + class OperatorClass { + @Suppress("RedundantSuspendModifier") + suspend operator fun times(operatorClass: OperatorClass) = OperatorClass() + } + + suspend fun foo() { + runCatching { + val operatorClass1 = OperatorClass() + val operatorClass2 = OperatorClass() + println(operatorClass1 * operatorClass2) + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(7, 5)), + listOf(SourceLocation(7, 16)) + ) + } + + @Nested + inner class WithSuspendingIterator { + + @Test + fun `does report when suspending iterator is used`() { + val code = """ + import kotlinx.coroutines.delay + + class SuspendingIterator { + suspend operator fun iterator(): Iterator = iterator { yield("value") } + } + + suspend fun bar() { + runCatching { + for (x in SuspendingIterator()) { + println(x) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(8, 5)), + listOf(SourceLocation(8, 16)) + ) + } + + @Test + fun `does report when nested suspending iterator is used`() { + val code = """ + import kotlinx.coroutines.delay + + class SuspendingIterator { + suspend operator fun iterator(): Iterator = iterator { yield("value") } + } + + suspend fun bar() { + runCatching { + for (x in SuspendingIterator()) { + for (y in SuspendingIterator()) { + println(x + y) + } + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(8, 5)), + listOf(SourceLocation(8, 16)) + ) + } + + @Test + fun `does report when nested iterator with one suspending iterator is used`() { + val code = """ + import kotlinx.coroutines.delay + + class SuspendingIterator { + suspend operator fun iterator(): Iterator = iterator { yield("value") } + } + + suspend fun bar() { + runCatching { + for (x in 1..10) { + for (y in SuspendingIterator()) { + println(x.toString() + y) + } + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(8, 5)), + listOf(SourceLocation(8, 16)) + ) + } + + @Test + fun `does report when suspending iterator is used withing inlined block`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + + class SuspendingIterator { + suspend operator fun iterator(): Iterator = iterator { yield("value") } + } + + inline fun foo(lambda: () -> Unit) { + lambda() + } + + suspend fun bar() { + runCatching { + foo { + for (x in SuspendingIterator()) { + println(x) + } + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(14, 5)), + listOf(SourceLocation(14, 16)) + ) + } + + @Test + fun `does not report when suspending iterator is used withing non inlined block`() { + val code = """ + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.delay + import kotlinx.coroutines.launch + + class SuspendingIterator { + suspend operator fun iterator(): Iterator = iterator { yield("value") } + } + + suspend fun bar() { + runCatching { + MainScope().launch { + for (x in SuspendingIterator()) { + println(x) + } + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } + + @Test + fun `does report when suspend function is invoked`() { + val code = """ + import kotlinx.coroutines.delay + + suspend fun foo() { + val suspendBlock = suspend { } + runCatching { + suspendBlock() + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertFindingsForSuspendCall( + findings, + listOf(SourceLocation(5, 5)), + listOf(SourceLocation(5, 16)) + ) + } + + @Test + fun `does not report when suspend block is passed to non inline function`() { + val code = """ + import kotlinx.coroutines.delay + import kotlinx.coroutines.MainScope + import kotlinx.coroutines.launch + + fun bar(lambda: suspend () -> Unit) { + MainScope().launch { lambda() } + } + + suspend fun foo() { + runCatching { + bar { + delay(1000L) + } + } + } + """.trimIndent() + + val findings = subject.compileAndLintWithContext(env, code) + assertThat(findings).isEmpty() + } + } + } + + private fun assertFindingsForSuspendCall( + findings: List, + listOfStartLocation: List, + listOfEndLocation: List, + ) { + check(listOfEndLocation.size == listOfStartLocation.size) + assertThat(findings).hasSize(listOfStartLocation.size) + assertThat(findings).hasStartSourceLocations(*listOfStartLocation.toTypedArray()) + assertThat(findings).hasEndSourceLocations(*listOfEndLocation.toTypedArray()) + } +}