From 6047db87831b23dccc8f42ed5b94af9047c14bf6 Mon Sep 17 00:00:00 2001 From: Paul de Vrieze Date: Wed, 2 Feb 2022 13:38:22 +0000 Subject: [PATCH] Protobuf packed encoding/decoding (#1830) * Create an annotation to request packing of collections. * The specification only allows packing for primitive types (wire types 1, 2 or 5) to allow decoders decode either format independently of the proto specification. * Make pushback work in respect to currentType/currentId. This allows it to be used to effectively peek the type without assumptions on state. * Clarify in the documentation that reading will (per the standard) supports inputs in either format, independent of the annotation. The annotation only affects writing. * Support decoding "packed" arrays as toplevels. Add tests for handling of strings and "packed" toplevel arrays. The checking for eof works as bytesize is always >= array length. --- .../api/kotlinx-serialization-protobuf.api | 7 + .../serialization/protobuf/ProtoTypes.kt | 15 +- .../protobuf/internal/Helpers.kt | 29 +++- .../protobuf/internal/PackedArrayDecoder.kt | 32 +++++ .../protobuf/internal/PackedArrayEncoder.kt | 31 ++++ .../protobuf/internal/ProtobufDecoding.kt | 8 +- .../protobuf/internal/ProtobufEncoding.kt | 26 ++-- .../protobuf/internal/ProtobufReader.kt | 19 ++- .../protobuf/internal/Streams.kt | 2 +- .../protobuf/PackedArraySerializerTest.kt | 136 ++++++++++++++++++ 10 files changed, 288 insertions(+), 17 deletions(-) create mode 100644 formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/PackedArrayDecoder.kt create mode 100644 formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/PackedArrayEncoder.kt create mode 100644 formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/PackedArraySerializerTest.kt diff --git a/formats/protobuf/api/kotlinx-serialization-protobuf.api b/formats/protobuf/api/kotlinx-serialization-protobuf.api index ddc91d4f9..65093b2c8 100644 --- a/formats/protobuf/api/kotlinx-serialization-protobuf.api +++ b/formats/protobuf/api/kotlinx-serialization-protobuf.api @@ -38,6 +38,13 @@ public final class kotlinx/serialization/protobuf/ProtoNumber$Impl : kotlinx/ser public final synthetic fun number ()I } +public abstract interface annotation class kotlinx/serialization/protobuf/ProtoPacked : java/lang/annotation/Annotation { +} + +public final class kotlinx/serialization/protobuf/ProtoPacked$Impl : kotlinx/serialization/protobuf/ProtoPacked { + public fun ()V +} + public abstract interface annotation class kotlinx/serialization/protobuf/ProtoType : java/lang/annotation/Annotation { public abstract fun type ()Lkotlinx/serialization/protobuf/ProtoIntegerType; } diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoTypes.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoTypes.kt index ccf553942..3b62d4dc8 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoTypes.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoTypes.kt @@ -32,9 +32,9 @@ public annotation class ProtoNumber(public val number: Int) @Suppress("NO_EXPLICIT_VISIBILITY_IN_API_MODE_WARNING") @ExperimentalSerializationApi public enum class ProtoIntegerType(internal val signature: Long) { - DEFAULT(0L shl 32), - SIGNED(1L shl 32), - FIXED(2L shl 32); + DEFAULT(0L shl 33), + SIGNED(1L shl 33), + FIXED(2L shl 33); } /** @@ -45,3 +45,12 @@ public enum class ProtoIntegerType(internal val signature: Long) { @Target(AnnotationTarget.PROPERTY) @ExperimentalSerializationApi public annotation class ProtoType(public val type: ProtoIntegerType) + + +/** + * Instructs that a particular collection should be written as [packed array](https://developers.google.com/protocol-buffers/docs/encoding#packed) + */ +@SerialInfo +@Target(AnnotationTarget.PROPERTY) +@ExperimentalSerializationApi +public annotation class ProtoPacked diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Helpers.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Helpers.kt index 0831b931e..59533db0f 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Helpers.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Helpers.kt @@ -16,7 +16,15 @@ internal const val i64 = 1 internal const val SIZE_DELIMITED = 2 internal const val i32 = 5 -private const val MASK = Int.MAX_VALUE.toLong() shl 32 +private const val INTTYPEMASK = (Int.MAX_VALUE.toLong() shr 1) shl 33 +private const val PACKEDMASK = 1L shl 32 + +@Suppress("NOTHING_TO_INLINE") +internal inline fun ProtoDesc(protoId: Int, type: ProtoIntegerType, packed: Boolean): ProtoDesc { + val packedBits = if (packed) 1L shl 32 else 0L + val signature = type.signature or packedBits + return signature or protoId.toLong() +} @Suppress("NOTHING_TO_INLINE") internal inline fun ProtoDesc(protoId: Int, type: ProtoIntegerType): ProtoDesc { @@ -26,25 +34,40 @@ internal inline fun ProtoDesc(protoId: Int, type: ProtoIntegerType): ProtoDesc { internal inline val ProtoDesc.protoId: Int get() = (this and Int.MAX_VALUE.toLong()).toInt() internal val ProtoDesc.integerType: ProtoIntegerType - get() = when(this and MASK) { + get() = when(this and INTTYPEMASK) { ProtoIntegerType.DEFAULT.signature -> ProtoIntegerType.DEFAULT ProtoIntegerType.SIGNED.signature -> ProtoIntegerType.SIGNED else -> ProtoIntegerType.FIXED } +internal val SerialDescriptor.isPackable: Boolean + @OptIn(kotlinx.serialization.ExperimentalSerializationApi::class) + get() = when (kind) { + PrimitiveKind.STRING, + !is PrimitiveKind -> false + else -> true + } + +internal val ProtoDesc.isPacked: Boolean + get() = (this and PACKEDMASK) != 0L + internal fun SerialDescriptor.extractParameters(index: Int): ProtoDesc { val annotations = getElementAnnotations(index) var protoId: Int = index + 1 var format: ProtoIntegerType = ProtoIntegerType.DEFAULT + var protoPacked = false + for (i in annotations.indices) { // Allocation-friendly loop val annotation = annotations[i] if (annotation is ProtoNumber) { protoId = annotation.number } else if (annotation is ProtoType) { format = annotation.type + } else if (annotation is ProtoPacked) { + protoPacked = true } } - return ProtoDesc(protoId, format) + return ProtoDesc(protoId, format, protoPacked) } internal fun extractProtoId(descriptor: SerialDescriptor, index: Int, zeroBasedDefault: Boolean): Int { diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/PackedArrayDecoder.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/PackedArrayDecoder.kt new file mode 100644 index 000000000..b17d5119e --- /dev/null +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/PackedArrayDecoder.kt @@ -0,0 +1,32 @@ +package kotlinx.serialization.protobuf.internal + +import kotlinx.serialization.* +import kotlinx.serialization.descriptors.* +import kotlinx.serialization.encoding.* +import kotlinx.serialization.protobuf.* + +@OptIn(ExperimentalSerializationApi::class) +internal class PackedArrayDecoder( + proto: ProtoBuf, + reader: ProtobufReader, + descriptor: SerialDescriptor, +) : ProtobufDecoder(proto, reader, descriptor) { + private var nextIndex: Int = 0 + + // Tags are omitted in the packed array format + override fun SerialDescriptor.getTag(index: Int): ProtoDesc = MISSING_TAG + + override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { + throw SerializationException("Packing only supports primitive number types. The input type however was a struct: $descriptor") + } + + override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + // We need eof here as there is no tag to read in packed form. + if (reader.eof) return CompositeDecoder.DECODE_DONE + return nextIndex++ + } + + override fun decodeTaggedString(tag: ProtoDesc): String { + throw SerializationException("Packing only supports primitive number types. The actual reading is for string.") + } +} \ No newline at end of file diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/PackedArrayEncoder.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/PackedArrayEncoder.kt new file mode 100644 index 000000000..812ca3074 --- /dev/null +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/PackedArrayEncoder.kt @@ -0,0 +1,31 @@ +package kotlinx.serialization.protobuf.internal + +import kotlinx.serialization.* +import kotlinx.serialization.descriptors.* +import kotlinx.serialization.encoding.* +import kotlinx.serialization.protobuf.* + +@OptIn(ExperimentalSerializationApi::class) +internal class PackedArrayEncoder( + proto: ProtoBuf, + writer: ProtobufWriter, + curTag: ProtoDesc, + descriptor: SerialDescriptor, + stream: ByteArrayOutput = ByteArrayOutput() +) : NestedRepeatedEncoder(proto, writer, curTag, descriptor, stream) { + + // Triggers not writing header + override fun SerialDescriptor.getTag(index: Int): ProtoDesc = MISSING_TAG + + override fun beginCollection(descriptor: SerialDescriptor, collectionSize: Int): CompositeEncoder { + throw SerializationException("Packing only supports primitive number types") + } + + override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { + throw SerializationException("Packing only supports primitive number types") + } + + override fun encodeTaggedString(tag: ProtoDesc, value: String) { + throw SerializationException("Packing only supports primitive number types") + } +} diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt index f47b83dd9..09773919a 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufDecoding.kt @@ -108,6 +108,11 @@ internal open class ProtobufDecoder( reader.readTag() // all elements always have id = 1 RepeatedDecoder(proto, reader, ProtoDesc(1, ProtoIntegerType.DEFAULT), descriptor) + + } else if (reader.currentType == SIZE_DELIMITED && descriptor.getElementDescriptor(0).isPackable) { + val sliceReader = ProtobufReader(reader.objectInput()) + PackedArrayDecoder(proto, sliceReader, descriptor) + } else { RepeatedDecoder(proto, reader, tag, descriptor) } @@ -287,7 +292,8 @@ private class RepeatedDecoder( private fun decodeListIndexNoTag(): Int { val size = -tagOrSize val idx = ++index - if (idx.toLong() == size) return CompositeDecoder.DECODE_DONE + // Check for eof is here for the case that it is an out-of-spec packed array where size is bytesize not list length. + if (idx.toLong() == size || reader.eof) return CompositeDecoder.DECODE_DONE return idx } diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufEncoding.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufEncoding.kt index 670143776..fab7a09df 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufEncoding.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufEncoding.kt @@ -30,13 +30,17 @@ internal open class ProtobufEncoder( ): CompositeEncoder = when (descriptor.kind) { StructureKind.LIST -> { val tag = currentTagOrDefault - if (tag == MISSING_TAG) { - writer.writeInt(collectionSize) - } - if (this.descriptor.kind == StructureKind.LIST && tag != MISSING_TAG && this.descriptor != descriptor) { - NestedRepeatedEncoder(proto, writer, tag, descriptor) + if (tag.isPacked && descriptor.getElementDescriptor(0).isPackable) { + PackedArrayEncoder(proto, writer, currentTagOrDefault, descriptor) } else { - RepeatedEncoder(proto, writer, tag, descriptor) + if (tag == MISSING_TAG) { + writer.writeInt(collectionSize) + } + if (this.descriptor.kind == StructureKind.LIST && tag != MISSING_TAG && this.descriptor != descriptor) { + NestedRepeatedEncoder(proto, writer, tag, descriptor) + } else { + RepeatedEncoder(proto, writer, tag, descriptor) + } } } StructureKind.MAP -> { @@ -47,7 +51,13 @@ internal open class ProtobufEncoder( } override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder = when (descriptor.kind) { - StructureKind.LIST -> RepeatedEncoder(proto, writer, currentTagOrDefault, descriptor) + StructureKind.LIST -> { + if (descriptor.getElementDescriptor(0).isPackable && currentTagOrDefault.isPacked) { + PackedArrayEncoder(proto, writer, currentTagOrDefault, descriptor) + } else { + RepeatedEncoder(proto, writer, currentTagOrDefault, descriptor) + } + } StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> { val tag = currentTagOrDefault if (tag == MISSING_TAG && descriptor == this.descriptor) this @@ -183,7 +193,7 @@ private class RepeatedEncoder( override fun SerialDescriptor.getTag(index: Int) = curTag } -private class NestedRepeatedEncoder( +internal open class NestedRepeatedEncoder( proto: ProtoBuf, @JvmField val writer: ProtobufWriter, @JvmField val curTag: ProtoDesc, diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufReader.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufReader.kt index 4bf2d968c..c7d4ea087 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufReader.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufReader.kt @@ -15,14 +15,27 @@ internal class ProtobufReader(private val input: ByteArrayInput) { @JvmField public var currentType = -1 private var pushBack = false + private var pushBackHeader = 0 + + public val eof + get() = !pushBack && input.availableBytes == 0 public fun readTag(): Int { if (pushBack) { pushBack = false - return currentId + val previousHeader = (currentId shl 3) or currentType + return updateIdAndType(pushBackHeader).also { + pushBackHeader = previousHeader + } } + // Header to use when pushed back is the old id/type + pushBackHeader = (currentId shl 3) or currentType val header = input.readVarint64(true).toInt() + return updateIdAndType(header) + } + + private fun updateIdAndType(header: Int): Int { return if (header == -1) { currentId = -1 currentType = -1 @@ -36,6 +49,10 @@ internal class ProtobufReader(private val input: ByteArrayInput) { public fun pushBackTag() { pushBack = true + + val nextHeader = (currentId shl 3) or currentType + updateIdAndType(pushBackHeader) + pushBackHeader = nextHeader } fun skipElement() { diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Streams.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Streams.kt index 991349cc7..575c5e742 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Streams.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/Streams.kt @@ -8,7 +8,7 @@ import kotlinx.serialization.* internal class ByteArrayInput(private var array: ByteArray, private val endIndex: Int = array.size) { private var position: Int = 0 - private val availableBytes: Int get() = endIndex - position + val availableBytes: Int get() = endIndex - position fun slice(size: Int): ByteArrayInput { ensureEnoughBytes(size) diff --git a/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/PackedArraySerializerTest.kt b/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/PackedArraySerializerTest.kt new file mode 100644 index 000000000..e7bf67622 --- /dev/null +++ b/formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/PackedArraySerializerTest.kt @@ -0,0 +1,136 @@ +/* + * Copyright 2017-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.serialization.protobuf + +import kotlinx.serialization.* +import kotlin.test.* + +class PackedArraySerializerTest { + + abstract class BaseFloatArrayCarrier { + abstract val createdAt: ULong + abstract val vector: FloatArray + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is BaseFloatArrayCarrier) return false + + if (createdAt != other.createdAt) return false + if (!vector.contentEquals(other.vector)) return false + + return true + } + + override fun hashCode(): Int { + var result = createdAt.hashCode() + result = 31 * result + vector.contentHashCode() + return result + } + } + + @Serializable + class PackedFloatArrayCarrier( + @ProtoNumber(2) + override val createdAt: ULong, + @ProtoPacked + @ProtoNumber(3) override val vector: FloatArray + ) : BaseFloatArrayCarrier() + + @Serializable + class NonPackedFloatArrayCarrier( + @ProtoNumber(2) + override val createdAt: ULong, + @ProtoNumber(3) override val vector: FloatArray + ) : BaseFloatArrayCarrier() + + @Serializable + data class PackedStringCarrier( + @ProtoNumber(0) + @ProtoPacked + val s: List + ) + + /** + * Test that when packing is specified the array is encoded as packed + */ + @Test + fun testEncodePackedFloatArrayProtobuf() { + val obj = PackedFloatArrayCarrier(1234567890L.toULong(), floatArrayOf(1f, 2f, 3f)) + val s = ProtoBuf.encodeToHexString(PackedFloatArrayCarrier.serializer(), obj).uppercase() + assertEquals("""10D285D8CC041A0C0000803F0000004000004040""", s) + } + + /** + * Test that when packing is not specified the array is not encoded as packed. Note that protobuf 3 + * should encode as packed by default. The format doesn't allow specifying versions at this point so + * the default remains the original. + */ + @Test + fun testEncodeNonPackedFloatArrayProtobuf() { + val obj = NonPackedFloatArrayCarrier(1234567890L.toULong(), floatArrayOf(1f, 2f, 3f)) + val s = ProtoBuf.encodeToHexString(NonPackedFloatArrayCarrier.serializer(), obj).uppercase() + assertEquals("""10D285D8CC041D0000803F1D000000401D00004040""", s) + } + + /** + * Per the specification decoders should support both packed and repeated fields independent of whether + * a field is specified as packed in the schema. Check that decoding works with both types (packed and non-packed) + * if the data itself is packed. + */ + @Test + fun testDecodePackedFloatArrayProtobuf() { + val obj: BaseFloatArrayCarrier = PackedFloatArrayCarrier(1234567890L.toULong(), floatArrayOf(1f, 2f, 3f)) + val s = """10D285D8CC041A0C0000803F0000004000004040""" + val decodedPacked = ProtoBuf.decodeFromHexString(PackedFloatArrayCarrier.serializer(), s) + assertEquals(obj, decodedPacked) + val decodedNonPacked = ProtoBuf.decodeFromHexString(NonPackedFloatArrayCarrier.serializer(), s) + assertEquals(obj, decodedNonPacked) + } + + /** + * Per the specification decoders should support both packed and repeated fields independent of whether + * a field is specified as packed in the schema. Check that decoding works with both types (packed and non-packed) + * if the data itself is not packed. + */ + @Test + fun testDecodeNonPackedFloatArrayProtobuf() { + val obj: BaseFloatArrayCarrier = PackedFloatArrayCarrier(1234567890L.toULong(), floatArrayOf(1f, 2f, 3f)) + val s = """10D285D8CC041D0000803F1D000000401D00004040""" + val decodedPacked = ProtoBuf.decodeFromHexString(PackedFloatArrayCarrier.serializer(), s) + assertEquals(obj, decodedPacked) + val decodedNonPacked = ProtoBuf.decodeFromHexString(NonPackedFloatArrayCarrier.serializer(), s) + assertEquals(obj, decodedNonPacked) + } + + /** + * Test that serializing a list of strings is never packed, and deserialization ignores the packing annotation. + */ + @Test + fun testEncodeAnnotatedStringList() { + val obj = PackedStringCarrier(listOf("aaa", "bbb", "ccc")) + val expectedHex = "020361616102036262620203636363" + val encodedHex = ProtoBuf.encodeToHexString(obj) + assertEquals(expectedHex, encodedHex) + assertEquals(obj, ProtoBuf.decodeFromHexString(expectedHex)) + + val invalidPackedHex = "020C036161610362626203636363" + val decoded = ProtoBuf.decodeFromHexString(invalidPackedHex) + val invalidString = "\u0003aaa\u0003bbb\u0003ccc" + assertEquals(PackedStringCarrier(listOf(invalidString)), decoded) + } + + /** + * Test that toplevel "packed" lists with only byte length also work. + */ + @Test + fun testDecodeToplevelPackedList() { + val input = "0feffdb6f507e6cc9933ba0180feff03" + val listData = listOf(0x7eadbeef, 0x6666666, 0xba, 0x7fff00) + val decoded = ProtoBuf.decodeFromHexString>(input) + + assertEquals(listData, decoded) + } + +}