Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OptionalUnit] Allow a function to declare a Unit return type when it uses a generic function initializer #4371

Merged
Expand Up @@ -9,17 +9,21 @@ import io.gitlab.arturbosch.detekt.api.Rule
import io.gitlab.arturbosch.detekt.api.Severity
import io.gitlab.arturbosch.detekt.rules.isOverride
import org.jetbrains.kotlin.cfg.WhenChecker
import org.jetbrains.kotlin.js.translate.callTranslator.getReturnType
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea why this is in a JavaScript package, but it appears to work in JVM code.

import org.jetbrains.kotlin.psi.KtBlockExpression
import org.jetbrains.kotlin.psi.KtExpression
import org.jetbrains.kotlin.psi.KtIfExpression
import org.jetbrains.kotlin.psi.KtNameReferenceExpression
import org.jetbrains.kotlin.psi.KtNamedFunction
import org.jetbrains.kotlin.psi.KtTypeReference
import org.jetbrains.kotlin.psi.KtWhenExpression
import org.jetbrains.kotlin.psi.psiUtil.siblings
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.bindingContextUtil.isUsedAsExpression
import org.jetbrains.kotlin.resolve.calls.callUtil.getResolvedCall
import org.jetbrains.kotlin.resolve.calls.callUtil.getType
import org.jetbrains.kotlin.types.typeUtil.isNothing
import org.jetbrains.kotlin.types.typeUtil.isTypeParameter
import org.jetbrains.kotlin.types.typeUtil.isUnit
import org.jetbrains.kotlin.utils.addToStdlib.firstIsInstanceOrNull

Expand Down Expand Up @@ -56,8 +60,9 @@ class OptionalUnit(config: Config = Config.empty) : Rule(config) {
)

override fun visitNamedFunction(function: KtNamedFunction) {
if (function.hasDeclaredReturnType()) {
checkFunctionWithExplicitReturnType(function)
val typeReference = function.typeReference
if (typeReference != null) {
checkFunctionWithExplicitReturnType(function, typeReference)
} else if (!function.isOverride()) {
checkFunctionWithInferredReturnType(function)
}
Expand Down Expand Up @@ -105,11 +110,11 @@ class OptionalUnit(config: Config = Config.empty) : Rule(config) {
}
}

private fun checkFunctionWithExplicitReturnType(function: KtNamedFunction) {
val typeReference = function.typeReference
val typeElementText = typeReference?.typeElement?.text
private fun checkFunctionWithExplicitReturnType(function: KtNamedFunction, typeReference: KtTypeReference) {
val typeElementText = typeReference.typeElement?.text
if (typeElementText == UNIT) {
if (function.initializer.isNothingType()) return
val initializer = function.initializer
if (initializer?.isGenericOrNothingType() == true) return
report(CodeSmell(issue, Entity.from(typeReference), createMessage(function)))
}
}
Expand All @@ -124,8 +129,14 @@ class OptionalUnit(config: Config = Config.empty) : Rule(config) {
private fun createMessage(function: KtNamedFunction) = "The function ${function.name} " +
"defines a return type of Unit. This is unnecessary and can safely be removed."

private fun KtExpression?.isNothingType() =
bindingContext != BindingContext.EMPTY && this?.getType(bindingContext)?.isNothing() == true
private fun KtExpression.isGenericOrNothingType(): Boolean {
if (bindingContext == BindingContext.EMPTY) return false
val isGenericType = getResolvedCall(bindingContext)?.getReturnType()?.isTypeParameter() == true
val isNothingType = getType(bindingContext)?.isNothing() == true
// Either the function initializer returns Nothing or it is a generic function
// into which Unit is passed, but not both.
return (isGenericType && !isNothingType) || (isNothingType && !isGenericType)
}

companion object {
private const val UNIT = "Unit"
Expand Down
Expand Up @@ -18,6 +18,21 @@ class OptionalUnitSpec : Spek({
val env: KotlinCoreEnvironment by memoized()

describe("OptionalUnit rule") {
it("should report when a function has an explicit Unit return type with context") {
val code = """
fun foo(): Unit { }
""".trimIndent()
val findings = subject.compileAndLintWithContext(env, code)
assertThat(findings).hasSize(1)
}

it("should not report when a function has a non-unit body expression") {
val code = """
fun foo() = String
""".trimIndent()
val findings = subject.compileAndLintWithContext(env, code)
assertThat(findings).isEmpty()
}

context("several functions which return Unit") {

Expand Down Expand Up @@ -296,14 +311,72 @@ class OptionalUnitSpec : Spek({
val findings = subject.compileAndLintWithContext(env, code)
assertThat(findings).hasSize(1)
}

it("another object is used as the last expression") {
val code = """
fun foo() {
String
}
""".trimIndent()
val findings = subject.compileAndLintWithContext(env, code)
assertThat(findings).isEmpty()
}
}

it("should not report when function initializer is Nothing") {
val code = """
context("function initializers") {
it("should not report when function initializer is Nothing") {
val code = """
fun test(): Unit = throw UnsupportedOperationException()
"""
val findings = subject.compileAndLintWithContext(env, code)
assertThat(findings).isEmpty()
val findings = subject.compileAndLintWithContext(env, code)
assertThat(findings).isEmpty()
}

it("should not report when the function initializer requires a type") {
val code = """
fun <T> foo(block: (List<T>) -> Unit): T {
val list = listOf<T>()
block(list)
return list.first()
}

fun doFoo(): Unit = foo {}
""".trimIndent()
val findings = subject.compileAndLintWithContext(env, code)
assertThat(findings).isEmpty()
}

it("should report on function initializers when there is no context") {
val code = """
fun test(): Unit = throw UnsupportedOperationException()
"""
val findings = subject.compileAndLint(code)
assertThat(findings).hasSize(1)
}

it("should report when the function initializer takes in the type Nothing") {
val code = """
fun <T> foo(block: (List<T>) -> Unit): T {
val list = listOf<T>()
block(list)
return list.first()
}

fun doFoo(): Unit = foo<Nothing> {}
""".trimIndent()
val findings = subject.compileAndLintWithContext(env, code)
assertThat(findings).hasSize(1)
}

it("should report when the function initializer does not provide a different type") {
val code = """
fun foo() {}

fun doFoo(): Unit = foo()
""".trimIndent()
val findings = subject.compileAndLintWithContext(env, code)
assertThat(findings).hasSize(1)
}
}
}
})