Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support analyzing context receivers #1475

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -74,4 +74,9 @@ interface KSClassDeclaration : KSDeclaration, KSDeclarationContainer {
* @return A type with all type parameters applied with star projection.
*/
fun asStarProjectedType(): KSType

/**
* The class may have context receivers at the class level, which affect all constructors
*/
val contextReceivers: List<KSTypeReference>
}
Expand Up @@ -53,6 +53,13 @@ interface KSFunction {
*/
val extensionReceiverType: KSType?

/**
* The context receiver types of the function
*
* @see KSFunctionDeclaration.contextReceivers
*/
val contextReceiverTypes: List<KSType>

/**
* True if the compiler couldn't resolve the function.
*/
Expand Down
Expand Up @@ -44,6 +44,11 @@ interface KSFunctionDeclaration : KSDeclaration, KSDeclarationContainer {
*/
val extensionReceiver: KSTypeReference?

/**
* Context receivers of this function
*/
val contextReceivers: List<KSTypeReference>

/**
* Return type of this function.
* Can be null if an error occurred during resolution.
Expand Down
Expand Up @@ -72,6 +72,10 @@ class KSClassDeclarationDescriptorImpl private constructor(val descriptor: Class

override fun getAllProperties(): Sequence<KSPropertyDeclaration> = descriptor.getAllProperties()

override val contextReceivers: List<KSTypeReference> by lazy {
descriptor.getAllContextReceivers(this)
}

override val primaryConstructor: KSFunctionDeclaration? by lazy {
descriptor.unsubstitutedPrimaryConstructor?.let { KSFunctionDeclarationDescriptorImpl.getCached(it) }
}
Expand Down Expand Up @@ -190,6 +194,12 @@ internal fun ClassDescriptor.getAllProperties(): Sequence<KSPropertyDeclaration>
}
}

internal fun ClassDescriptor.getAllContextReceivers(node: KSNode): List<KSTypeReference> {
return contextReceivers.map {
KSTypeReferenceDescriptorImpl.getCached(it.type, origin, node)
}
}

internal fun ClassDescriptor.sealedSubclassesSequence(): Sequence<KSClassDeclaration> {
// TODO record incremental subclass lookups in Kotlin 1.5.x?
return sealedSubclasses
Expand Down
Expand Up @@ -58,6 +58,12 @@ class KSFunctionDeclarationDescriptorImpl private constructor(val descriptor: Fu
}
}

override val contextReceivers: List<KSTypeReference> by lazy {
descriptor.contextReceiverParameters.map {
KSTypeReferenceDescriptorImpl.getCached(it.type, origin, this)
}
}

override val functionKind: FunctionKind by lazy {

when {
Expand Down
Expand Up @@ -105,6 +105,8 @@ class KSClassDeclarationJavaEnumEntryImpl private constructor(val psi: PsiEnumCo
return getKSTypeCached(descriptor!!.defaultType)
}

override val contextReceivers: List<KSTypeReference> = emptyList()

override fun <D, R> accept(visitor: KSVisitor<D, R>, data: D): R {
return visitor.visitClassDeclaration(this, data)
}
Expand Down
Expand Up @@ -166,6 +166,8 @@ class KSClassDeclarationJavaImpl private constructor(val psi: PsiClass) :
} ?: KSErrorType
}

override val contextReceivers: List<KSTypeReference> = emptyList()

override fun <D, R> accept(visitor: KSVisitor<D, R>, data: D): R {
return visitor.visitClassDeclaration(this, data)
}
Expand Down
Expand Up @@ -61,6 +61,8 @@ class KSFunctionDeclarationJavaImpl private constructor(val psi: PsiMethod) :

override val extensionReceiver: KSTypeReference? = null

override val contextReceivers: List<KSTypeReference> = emptyList()

override val functionKind: FunctionKind = when {
psi.hasModifier(JvmModifier.STATIC) -> FunctionKind.STATIC
else -> FunctionKind.MEMBER
Expand Down
Expand Up @@ -25,6 +25,7 @@ import com.google.devtools.ksp.processing.impl.KSTypeReferenceSyntheticImpl
import com.google.devtools.ksp.processing.impl.ResolverImpl
import com.google.devtools.ksp.symbol.*
import com.google.devtools.ksp.symbol.impl.*
import com.google.devtools.ksp.symbol.impl.binary.getAllContextReceivers
import com.google.devtools.ksp.symbol.impl.binary.getAllFunctions
import com.google.devtools.ksp.symbol.impl.binary.getAllProperties
import com.google.devtools.ksp.symbol.impl.binary.sealedSubclassesSequence
Expand Down Expand Up @@ -132,6 +133,10 @@ class KSClassDeclarationImpl private constructor(val ktClassOrObject: KtClassOrO
return getKSTypeCached(descriptor.defaultType.replaceArgumentsWithStarProjections())
}

override val contextReceivers: List<KSTypeReference> by lazy {
descriptor.getAllContextReceivers(this)
}

override fun <D, R> accept(visitor: KSVisitor<D, R>, data: D): R {
return visitor.visitClassDeclaration(this, data)
}
Expand Down
Expand Up @@ -25,6 +25,7 @@ import com.google.devtools.ksp.processing.impl.KSPCompilationError
import com.google.devtools.ksp.processing.impl.ResolverImpl
import com.google.devtools.ksp.symbol.*
import com.google.devtools.ksp.symbol.impl.*
import com.google.devtools.ksp.symbol.impl.binary.KSTypeReferenceDescriptorImpl
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.KtConstructor
Expand Down Expand Up @@ -78,6 +79,12 @@ class KSFunctionDeclarationImpl private constructor(val ktFunction: KtFunction)
}
}

override val contextReceivers: List<KSTypeReference> by lazy {
ktFunction.contextReceivers.map {
KSTypeReferenceImpl.getCached(it.typeReference()!!)
}
}

override val functionKind: FunctionKind by lazy {
if (parentDeclaration == null) {
FunctionKind.TOP_LEVEL
Expand Down
Expand Up @@ -43,6 +43,11 @@ class KSFunctionErrorImpl(
KSErrorType
}

override val contextReceiverTypes: List<KSType>
get() = declaration.contextReceivers.let {
listOf(KSErrorType)
}

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
Expand Down
Expand Up @@ -54,6 +54,12 @@ class KSFunctionImpl(val descriptor: CallableDescriptor) : KSFunction {
descriptor.extensionReceiverParameter?.type?.let(::getKSTypeCached)
}

override val contextReceiverTypes: List<KSType> by lazy(LazyThreadSafetyMode.PUBLICATION) {
descriptor.contextReceiverParameters.map {
getKSTypeCached(it.type)
}
}

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
Expand All @@ -64,6 +70,7 @@ class KSFunctionImpl(val descriptor: CallableDescriptor) : KSFunction {
if (parameterTypes != other.parameterTypes) return false
if (typeParameters != other.typeParameters) return false
if (extensionReceiverType != other.extensionReceiverType) return false
if (contextReceiverTypes != other.contextReceiverTypes) return false

return true
}
Expand Down
Expand Up @@ -39,6 +39,8 @@ class KSConstructorSyntheticImpl private constructor(val ksClassDeclaration: KSC

override val extensionReceiver: KSTypeReference? = null

override val contextReceivers: List<KSTypeReference> = emptyList()

override val parameters: List<KSValueParameter> = emptyList()

override val functionKind: FunctionKind = FunctionKind.MEMBER
Expand Down
Expand Up @@ -64,6 +64,8 @@ object KSErrorTypeClassDeclaration : KSClassDeclaration {
return ResolverImpl.instance!!.builtIns.nothingType
}

override val contextReceivers: List<KSTypeReference> = emptyList()

override fun asType(typeArguments: List<KSTypeArgument>): KSType {
return ResolverImpl.instance!!.builtIns.nothingType
}
Expand Down
Expand Up @@ -62,6 +62,14 @@ class KSFunctionDeclarationImpl private constructor(internal val ktFunctionSymbo
}
}

override val contextReceivers: List<KSTypeReference> by lazy {
analyze {
ktFunctionSymbol.contextReceivers.map {
KSTypeReferenceImpl.getCached(it.type, this@KSFunctionDeclarationImpl)
}
}
}

override val returnType: KSTypeReference? by lazy {
analyze {
// Constructors
Expand Down
Expand Up @@ -224,7 +224,12 @@ class AsMemberOfProcessor : AbstractTestProcessor() {
} else {
""
}
return "$receiverSignature$paramTypesSignature($params) -> $returnType"
val contextSignature = if (contextReceiverTypes.isNotEmpty()) {
contextReceiverTypes.map { it.toSignature() }.joinToString(prefix = "context(", postfix = ") ", separator = ",")
} else {
""
}
return "$contextSignature$receiverSignature$paramTypesSignature($params) -> $returnType"
}

private fun Nullability.toSignature() = when (this) {
Expand Down
Expand Up @@ -56,8 +56,15 @@ class ConstructorDeclarationsProcessor : AbstractTestProcessor() {
listOf("class: " + it.key.qualifiedName!!.asString()) + it.value
}
}
fun KSFunctionDeclaration.toSignature(): String {
fun KSFunctionDeclaration.toSignature(classDeclaration: KSClassDeclaration): String {
val contextSignature = if (classDeclaration.contextReceivers.isNotEmpty()) {
classDeclaration.contextReceivers.map { it.resolve().declaration.qualifiedName?.asString() }
.joinToString(prefix = " context(", postfix = ") ", separator = ",")
} else {
""
}
return this.simpleName.asString() +
contextSignature +
"(${this.parameters.map {
buildString {
append(it.type.resolve().declaration.qualifiedName?.asString())
Expand All @@ -74,7 +81,7 @@ class ConstructorDeclarationsProcessor : AbstractTestProcessor() {
val declarations = mutableListOf<String>()
declarations.addAll(
classDeclaration.getConstructors().map {
it.toSignature()
it.toSignature(classDeclaration)
}.sorted()
)
// TODO add some assertions that if we go through he path of getDeclarations
Expand Down
Expand Up @@ -22,11 +22,18 @@ import com.intellij.openapi.Disposable
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.Disposer
import com.intellij.testFramework.TestDataFile
import org.jetbrains.kotlin.cli.common.arguments.CommonCompilerArguments
import org.jetbrains.kotlin.cli.common.arguments.K2JVMCompilerArguments
import org.jetbrains.kotlin.cli.common.setupCommonArguments
import org.jetbrains.kotlin.cli.jvm.compiler.EnvironmentConfigFiles
import org.jetbrains.kotlin.cli.jvm.compiler.KotlinCoreEnvironment
import org.jetbrains.kotlin.cli.jvm.config.addJavaSourceRoot
import org.jetbrains.kotlin.cli.jvm.config.addJvmClasspathRoots
import org.jetbrains.kotlin.cli.jvm.setupJvmSpecificArguments
import org.jetbrains.kotlin.codegen.GenerationUtils
import org.jetbrains.kotlin.config.CompilerConfigurationKey
import org.jetbrains.kotlin.config.LanguageFeature
import org.jetbrains.kotlin.config.languageVersionSettings
import org.jetbrains.kotlin.platform.jvm.JvmPlatforms
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.test.ExecutionListenerBasedDisposableProvider
Expand Down Expand Up @@ -111,6 +118,10 @@ abstract class AbstractKSPTest(frontend: FrontendKind<*>) : DisposableTest() {
this@globalDefaults.frontend = frontend
targetPlatform = JvmPlatforms.defaultJvmPlatform
dependencyKind = DependencyKind.Source
languageSettings {
// TODO: when would this be removed after they become stable, as it is version specific?
this.enable(LanguageFeature.ContextReceivers)
}
}
useConfigurators(
::CommonEnvironmentConfigurator,
Expand Down
23 changes: 22 additions & 1 deletion test-utils/testData/api/asMemberOf.kt
Expand Up @@ -26,6 +26,9 @@
// errorType: <Error>?
// extensionProperty: kotlin.String?
// returnInt: () -> kotlin.Int!!
// returnInt2: context(kotlin.Int!!,kotlin.String!!) () -> kotlin.Int!!
// returnInt3: context(kotlin.String!!) kotlin.Int!!.() -> kotlin.Int!!
// returnInt4: kotlin.Int!!.() -> kotlin.Int!!
// returnArg1: () -> kotlin.Int!!
// returnArg1Nullable: () -> kotlin.Int?
// returnArg2: () -> kotlin.String?
Expand All @@ -43,6 +46,9 @@
// errorType: <Error>?
// extensionProperty: kotlin.Any?
// returnInt: () -> kotlin.Int!!
// returnInt2: context(kotlin.Int!!,kotlin.String!!) () -> kotlin.Int!!
// returnInt3: context(kotlin.String!!) kotlin.Int!!.() -> kotlin.Int!!
// returnInt4: kotlin.Int!!.() -> kotlin.Int!!
// returnArg1: () -> kotlin.Any?
// returnArg1Nullable: () -> kotlin.Any?
// returnArg2: () -> kotlin.Any?
Expand All @@ -60,6 +66,9 @@
// errorType: <Error>?
// extensionProperty: kotlin.String?
// returnInt: () -> kotlin.Int!!
// returnInt2: context(kotlin.Int!!,kotlin.String!!) () -> kotlin.Int!!
// returnInt3: context(kotlin.String!!) kotlin.Int!!.() -> kotlin.Int!!
// returnInt4: kotlin.Int!!.() -> kotlin.Int!!
// returnArg1: () -> kotlin.String!!
// returnArg1Nullable: () -> kotlin.String?
// returnArg2: () -> kotlin.String?
Expand All @@ -77,6 +86,9 @@
// errorType: java.lang.IllegalArgumentException: NotAChild is not a sub type of the class/interface that contains `errorType` (Base)
// extensionProperty: java.lang.IllegalArgumentException: NotAChild is not a sub type of the class/interface that contains `extensionProperty` (Base)
// returnInt: java.lang.IllegalArgumentException: NotAChild is not a sub type of the class/interface that contains `returnInt` (Base)
// returnInt2: java.lang.IllegalArgumentException: NotAChild is not a sub type of the class/interface that contains `returnInt2` (Base)
// returnInt3: java.lang.IllegalArgumentException: NotAChild is not a sub type of the class/interface that contains `returnInt3` (Base)
// returnInt4: java.lang.IllegalArgumentException: NotAChild is not a sub type of the class/interface that contains `returnInt4` (Base)
// returnArg1: java.lang.IllegalArgumentException: NotAChild is not a sub type of the class/interface that contains `returnArg1` (Base)
// returnArg1Nullable: java.lang.IllegalArgumentException: NotAChild is not a sub type of the class/interface that contains `returnArg1Nullable` (Base)
// returnArg2: java.lang.IllegalArgumentException: NotAChild is not a sub type of the class/interface that contains `returnArg2` (Base)
Expand Down Expand Up @@ -104,7 +116,7 @@
// fileLevelFunction: java.lang.IllegalArgumentException: Cannot call asMemberOf with a function that is not declared in a class or an interface
// fileLevelExtensionFunction: java.lang.IllegalArgumentException: Cannot call asMemberOf with a function that is not declared in a class or an interface
// fileLevelProperty: java.lang.IllegalArgumentException: Cannot call asMemberOf with a property that is not declared in a class or an interface
// errorType: (<Error>?) -> <Error>?
// errorType: context(<Error>?) (<Error>?) -> <Error>?
// expected comparison failures
// <BaseTypeArg1: kotlin.Any?>(Base.functionArgType.BaseTypeArg1?) -> kotlin.String?
// () -> kotlin.Int!!
Expand All @@ -119,6 +131,15 @@ open class Base<BaseTypeArg1, BaseTypeArg2> {
val typePair: Pair<BaseTypeArg2, BaseTypeArg1> = TODO()
val errorType: NonExistType = TODO()
fun returnInt():Int = TODO()

context(Int, String)
fun returnInt2():Int = TODO()

context(String)
fun Int.returnInt3():Int = TODO()

fun Int.returnInt4():Int = TODO()

fun returnArg1(): BaseTypeArg1 = TODO()
fun returnArg1Nullable(): BaseTypeArg1? = TODO()
fun returnArg2(): BaseTypeArg2 = TODO()
Expand Down