Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support generics and inner classes in @optics #2776

Merged
merged 3 commits into from Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -38,12 +38,6 @@ class OpticsProcessor(private val codegen: CodeGenerator, private val logger: KS
return
}

// check that it does not have type arguments
if (klass.typeParameters.isNotEmpty()) {
logger.error(klass.qualifiedNameOrSimpleName.typeParametersErrorMessage, klass)
return
}

// check that the companion object exists
if (klass.companionObject == null) {
logger.error(klass.qualifiedNameOrSimpleName.noCompanion, klass)
Expand Down
Expand Up @@ -2,25 +2,34 @@ package arrow.optics.plugin.internals

import arrow.optics.plugin.companionObject
import com.google.devtools.ksp.getVisibility
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSName
import com.google.devtools.ksp.symbol.Visibility
import com.google.devtools.ksp.symbol.*
import java.util.Locale

data class ADT(val pckg: KSName, val declaration: KSClassDeclaration, val targets: List<Target>) {
val sourceClassName = declaration.qualifiedNameOrSimpleName
val sourceName = declaration.simpleName.asString().replaceFirstChar { it.lowercase(Locale.getDefault()) }
val simpleName = declaration.simpleName.asString()
val simpleName = declaration.nameWithParentClass
val packageName = pckg.asString()
val visibilityModifierName = when (declaration.companionObject?.getVisibility()) {
Visibility.INTERNAL -> "internal"
else -> "public"
}
val typeParameters: List<String> = declaration.typeParameters.map { it.simpleName.asString() }
val angledTypeParameters: String = when {
typeParameters.isEmpty() -> ""
else -> "<${typeParameters.joinToString(separator = ",")}>"
}

operator fun Snippet.plus(snippet: Snippet): Snippet =
copy(imports = imports + snippet.imports, content = "$content\n${snippet.content}")
}

val KSClassDeclaration.nameWithParentClass: String
get() = when (val parent = parentDeclaration) {
is KSClassDeclaration -> parent.nameWithParentClass + "." + simpleName.asString()
else -> simpleName.asString()
}

enum class OpticsTarget {
ISO,
LENS,
Expand Down Expand Up @@ -61,27 +70,46 @@ typealias NullableFocus = Focus.Nullable
sealed class Focus {

companion object {
operator fun invoke(fullName: String, paramName: String): Focus =
operator fun invoke(fullName: String, paramName: String, refinedType: KSType? = null): Focus =
when {
fullName.endsWith("?") -> Nullable(fullName, paramName)
fullName.startsWith("`arrow`.`core`.`Option`") -> Option(fullName, paramName)
else -> NonNull(fullName, paramName)
fullName.endsWith("?") -> Nullable(fullName, paramName, refinedType)
fullName.startsWith("`arrow`.`core`.`Option`") -> Option(fullName, paramName, refinedType)
else -> NonNull(fullName, paramName, refinedType)
}
}

abstract val className: String
abstract val paramName: String

data class Nullable(override val className: String, override val paramName: String) : Focus() {
// only used for type-refining prisms
abstract val refinedType: KSType?

val refinedArguments: List<String>
get() = refinedType?.arguments?.filter {
it.type?.resolve()?.declaration is KSTypeParameter
}?.map { it.qualifiedString() }.orEmpty()

data class Nullable(
override val className: String,
override val paramName: String,
override val refinedType: KSType?
) : Focus() {
val nonNullClassName = className.dropLast(1)
}

data class Option(override val className: String, override val paramName: String) : Focus() {
data class Option(
override val className: String,
override val paramName: String,
override val refinedType: KSType?
) : Focus() {
val nestedClassName =
Regex("`arrow`.`core`.`Option`<(.*)>$").matchEntire(className)!!.groupValues[1]
}

data class NonNull(override val className: String, override val paramName: String) : Focus()
data class NonNull(
override val className: String,
override val paramName: String,
override val refinedType: KSType?
) : Focus()
}

const val Lens = "arrow.optics.Lens"
Expand Down
Expand Up @@ -22,7 +22,8 @@ fun generatePrismDsl(ele: ADT, isoOptic: SealedClassDsl): Snippet =
)

private fun processLensSyntax(ele: ADT, foci: List<Focus>): String =
foci.joinToString(separator = "\n") { focus ->
if (ele.typeParameters.isEmpty()) {
foci.joinToString(separator = "\n") { focus ->
"""
|${ele.visibilityModifierName} inline val <S> $Iso<S, ${ele.sourceClassName}>.${focus.lensParamName()}: $Lens<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()}
|${ele.visibilityModifierName} inline val <S> $Lens<S, ${ele.sourceClassName}>.${focus.lensParamName()}: $Lens<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()}
Expand All @@ -34,17 +35,36 @@ private fun processLensSyntax(ele: ADT, foci: List<Focus>): String =
|${ele.visibilityModifierName} inline val <S> $Fold<S, ${ele.sourceClassName}>.${focus.lensParamName()}: $Fold<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()}
|${ele.visibilityModifierName} inline val <S> $Every<S, ${ele.sourceClassName}>.${focus.lensParamName()}: $Every<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()}
|""".trimMargin()
}
} else {
val sourceClassNameWithParams = "${ele.sourceClassName}${ele.angledTypeParameters}"
val joinedTypeParams = ele.typeParameters.joinToString(separator=",")
foci.joinToString(separator = "\n") { focus ->
"""
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Iso<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Lens<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Lens<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Lens<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Optional<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Optional<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Prism<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Optional<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Getter<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Getter<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Setter<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Setter<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Traversal<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Traversal<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Fold<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Fold<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Every<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Every<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|""".trimMargin()
}
}

private fun processOptionalSyntax(ele: ADT, optic: DataClassDsl) =
optic.foci.filterNot { it is NonNullFocus }.joinToString(separator = "\n") { focus ->
private fun processOptionalSyntax(ele: ADT, optic: DataClassDsl): String {
val sourceClassNameWithParams = "${ele.sourceClassName}${ele.angledTypeParameters}"
val joinedTypeParams = ele.typeParameters.joinToString(separator=",")
return optic.foci.filterNot { it is NonNullFocus }.joinToString(separator = "\n") { focus ->
val targetClassName =
when (focus) {
is NullableFocus -> focus.nonNullClassName
is OptionFocus -> focus.nestedClassName
is NonNullFocus -> ""
is Focus.Nullable -> focus.nonNullClassName
is Focus.Option -> focus.nestedClassName
is Focus.NonNull -> ""
}

if (ele.typeParameters.isEmpty()) {
"""
|${ele.visibilityModifierName} inline val <S> $Iso<S, ${ele.sourceClassName}>.${focus.paramName}: $Optional<S, $targetClassName> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Lens<S, ${ele.sourceClassName}>.${focus.paramName}: $Optional<S, $targetClassName> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
Expand All @@ -55,10 +75,24 @@ private fun processOptionalSyntax(ele: ADT, optic: DataClassDsl) =
|${ele.visibilityModifierName} inline val <S> $Fold<S, ${ele.sourceClassName}>.${focus.paramName}: $Fold<S, $targetClassName> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Every<S, ${ele.sourceClassName}>.${focus.paramName}: $Every<S, $targetClassName> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
|""".trimMargin()
} else {
"""
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Iso<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Lens<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Optional<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Prism<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Setter<S, $sourceClassNameWithParams>.${focus.paramName}(): $Setter<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Traversal<S, $sourceClassNameWithParams>.${focus.paramName}(): $Traversal<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Fold<S, $sourceClassNameWithParams>.${focus.paramName}(): $Fold<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Every<S, $sourceClassNameWithParams>.${focus.paramName}(): $Every<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|""".trimMargin()
}
}
}

private fun processPrismSyntax(ele: ADT, dsl: SealedClassDsl): String =
dsl.foci.joinToString(separator = "\n\n") { focus ->
if (ele.typeParameters.isEmpty()) {
dsl.foci.joinToString(separator = "\n\n") { focus ->
"""
|${ele.visibilityModifierName} inline val <S> $Iso<S, ${ele.sourceClassName}>.${focus.paramName}: $Prism<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Lens<S, ${ele.sourceClassName}>.${focus.paramName}: $Optional<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
Expand All @@ -69,4 +103,23 @@ private fun processPrismSyntax(ele: ADT, dsl: SealedClassDsl): String =
|${ele.visibilityModifierName} inline val <S> $Fold<S, ${ele.sourceClassName}>.${focus.paramName}: $Fold<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Every<S, ${ele.sourceClassName}>.${focus.paramName}: $Every<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
|""".trimMargin()
}
} else {
dsl.foci.joinToString(separator = "\n\n") { focus ->
val sourceClassNameWithParams = focus.refinedType?.qualifiedString() ?: "${ele.sourceClassName}${ele.angledTypeParameters}"
val joinedTypeParams = when {
focus.refinedArguments.isEmpty() -> ""
else -> focus.refinedArguments.joinToString(separator=",")
}
"""
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Iso<S, $sourceClassNameWithParams>.${focus.paramName}(): $Prism<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Lens<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Optional<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Prism<S, $sourceClassNameWithParams>.${focus.paramName}(): $Prism<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Setter<S, $sourceClassNameWithParams>.${focus.paramName}(): $Setter<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Traversal<S, $sourceClassNameWithParams>.${focus.paramName}(): $Traversal<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Fold<S, $sourceClassNameWithParams>.${focus.paramName}(): $Fold<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Every<S, $sourceClassNameWithParams>.${focus.paramName}(): $Every<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|""".trimMargin()
}
}
Expand Up @@ -63,9 +63,17 @@ private fun processElement(iso: ADT, target: Target): String {
"tuple: ${focusType()} -> ${(foci.indices).joinToString(prefix = "${iso.sourceClassName}(", postfix = ")", transform = { "tuple.${letters[it]}" })}"
}

val sourceClassNameWithParams = "${iso.sourceClassName}${iso.angledTypeParameters}"
val firstLine = when {
iso.typeParameters.isEmpty() ->
"${iso.visibilityModifierName} inline val ${iso.sourceClassName}.Companion.iso: $Iso<${iso.sourceClassName}, ${focusType()}> inline get()"
else ->
"${iso.visibilityModifierName} inline fun ${iso.angledTypeParameters} ${iso.sourceClassName}.Companion.iso(): $Iso<$sourceClassNameWithParams, ${focusType()}>"
}

return """
|${iso.visibilityModifierName} inline val ${iso.sourceClassName}.Companion.iso: $Iso<${iso.sourceClassName}, ${focusType()}> inline get()= $Iso(
| get = { ${iso.sourceName}: ${iso.sourceClassName} -> ${tupleConstructor()} },
|$firstLine = $Iso(
| get = { ${iso.sourceName}: $sourceClassNameWithParams -> ${tupleConstructor()} },
| reverseGet = { ${classConstructorFromTuple()} }
|)
|""".trimMargin()
Expand Down
Expand Up @@ -20,17 +20,24 @@ private fun String.toUpperCamelCase(): String =
}
)

private fun processElement(adt: ADT, foci: List<Focus>): String =
foci.joinToString(separator = "\n") { focus ->
private fun processElement(adt: ADT, foci: List<Focus>): String {
val sourceClassNameWithParams = "${adt.sourceClassName}${adt.angledTypeParameters}"
return foci.joinToString(separator = "\n") { focus ->
val firstLine = when {
adt.typeParameters.isEmpty() ->
"${adt.visibilityModifierName} inline val ${adt.sourceClassName}.Companion.${focus.lensParamName()}: $Lens<${adt.sourceClassName}, ${focus.className}> inline get()"
else ->
"${adt.visibilityModifierName} inline fun ${adt.angledTypeParameters} ${adt.sourceClassName}.Companion.${focus.lensParamName()}(): $Lens<$sourceClassNameWithParams, ${focus.className}>"
}
"""
|${adt.visibilityModifierName} inline val ${adt.sourceClassName}.Companion.${focus.lensParamName()}: $Lens<${adt.sourceClassName}, ${focus.className}> inline get()= $Lens(
| get = { ${adt.sourceName}: ${adt.sourceClassName} -> ${adt.sourceName}.${
|$firstLine = $Lens(
| get = { ${adt.sourceName}: $sourceClassNameWithParams -> ${adt.sourceName}.${
focus.paramName.plusIfNotBlank(
prefix = "`",
postfix = "`"
)
} },
| set = { ${adt.sourceName}: ${adt.sourceClassName}, value: ${focus.className} -> ${adt.sourceName}.copy(${
| set = { ${adt.sourceName}: $sourceClassNameWithParams, value: ${focus.className} -> ${adt.sourceName}.copy(${
focus.paramName.plusIfNotBlank(
prefix = "`",
postfix = "`"
Expand All @@ -39,6 +46,7 @@ private fun processElement(adt: ADT, foci: List<Focus>): String =
|)
|""".trimMargin()
}
}

fun Focus.lensParamName(): String =
when (this) {
Expand Down
Expand Up @@ -11,8 +11,23 @@ internal fun generateOptionals(ele: ADT, target: OptionalTarget) =

private fun processElement(ele: ADT, foci: List<Focus>): String =
foci.joinToString(separator = "\n") { focus ->

val targetClassName = when (focus) {
is NullableFocus -> focus.nonNullClassName
is OptionFocus -> focus.nestedClassName
is NonNullFocus -> return@joinToString ""
}

val sourceClassNameWithParams = "${ele.sourceClassName}${ele.angledTypeParameters}"
val firstLine = when {
ele.typeParameters.isEmpty() ->
"${ele.visibilityModifierName} inline val ${ele.sourceClassName}.Companion.${focus.paramName}: $Optional<${ele.sourceClassName}, $targetClassName> inline get()"
else ->
"${ele.visibilityModifierName} inline fun ${ele.angledTypeParameters} ${ele.sourceClassName}.Companion.${focus.paramName}(): $Optional<$sourceClassNameWithParams, $targetClassName>"
}

fun getOrModifyF(toNullable: String = "") =
"{ ${ele.sourceName}: ${ele.sourceClassName} -> ${ele.sourceName}.${
"{ ${ele.sourceName}: $sourceClassNameWithParams -> ${ele.sourceName}.${
focus.paramName.plusIfNotBlank(
prefix = "`",
postfix = "`"
Expand All @@ -21,18 +36,17 @@ private fun processElement(ele: ADT, foci: List<Focus>): String =
fun setF(fromNullable: String = "") =
"${ele.sourceName}.copy(${focus.paramName.plusIfNotBlank(prefix = "`", postfix = "`")} = value$fromNullable)"

val (targetClassName, getOrModify, set) =
val (getOrModify, set) =
when (focus) {
is NullableFocus -> Triple(focus.nonNullClassName, getOrModifyF(), setF())
is OptionFocus ->
Triple(focus.nestedClassName, getOrModifyF(".orNull()"), setF(".toOption()"))
is NullableFocus -> Pair(getOrModifyF(), setF())
is OptionFocus -> Pair(getOrModifyF(".orNull()"), setF(".toOption()"))
is NonNullFocus -> return@joinToString ""
}

"""
|${ele.visibilityModifierName} inline val ${ele.sourceClassName}.Companion.${focus.paramName}: $Optional<${ele.sourceClassName}, $targetClassName> inline get()= $Optional(
|$firstLine = $Optional(
| getOrModify = $getOrModify,
| set = { ${ele.sourceName}: ${ele.sourceClassName}, value: $targetClassName -> $set }
| set = { ${ele.sourceName}: $sourceClassNameWithParams, value: $targetClassName -> $set }
|)
|""".trimMargin()
}