diff --git a/formats/json-tests/commonTest/src/kotlinx/serialization/features/sealed/SealedInterfacesInlineSerialNameTest.kt b/formats/json-tests/commonTest/src/kotlinx/serialization/features/sealed/SealedInterfacesInlineSerialNameTest.kt new file mode 100644 index 000000000..c86a5d38d --- /dev/null +++ b/formats/json-tests/commonTest/src/kotlinx/serialization/features/sealed/SealedInterfacesInlineSerialNameTest.kt @@ -0,0 +1,51 @@ +/* + * Copyright 2017-2024 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.serialization.features.sealed + +import kotlinx.serialization.* +import kotlinx.serialization.json.* +import kotlin.jvm.* +import kotlin.test.* + +class SealedInterfacesInlineSerialNameTest : JsonTestBase() { + @Serializable + data class Child1Value( + val a: Int, + val b: String + ) + + @Serializable + data class Child2Value( + val c: Int, + val d: String + ) + + @Serializable + sealed interface Parent + + @Serializable + @SerialName("child1") + @JvmInline + value class Child1(val value: Child1Value) : Parent + + @Serializable + @SerialName("child2") + @JvmInline + value class Child2(val value: Child2Value) : Parent + + // From https://github.com/Kotlin/kotlinx.serialization/issues/2288 + @Test + fun testSealedInterfaceInlineSerialName() { + val messages = listOf( + Child1(Child1Value(1, "one")), + Child2(Child2Value(2, "two")) + ) + assertJsonFormAndRestored( + serializer(), + messages, + """[{"type":"child1","a":1,"b":"one"},{"type":"child2","c":2,"d":"two"}]""" + ) + } +} diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonEncoder.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonEncoder.kt index cf562de5c..31f0aa6d2 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonEncoder.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonEncoder.kt @@ -43,6 +43,7 @@ internal class StreamingJsonEncoder( // Forces serializer to wrap all values into quotes private var forceQuoting: Boolean = false private var polymorphicDiscriminator: String? = null + private var polymorphicSerialName: String? = null init { val i = mode.ordinal @@ -66,12 +67,12 @@ internal class StreamingJsonEncoder( } } - private fun encodeTypeInfo(descriptor: SerialDescriptor) { + private fun encodeTypeInfo(discriminator: String, serialName: String) { composer.nextItem() - encodeString(polymorphicDiscriminator!!) + encodeString(discriminator) composer.print(COLON) composer.space() - encodeString(descriptor.serialName) + encodeString(serialName) } override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { @@ -81,9 +82,11 @@ internal class StreamingJsonEncoder( composer.indent() } - if (polymorphicDiscriminator != null) { - encodeTypeInfo(descriptor) + val discriminator = polymorphicDiscriminator + if (discriminator != null) { + encodeTypeInfo(discriminator, polymorphicSerialName ?: descriptor.serialName) polymorphicDiscriminator = null + polymorphicSerialName = null } if (mode == newMode) { @@ -160,6 +163,7 @@ internal class StreamingJsonEncoder( when { descriptor.isUnsignedNumber -> StreamingJsonEncoder(composerAs(::ComposerForUnsignedNumbers), json, mode, null) descriptor.isUnquotedLiteral -> StreamingJsonEncoder(composerAs(::ComposerForUnquotedLiterals), json, mode, null) + polymorphicDiscriminator != null -> apply { polymorphicSerialName = descriptor.serialName } else -> super.encodeInline(descriptor) } diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonDecoder.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonDecoder.kt index 690b35e1f..5b00da72c 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonDecoder.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonDecoder.kt @@ -35,7 +35,8 @@ internal fun Json.readPolymorphicJson( private sealed class AbstractJsonTreeDecoder( override val json: Json, - open val value: JsonElement + open val value: JsonElement, + protected val polymorphicDiscriminator: String? = null ) : NamedValueDecoder(), JsonDecoder { override val serializersModule: SerializersModule @@ -63,7 +64,7 @@ private sealed class AbstractJsonTreeDecoder( { JsonTreeMapDecoder(json, cast(currentObject, descriptor)) }, { JsonTreeListDecoder(json, cast(currentObject, descriptor)) } ) - else -> JsonTreeDecoder(json, cast(currentObject, descriptor)) + else -> JsonTreeDecoder(json, cast(currentObject, descriptor), polymorphicDiscriminator) } } @@ -159,11 +160,15 @@ private sealed class AbstractJsonTreeDecoder( override fun decodeInline(descriptor: SerialDescriptor): Decoder { return if (currentTagOrNull != null) super.decodeInline(descriptor) - else JsonPrimitiveDecoder(json, value).decodeInline(descriptor) + else JsonPrimitiveDecoder(json, value, polymorphicDiscriminator).decodeInline(descriptor) } } -private class JsonPrimitiveDecoder(json: Json, override val value: JsonElement) : AbstractJsonTreeDecoder(json, value) { +private class JsonPrimitiveDecoder( + json: Json, + override val value: JsonElement, + polymorphicDiscriminator: String? = null +) : AbstractJsonTreeDecoder(json, value, polymorphicDiscriminator) { init { pushTag(PRIMITIVE_TAG) @@ -180,9 +185,9 @@ private class JsonPrimitiveDecoder(json: Json, override val value: JsonElement) private open class JsonTreeDecoder( json: Json, override val value: JsonObject, - private val polyDiscriminator: String? = null, + polymorphicDiscriminator: String? = null, private val polyDescriptor: SerialDescriptor? = null -) : AbstractJsonTreeDecoder(json, value) { +) : AbstractJsonTreeDecoder(json, value, polymorphicDiscriminator) { private var position = 0 private var forceNull: Boolean = false /* @@ -251,7 +256,7 @@ private open class JsonTreeDecoder( // in endStructure can filter polyDiscriminator out. if (descriptor === polyDescriptor) { return JsonTreeDecoder( - json, cast(currentObject(), polyDescriptor), polyDiscriminator, polyDescriptor + json, cast(currentObject(), polyDescriptor), polymorphicDiscriminator, polyDescriptor ) } @@ -271,7 +276,7 @@ private open class JsonTreeDecoder( } for (key in value.keys) { - if (key !in names && key != polyDiscriminator) { + if (key !in names && key != polymorphicDiscriminator) { throw UnknownKeyException(key, value.toString()) } } diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonEncoder.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonEncoder.kt index 5e3c80868..8363bfa25 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonEncoder.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/TreeJsonEncoder.kt @@ -35,6 +35,7 @@ private sealed class AbstractJsonTreeEncoder( protected val configuration = json.configuration private var polymorphicDiscriminator: String? = null + private var polymorphicSerialName: String? = null override fun elementName(descriptor: SerialDescriptor, index: Int): String = descriptor.getJsonElementName(json, index) @@ -112,8 +113,12 @@ private sealed class AbstractJsonTreeEncoder( } override fun encodeInline(descriptor: SerialDescriptor): Encoder { - return if (currentTagOrNull != null) super.encodeInline(descriptor) - else JsonPrimitiveEncoder(json, nodeConsumer).encodeInline(descriptor) + return if (currentTagOrNull != null) { + if (polymorphicDiscriminator != null) polymorphicSerialName = descriptor.serialName + super.encodeInline(descriptor) + } else { + JsonPrimitiveEncoder(json, nodeConsumer).encodeInline(descriptor) + } } @SuppressAnimalSniffer // Long(Integer).toUnsignedString(long) @@ -148,9 +153,11 @@ private sealed class AbstractJsonTreeEncoder( else -> JsonTreeEncoder(json, consumer) } - if (polymorphicDiscriminator != null) { - encoder.putElement(polymorphicDiscriminator!!, JsonPrimitive(descriptor.serialName)) + val discriminator = polymorphicDiscriminator + if (discriminator != null) { + encoder.putElement(discriminator, JsonPrimitive(polymorphicSerialName ?: descriptor.serialName)) polymorphicDiscriminator = null + polymorphicSerialName = null } return encoder