diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/OpticsProcessor.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/OpticsProcessor.kt index 7b3b72c23c5..5916b14a4a5 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/OpticsProcessor.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/OpticsProcessor.kt @@ -38,12 +38,6 @@ class OpticsProcessor(private val codegen: CodeGenerator, private val logger: KS return } - // check that it does not have type arguments - if (klass.typeParameters.isNotEmpty()) { - logger.error(klass.qualifiedNameOrSimpleName.typeParametersErrorMessage, klass) - return - } - // check that the companion object exists if (klass.companionObject == null) { logger.error(klass.qualifiedNameOrSimpleName.noCompanion, klass) diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/domain.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/domain.kt index cfa759e3875..e3bd7b7952d 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/domain.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/domain.kt @@ -2,25 +2,34 @@ package arrow.optics.plugin.internals import arrow.optics.plugin.companionObject import com.google.devtools.ksp.getVisibility -import com.google.devtools.ksp.symbol.KSClassDeclaration -import com.google.devtools.ksp.symbol.KSName -import com.google.devtools.ksp.symbol.Visibility +import com.google.devtools.ksp.symbol.* import java.util.Locale data class ADT(val pckg: KSName, val declaration: KSClassDeclaration, val targets: List) { val sourceClassName = declaration.qualifiedNameOrSimpleName val sourceName = declaration.simpleName.asString().replaceFirstChar { it.lowercase(Locale.getDefault()) } - val simpleName = declaration.simpleName.asString() + val simpleName = declaration.nameWithParentClass val packageName = pckg.asString() val visibilityModifierName = when (declaration.companionObject?.getVisibility()) { Visibility.INTERNAL -> "internal" else -> "public" } + val typeParameters: List = declaration.typeParameters.map { it.simpleName.asString() } + val angledTypeParameters: String = when { + typeParameters.isEmpty() -> "" + else -> "<${typeParameters.joinToString(separator = ",")}>" + } operator fun Snippet.plus(snippet: Snippet): Snippet = copy(imports = imports + snippet.imports, content = "$content\n${snippet.content}") } +val KSClassDeclaration.nameWithParentClass: String + get() = when (val parent = parentDeclaration) { + is KSClassDeclaration -> parent.nameWithParentClass + "." + simpleName.asString() + else -> simpleName.asString() + } + enum class OpticsTarget { ISO, LENS, @@ -61,27 +70,46 @@ typealias NullableFocus = Focus.Nullable sealed class Focus { companion object { - operator fun invoke(fullName: String, paramName: String): Focus = + operator fun invoke(fullName: String, paramName: String, refinedType: KSType? = null): Focus = when { - fullName.endsWith("?") -> Nullable(fullName, paramName) - fullName.startsWith("`arrow`.`core`.`Option`") -> Option(fullName, paramName) - else -> NonNull(fullName, paramName) + fullName.endsWith("?") -> Nullable(fullName, paramName, refinedType) + fullName.startsWith("`arrow`.`core`.`Option`") -> Option(fullName, paramName, refinedType) + else -> NonNull(fullName, paramName, refinedType) } } abstract val className: String abstract val paramName: String - - data class Nullable(override val className: String, override val paramName: String) : Focus() { + // only used for type-refining prisms + abstract val refinedType: KSType? + + val refinedArguments: List + get() = refinedType?.arguments?.filter { + it.type?.resolve()?.declaration is KSTypeParameter + }?.map { it.qualifiedString() }.orEmpty() + + data class Nullable( + override val className: String, + override val paramName: String, + override val refinedType: KSType? + ) : Focus() { val nonNullClassName = className.dropLast(1) } - data class Option(override val className: String, override val paramName: String) : Focus() { + data class Option( + override val className: String, + override val paramName: String, + override val refinedType: KSType? + ) : Focus() { val nestedClassName = Regex("`arrow`.`core`.`Option`<(.*)>$").matchEntire(className)!!.groupValues[1] } - data class NonNull(override val className: String, override val paramName: String) : Focus() + data class NonNull( + override val className: String, + override val paramName: String, + override val refinedType: KSType? + ) : Focus() } const val Lens = "arrow.optics.Lens" diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt index a4e97e9eda8..a42decca9e3 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt @@ -22,7 +22,8 @@ fun generatePrismDsl(ele: ADT, isoOptic: SealedClassDsl): Snippet = ) private fun processLensSyntax(ele: ADT, foci: List): String = - foci.joinToString(separator = "\n") { focus -> + if (ele.typeParameters.isEmpty()) { + foci.joinToString(separator = "\n") { focus -> """ |${ele.visibilityModifierName} inline val $Iso.${focus.lensParamName()}: $Lens inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} |${ele.visibilityModifierName} inline val $Lens.${focus.lensParamName()}: $Lens inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} @@ -34,17 +35,36 @@ private fun processLensSyntax(ele: ADT, foci: List): String = |${ele.visibilityModifierName} inline val $Fold.${focus.lensParamName()}: $Fold inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} |${ele.visibilityModifierName} inline val $Every.${focus.lensParamName()}: $Every inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()} |""".trimMargin() + } + } else { + val sourceClassNameWithParams = "${ele.sourceClassName}${ele.angledTypeParameters}" + val joinedTypeParams = ele.typeParameters.joinToString(separator=",") + foci.joinToString(separator = "\n") { focus -> + """ + |${ele.visibilityModifierName} inline fun $Iso.${focus.lensParamName()}(): $Lens = this + ${ele.sourceClassName}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Lens.${focus.lensParamName()}(): $Lens = this + ${ele.sourceClassName}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Optional.${focus.lensParamName()}(): $Optional = this + ${ele.sourceClassName}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Prism.${focus.lensParamName()}(): $Optional = this + ${ele.sourceClassName}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Getter.${focus.lensParamName()}(): $Getter = this + ${ele.sourceClassName}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Setter.${focus.lensParamName()}(): $Setter = this + ${ele.sourceClassName}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Traversal.${focus.lensParamName()}(): $Traversal = this + ${ele.sourceClassName}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Fold.${focus.lensParamName()}(): $Fold = this + ${ele.sourceClassName}.${focus.lensParamName()}() + |${ele.visibilityModifierName} inline fun $Every.${focus.lensParamName()}(): $Every = this + ${ele.sourceClassName}.${focus.lensParamName()}() + |""".trimMargin() + } } -private fun processOptionalSyntax(ele: ADT, optic: DataClassDsl) = - optic.foci.filterNot { it is NonNullFocus }.joinToString(separator = "\n") { focus -> +private fun processOptionalSyntax(ele: ADT, optic: DataClassDsl): String { + val sourceClassNameWithParams = "${ele.sourceClassName}${ele.angledTypeParameters}" + val joinedTypeParams = ele.typeParameters.joinToString(separator=",") + return optic.foci.filterNot { it is NonNullFocus }.joinToString(separator = "\n") { focus -> val targetClassName = when (focus) { - is NullableFocus -> focus.nonNullClassName - is OptionFocus -> focus.nestedClassName - is NonNullFocus -> "" + is Focus.Nullable -> focus.nonNullClassName + is Focus.Option -> focus.nestedClassName + is Focus.NonNull -> "" } - + if (ele.typeParameters.isEmpty()) { """ |${ele.visibilityModifierName} inline val $Iso.${focus.paramName}: $Optional inline get() = this + ${ele.sourceClassName}.${focus.paramName} |${ele.visibilityModifierName} inline val $Lens.${focus.paramName}: $Optional inline get() = this + ${ele.sourceClassName}.${focus.paramName} @@ -55,10 +75,24 @@ private fun processOptionalSyntax(ele: ADT, optic: DataClassDsl) = |${ele.visibilityModifierName} inline val $Fold.${focus.paramName}: $Fold inline get() = this + ${ele.sourceClassName}.${focus.paramName} |${ele.visibilityModifierName} inline val $Every.${focus.paramName}: $Every inline get() = this + ${ele.sourceClassName}.${focus.paramName} |""".trimMargin() + } else { + """ + |${ele.visibilityModifierName} inline fun $Iso.${focus.paramName}(): $Optional = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Lens.${focus.paramName}(): $Optional = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Optional.${focus.paramName}(): $Optional = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Prism.${focus.paramName}(): $Optional = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Setter.${focus.paramName}(): $Setter = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Traversal.${focus.paramName}(): $Traversal = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Fold.${focus.paramName}(): $Fold = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Every.${focus.paramName}(): $Every = this + ${ele.sourceClassName}.${focus.paramName}() + |""".trimMargin() + } } +} private fun processPrismSyntax(ele: ADT, dsl: SealedClassDsl): String = - dsl.foci.joinToString(separator = "\n\n") { focus -> + if (ele.typeParameters.isEmpty()) { + dsl.foci.joinToString(separator = "\n\n") { focus -> """ |${ele.visibilityModifierName} inline val $Iso.${focus.paramName}: $Prism inline get() = this + ${ele.sourceClassName}.${focus.paramName} |${ele.visibilityModifierName} inline val $Lens.${focus.paramName}: $Optional inline get() = this + ${ele.sourceClassName}.${focus.paramName} @@ -69,4 +103,23 @@ private fun processPrismSyntax(ele: ADT, dsl: SealedClassDsl): String = |${ele.visibilityModifierName} inline val $Fold.${focus.paramName}: $Fold inline get() = this + ${ele.sourceClassName}.${focus.paramName} |${ele.visibilityModifierName} inline val $Every.${focus.paramName}: $Every inline get() = this + ${ele.sourceClassName}.${focus.paramName} |""".trimMargin() + } + } else { + dsl.foci.joinToString(separator = "\n\n") { focus -> + val sourceClassNameWithParams = focus.refinedType?.qualifiedString() ?: "${ele.sourceClassName}${ele.angledTypeParameters}" + val joinedTypeParams = when { + focus.refinedArguments.isEmpty() -> "" + else -> focus.refinedArguments.joinToString(separator=",") + } + """ + |${ele.visibilityModifierName} inline fun $Iso.${focus.paramName}(): $Prism = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Lens.${focus.paramName}(): $Optional = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Optional.${focus.paramName}(): $Optional = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Prism.${focus.paramName}(): $Prism = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Setter.${focus.paramName}(): $Setter = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Traversal.${focus.paramName}(): $Traversal = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Fold.${focus.paramName}(): $Fold = this + ${ele.sourceClassName}.${focus.paramName}() + |${ele.visibilityModifierName} inline fun $Every.${focus.paramName}(): $Every = this + ${ele.sourceClassName}.${focus.paramName}() + |""".trimMargin() + } } diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/isos.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/isos.kt index 4a75add2cd9..dce33e5b3e0 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/isos.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/isos.kt @@ -63,9 +63,17 @@ private fun processElement(iso: ADT, target: Target): String { "tuple: ${focusType()} -> ${(foci.indices).joinToString(prefix = "${iso.sourceClassName}(", postfix = ")", transform = { "tuple.${letters[it]}" })}" } + val sourceClassNameWithParams = "${iso.sourceClassName}${iso.angledTypeParameters}" + val firstLine = when { + iso.typeParameters.isEmpty() -> + "${iso.visibilityModifierName} inline val ${iso.sourceClassName}.Companion.iso: $Iso<${iso.sourceClassName}, ${focusType()}> inline get()" + else -> + "${iso.visibilityModifierName} inline fun ${iso.angledTypeParameters} ${iso.sourceClassName}.Companion.iso(): $Iso<$sourceClassNameWithParams, ${focusType()}>" + } + return """ - |${iso.visibilityModifierName} inline val ${iso.sourceClassName}.Companion.iso: $Iso<${iso.sourceClassName}, ${focusType()}> inline get()= $Iso( - | get = { ${iso.sourceName}: ${iso.sourceClassName} -> ${tupleConstructor()} }, + |$firstLine = $Iso( + | get = { ${iso.sourceName}: $sourceClassNameWithParams -> ${tupleConstructor()} }, | reverseGet = { ${classConstructorFromTuple()} } |) |""".trimMargin() diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/lenses.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/lenses.kt index 2ed586fd3c5..6b75addd841 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/lenses.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/lenses.kt @@ -20,17 +20,24 @@ private fun String.toUpperCamelCase(): String = } ) -private fun processElement(adt: ADT, foci: List): String = - foci.joinToString(separator = "\n") { focus -> +private fun processElement(adt: ADT, foci: List): String { + val sourceClassNameWithParams = "${adt.sourceClassName}${adt.angledTypeParameters}" + return foci.joinToString(separator = "\n") { focus -> + val firstLine = when { + adt.typeParameters.isEmpty() -> + "${adt.visibilityModifierName} inline val ${adt.sourceClassName}.Companion.${focus.lensParamName()}: $Lens<${adt.sourceClassName}, ${focus.className}> inline get()" + else -> + "${adt.visibilityModifierName} inline fun ${adt.angledTypeParameters} ${adt.sourceClassName}.Companion.${focus.lensParamName()}(): $Lens<$sourceClassNameWithParams, ${focus.className}>" + } """ - |${adt.visibilityModifierName} inline val ${adt.sourceClassName}.Companion.${focus.lensParamName()}: $Lens<${adt.sourceClassName}, ${focus.className}> inline get()= $Lens( - | get = { ${adt.sourceName}: ${adt.sourceClassName} -> ${adt.sourceName}.${ + |$firstLine = $Lens( + | get = { ${adt.sourceName}: $sourceClassNameWithParams -> ${adt.sourceName}.${ focus.paramName.plusIfNotBlank( prefix = "`", postfix = "`" ) } }, - | set = { ${adt.sourceName}: ${adt.sourceClassName}, value: ${focus.className} -> ${adt.sourceName}.copy(${ + | set = { ${adt.sourceName}: $sourceClassNameWithParams, value: ${focus.className} -> ${adt.sourceName}.copy(${ focus.paramName.plusIfNotBlank( prefix = "`", postfix = "`" @@ -39,6 +46,7 @@ private fun processElement(adt: ADT, foci: List): String = |) |""".trimMargin() } +} fun Focus.lensParamName(): String = when (this) { diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/optional.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/optional.kt index 9430e318d8c..6a23c765d67 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/optional.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/optional.kt @@ -11,8 +11,23 @@ internal fun generateOptionals(ele: ADT, target: OptionalTarget) = private fun processElement(ele: ADT, foci: List): String = foci.joinToString(separator = "\n") { focus -> + + val targetClassName = when (focus) { + is NullableFocus -> focus.nonNullClassName + is OptionFocus -> focus.nestedClassName + is NonNullFocus -> return@joinToString "" + } + + val sourceClassNameWithParams = "${ele.sourceClassName}${ele.angledTypeParameters}" + val firstLine = when { + ele.typeParameters.isEmpty() -> + "${ele.visibilityModifierName} inline val ${ele.sourceClassName}.Companion.${focus.paramName}: $Optional<${ele.sourceClassName}, $targetClassName> inline get()" + else -> + "${ele.visibilityModifierName} inline fun ${ele.angledTypeParameters} ${ele.sourceClassName}.Companion.${focus.paramName}(): $Optional<$sourceClassNameWithParams, $targetClassName>" + } + fun getOrModifyF(toNullable: String = "") = - "{ ${ele.sourceName}: ${ele.sourceClassName} -> ${ele.sourceName}.${ + "{ ${ele.sourceName}: $sourceClassNameWithParams -> ${ele.sourceName}.${ focus.paramName.plusIfNotBlank( prefix = "`", postfix = "`" @@ -21,18 +36,17 @@ private fun processElement(ele: ADT, foci: List): String = fun setF(fromNullable: String = "") = "${ele.sourceName}.copy(${focus.paramName.plusIfNotBlank(prefix = "`", postfix = "`")} = value$fromNullable)" - val (targetClassName, getOrModify, set) = + val (getOrModify, set) = when (focus) { - is NullableFocus -> Triple(focus.nonNullClassName, getOrModifyF(), setF()) - is OptionFocus -> - Triple(focus.nestedClassName, getOrModifyF(".orNull()"), setF(".toOption()")) + is NullableFocus -> Pair(getOrModifyF(), setF()) + is OptionFocus -> Pair(getOrModifyF(".orNull()"), setF(".toOption()")) is NonNullFocus -> return@joinToString "" } """ - |${ele.visibilityModifierName} inline val ${ele.sourceClassName}.Companion.${focus.paramName}: $Optional<${ele.sourceClassName}, $targetClassName> inline get()= $Optional( + |$firstLine = $Optional( | getOrModify = $getOrModify, - | set = { ${ele.sourceName}: ${ele.sourceClassName}, value: $targetClassName -> $set } + | set = { ${ele.sourceName}: $sourceClassNameWithParams, value: $targetClassName -> $set } |) |""".trimMargin() } diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/prism.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/prism.kt index 044de977fc0..be20c59c6f0 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/prism.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/prism.kt @@ -1,5 +1,7 @@ package arrow.optics.plugin.internals +import com.google.devtools.ksp.symbol.KSTypeParameter + internal fun generatePrisms(ele: ADT, target: PrismTarget) = Snippet( `package` = ele.packageName, @@ -9,11 +11,24 @@ internal fun generatePrisms(ele: ADT, target: PrismTarget) = content = processElement(ele, target.foci) ) -private fun processElement(ele: ADT, foci: List): String = - foci.joinToString(separator = "\n\n") { focus -> +private fun processElement(ele: ADT, foci: List): String { + return foci.joinToString(separator = "\n\n") { focus -> + val sourceClassNameWithParams = + focus.refinedType?.qualifiedString() ?: "${ele.sourceClassName}${ele.angledTypeParameters}" + val angledTypeParameters = when { + focus.refinedArguments.isEmpty() -> "" + else -> focus.refinedArguments.joinToString(prefix = "<", separator = ",", postfix = ">") + } + val firstLine = when { + ele.typeParameters.isEmpty() -> + "${ele.visibilityModifierName} inline val ${ele.sourceClassName}.Companion.${focus.paramName}: $Prism<${ele.sourceClassName}, ${focus.className}> inline get()" + else -> + "${ele.visibilityModifierName} inline fun $angledTypeParameters ${ele.sourceClassName}.Companion.${focus.paramName}(): $Prism<$sourceClassNameWithParams, ${focus.className}>" + } + """ - |${ele.visibilityModifierName} inline val ${ele.sourceClassName}.Companion.${focus.paramName}: $Prism<${ele.sourceClassName}, ${focus.className}> inline get()= $Prism( - | getOrModify = { ${ele.sourceName}: ${ele.sourceClassName} -> + |$firstLine = $Prism( + | getOrModify = { ${ele.sourceName}: $sourceClassNameWithParams -> | when (${ele.sourceName}) { | is ${focus.className} -> ${ele.sourceName}.right() | else -> ${ele.sourceName}.left() @@ -23,3 +38,4 @@ private fun processElement(ele: ADT, foci: List): String = |) |""".trimMargin() } +} diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/processor.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/processor.kt index c974b948d91..8cc47d8755e 100644 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/processor.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/processor.kt @@ -3,10 +3,7 @@ package arrow.optics.plugin.internals import arrow.optics.plugin.isData import arrow.optics.plugin.isSealed import com.google.devtools.ksp.processing.KSPLogger -import com.google.devtools.ksp.symbol.KSClassDeclaration -import com.google.devtools.ksp.symbol.KSDeclaration -import com.google.devtools.ksp.symbol.KSType -import com.google.devtools.ksp.symbol.KSTypeArgument +import com.google.devtools.ksp.symbol.* import java.util.Locale internal fun adt(c: KSClassDeclaration, logger: KSPLogger): ADT = @@ -69,12 +66,13 @@ internal fun evalAnnotatedPrismElement( ): List = when { element.isSealed -> - element.sealedSubclassFqNameList().map { + element.getSealedSubclasses().map { Focus( - it, - it.substringAfterLast(".").replaceFirstChar { c -> c.lowercase(Locale.getDefault()) } + it.primaryConstructor?.returnType?.resolve()?.qualifiedString() ?: it.qualifiedNameOrSimpleName, + it.simpleName.asString().replaceFirstChar { c -> c.lowercase(Locale.getDefault()) }, + it.superTypes.first().resolve() ) - } + }.toList() else -> { logger.error(element.qualifiedNameOrSimpleName.prismErrorMessage, element) emptyList() @@ -135,14 +133,20 @@ internal fun evalAnnotatedIsoElement(element: KSClassDeclaration, logger: KSPLog internal fun KSClassDeclaration.getConstructorTypesNames(): List = primaryConstructor?.parameters?.map { it.type.resolve().qualifiedString() }.orEmpty() -internal fun KSType.qualifiedString(): String = when (val qname = declaration.qualifiedName?.asString()) { - null -> toString() - else -> { - val withArgs = when { - arguments.isEmpty() -> qname - else -> "$qname<${arguments.joinToString(separator = ", ") { it.qualifiedString() }}>" +internal fun KSType.qualifiedString(): String = when (declaration) { + is KSTypeParameter -> { + val n = declaration.simpleName.asString() + if (isMarkedNullable) "$n?" else n + } + else -> when (val qname = declaration.qualifiedName?.asString()) { + null -> toString() + else -> { + val withArgs = when { + arguments.isEmpty() -> qname + else -> "$qname<${arguments.joinToString(separator = ", ") { it.qualifiedString() }}>" + } + if (isMarkedNullable) "$withArgs?" else withArgs } - if (isMarkedNullable) "$withArgs?" else withArgs } } diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/IsoTests.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/IsoTests.kt index 6ce658d54a3..117ffcf7192 100755 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/IsoTests.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/IsoTests.kt @@ -20,6 +20,20 @@ class IsoTests { """.evals("r" to true) } + @Test + fun `Isos will be generated for generic data class`() { + """ + |$imports + |@optics + |data class IsoData( + | val field1: A + |) { companion object } + | + |val i: Iso, String> = IsoData.iso() + |val r = i != null + """.evals("r" to true) + } + @Test fun `Isos will be generated for data class with secondary constructors`() { """ diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/LensTests.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/LensTests.kt index d913d5006a2..8afcae52b03 100755 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/LensTests.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/LensTests.kt @@ -58,8 +58,54 @@ class LensTests { | companion object |} | - |val i: Lens = OpticsTest.time + |val i: Lens, Int> = OpticsTest.field() |val r = i != null - """.failsWith { it.contains("OpticsTest".typeParametersErrorMessage) } + """.evals("r" to true) + } + + @Test + fun `Lenses for nested classes`() { + """ + |$imports + |@optics + |data class LensData(val field1: String) { + | @optics + | data class InnerLensData(val field2: String) { + | companion object + | } + | companion object + |} + | + |val i: Lens = LensData.InnerLensData.field2 + |val r = i != null + """.evals("r" to true) + } + + @Test + fun `Lenses for nested classes with repeated names (#2718)`() { + """ + |$imports + |@optics + |data class LensData(val field1: String) { + | @optics + | data class InnerLensData(val field2: String) { + | companion object + | } + | companion object + |} + | + |@optics + |data class OtherLensData(val field1: String) { + | @optics + | data class InnerLensData(val field2: String) { + | companion object + | } + | companion object + |} + | + |val i: Lens = LensData.InnerLensData.field2 + |val j: Lens = OtherLensData.InnerLensData.field2 + |val r = i != null && j != null + """.evals("r" to true) } } diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/OptionalTests.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/OptionalTests.kt index c7b87fc3b0a..be661dfe02d 100755 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/OptionalTests.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/OptionalTests.kt @@ -18,6 +18,20 @@ class OptionalTests { """.evals("r" to true) } + @Test + fun `Optional will be generated for generic data class`() { + """ + |$imports + |@optics + |data class OptionalData( + | val field1: A? + |) { companion object } + | + |val i: Optional, String> = OptionalData.field1() + |val r = i != null + """.evals("r" to true) + } + @Test fun `Optional will be generated for data class with secondary constructors`() { """ diff --git a/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/PrismTests.kt b/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/PrismTests.kt index 5a8eacc5a51..cdc4e7cc4e9 100755 --- a/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/PrismTests.kt +++ b/arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/PrismTests.kt @@ -19,6 +19,21 @@ class PrismTests { """.evals("r" to true) } + @Test + fun `Prism will be generated for generic sealed class`() { + """ + |$imports + |@optics + |sealed class PrismSealed(val field: A, val nullable: B?) { + | data class PrismSealed1(private val a: String?) : PrismSealed("", a) + | data class PrismSealed2(private val b: C?) : PrismSealed("", b) + | companion object + |} + |val i: Prism, PrismSealed.PrismSealed1> = PrismSealed.prismSealed1() + |val r = i != null + """.evals("r" to true) + } + @Test fun `Prism will not be generated for sealed class if DSL Target is specified`() { """