Skip to content

Commit

Permalink
UnnecessaryNotNullCheck: fix false negative with smart casted argumen…
Browse files Browse the repository at this point in the history
…ts (#5380)

* UnnecessaryNotNullCheck: fix false negative with smart casted arguments

* Remove unnecessary BindingContext empty check

* Add a comment for KtExpression.getDataFlowAwareTypes
  • Loading branch information
t-kameyama committed Oct 9, 2022
1 parent cdb5928 commit 2e229ec
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 34 deletions.
2 changes: 2 additions & 0 deletions detekt-psi-utils/api/detekt-psi-utils.api
Expand Up @@ -141,5 +141,7 @@ public final class io/gitlab/arturbosch/detekt/rules/TraversingKt {

public final class io/gitlab/arturbosch/detekt/rules/TypeUtilsKt {
public static final fun fqNameOrNull (Lorg/jetbrains/kotlin/types/KotlinType;)Lorg/jetbrains/kotlin/name/FqName;
public static final fun getDataFlowAwareTypes (Lorg/jetbrains/kotlin/psi/KtExpression;Lorg/jetbrains/kotlin/resolve/BindingContext;Lorg/jetbrains/kotlin/config/LanguageVersionSettings;Lorg/jetbrains/kotlin/resolve/calls/smartcasts/DataFlowValueFactory;Lorg/jetbrains/kotlin/types/KotlinType;)Ljava/util/Set;
public static synthetic fun getDataFlowAwareTypes$default (Lorg/jetbrains/kotlin/psi/KtExpression;Lorg/jetbrains/kotlin/resolve/BindingContext;Lorg/jetbrains/kotlin/config/LanguageVersionSettings;Lorg/jetbrains/kotlin/resolve/calls/smartcasts/DataFlowValueFactory;Lorg/jetbrains/kotlin/types/KotlinType;ILjava/lang/Object;)Ljava/util/Set;
}

@@ -1,10 +1,54 @@
package io.gitlab.arturbosch.detekt.rules

import org.jetbrains.kotlin.config.LanguageVersionSettings
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.psi.KtExpression
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.calls.smartcasts.DataFlowValueFactory
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameOrNull
import org.jetbrains.kotlin.types.KotlinType
import org.jetbrains.kotlin.types.TypeUtils
import org.jetbrains.kotlin.util.containingNonLocalDeclaration

fun KotlinType.fqNameOrNull(): FqName? {
return TypeUtils.getClassDescriptor(this)?.fqNameOrNull()
}

/**
* Returns types considering data flow.
*
* For Example, for `s` in `print(s)` below, [BindingContext.getType] returns String?, but this function returns String.
*
* ```kotlin
* fun foo(s: String?) {
* if (s != null) {
* println(s) // s is String (smart cast from String?)
* }
* }
* ```
*/
@Suppress("ReturnCount")
fun KtExpression.getDataFlowAwareTypes(
bindingContext: BindingContext,
languageVersionSettings: LanguageVersionSettings,
dataFlowValueFactory: DataFlowValueFactory,
originalType: KotlinType? = bindingContext.getType(this),
): Set<KotlinType> {
require(bindingContext != BindingContext.EMPTY) { "The bindingContext must not be empty" }

if (originalType == null) return emptySet()

val dataFlowInfo = bindingContext[BindingContext.EXPRESSION_TYPE_INFO, this]
?.dataFlowInfo
?: return setOf(originalType)

val containingDeclaration = containingNonLocalDeclaration()
?.let { bindingContext[BindingContext.DECLARATION_TO_DESCRIPTOR, it] }
?: return setOf(originalType)

val dataFlowValue = dataFlowValueFactory
.createDataFlowValue(this, originalType, bindingContext, containingDeclaration)

return dataFlowInfo.getStableTypes(dataFlowValue, languageVersionSettings)
.ifEmpty { setOf(originalType) }
}
Expand Up @@ -8,6 +8,7 @@ import io.gitlab.arturbosch.detekt.api.Issue
import io.gitlab.arturbosch.detekt.api.Rule
import io.gitlab.arturbosch.detekt.api.Severity
import io.gitlab.arturbosch.detekt.api.internal.RequiresTypeResolution
import io.gitlab.arturbosch.detekt.rules.getDataFlowAwareTypes
import io.gitlab.arturbosch.detekt.rules.safeAs
import org.jetbrains.kotlin.com.intellij.psi.PsiElement
import org.jetbrains.kotlin.descriptors.CallableDescriptor
Expand All @@ -20,7 +21,6 @@ import org.jetbrains.kotlin.psi.KtSafeQualifiedExpression
import org.jetbrains.kotlin.psi.KtSimpleNameExpression
import org.jetbrains.kotlin.psi.KtStringTemplateEntry
import org.jetbrains.kotlin.psi.psiUtil.getStrictParentOfType
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.calls.util.getResolvedCall
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameOrNull
import org.jetbrains.kotlin.types.isFlexible
Expand Down Expand Up @@ -86,21 +86,21 @@ class NullableToStringCall(config: Config = Config.empty) : Rule(config) {
}

private fun KtExpression.isNullable(): Boolean {
val compilerResources = compilerResources ?: return false

val safeAccessOperation = safeAs<KtSafeQualifiedExpression>()?.operationTokenNode?.safeAs<PsiElement>()
if (safeAccessOperation != null) {
return bindingContext.diagnostics.forElement(safeAccessOperation).none {
it.factory == Errors.UNNECESSARY_SAFE_CALL
}
}
val compilerResources = compilerResources ?: return false
val descriptor = descriptor() ?: return false
val originalType = descriptor.returnType ?.takeIf { it.isNullable() && !it.isFlexible() } ?: return false
val dataFlowInfo =
bindingContext[BindingContext.EXPRESSION_TYPE_INFO, this]?.dataFlowInfo ?: return false
val dataFlowValue =
compilerResources.dataFlowValueFactory.createDataFlowValue(this, originalType, bindingContext, descriptor)
val dataFlowTypes =
dataFlowInfo.getStableTypes(dataFlowValue, compilerResources.languageVersionSettings)
val originalType = descriptor()?.returnType?.takeIf { it.isNullable() && !it.isFlexible() } ?: return false
val dataFlowTypes = getDataFlowAwareTypes(
bindingContext,
compilerResources.languageVersionSettings,
compilerResources.dataFlowValueFactory,
originalType
)
return dataFlowTypes.all { it.isNullable() }
}

Expand Down
Expand Up @@ -8,15 +8,11 @@ import io.gitlab.arturbosch.detekt.api.Issue
import io.gitlab.arturbosch.detekt.api.Rule
import io.gitlab.arturbosch.detekt.api.Severity
import io.gitlab.arturbosch.detekt.api.internal.RequiresTypeResolution
import io.gitlab.arturbosch.detekt.rules.getDataFlowAwareTypes
import io.gitlab.arturbosch.detekt.rules.isCalling
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.psi.KtCallExpression
import org.jetbrains.kotlin.psi.KtExpression
import org.jetbrains.kotlin.psi.psiUtil.getCallNameExpression
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.calls.util.getType
import org.jetbrains.kotlin.types.typeUtil.TypeNullability
import org.jetbrains.kotlin.types.typeUtil.nullability
import org.jetbrains.kotlin.types.isNullable

/**
* Reports unnecessary not-null checks with `requireNotNull` or `checkNotNull` that can be removed by the user.
Expand All @@ -41,30 +37,37 @@ class UnnecessaryNotNullCheck(config: Config = Config.empty) : Rule(config) {
Debt.FIVE_MINS,
)

@Suppress("ReturnCount")
override fun visitCallExpression(expression: KtCallExpression) {
super.visitCallExpression(expression)

if (bindingContext == BindingContext.EMPTY) return
val compilerResources = compilerResources ?: return

if (expression.isCalling(requireNotNullFunctionFqName, bindingContext) ||
expression.isCalling(checkNotNullFunctionFqName, bindingContext)
) {
val argument = expression.valueArguments[0].lastChild as KtExpression
if (argument.getType(bindingContext)?.nullability() == TypeNullability.NOT_NULL) {
val callName = expression.getCallNameExpression()?.text
report(
CodeSmell(
issue = issue,
entity = Entity.from(expression),
message = "Using `$callName` on non-null `${argument.text}` is unnecessary",
)
)
}
}
val callee = expression.calleeExpression ?: return
val argument = expression.valueArguments.firstOrNull()?.getArgumentExpression() ?: return

if (!expression.isCalling(notNullCheckFunctionFqNames, bindingContext)) return

val dataFlowAwareTypes = argument.getDataFlowAwareTypes(
bindingContext,
compilerResources.languageVersionSettings,
compilerResources.dataFlowValueFactory
)
if (dataFlowAwareTypes.all { it.isNullable() }) return

report(
CodeSmell(
issue = issue,
entity = Entity.from(expression),
message = "Using `${callee.text}` on non-null `${argument.text}` is unnecessary",
)
)
}

companion object {
private val requireNotNullFunctionFqName = FqName("kotlin.requireNotNull")
private val checkNotNullFunctionFqName = FqName("kotlin.checkNotNull")
private val notNullCheckFunctionFqNames = listOf(
FqName("kotlin.requireNotNull"),
FqName("kotlin.checkNotNull"),
)
}
}
Expand Up @@ -77,6 +77,32 @@ class UnnecessaryNotNullCheckSpec(private val env: KotlinCoreEnvironment) {
assertThat(findings).hasSize(1)
assertThat(findings).hasTextLocations(16 to 58)
}

@Test
fun shouldDetectAfterNullCheck() {
val code = """
fun foo(x: Int?) {
if (x != null) {
requireNotNull(x)
}
}
""".trimIndent()
val findings = subject.lintWithContext(env, code)
assertThat(findings).hasSize(1)
}

@Test
fun shouldDetectAfterTypeCheck() {
val code = """
fun bar(x: Any?) {
if (x is String) {
requireNotNull(x)
}
}
""".trimIndent()
val findings = subject.lintWithContext(env, code)
assertThat(findings).hasSize(1)
}
}

@Nested
Expand Down

0 comments on commit 2e229ec

Please sign in to comment.