Skip to content

Commit

Permalink
Protobuf packed encoding/decoding (Kotlin#1830)
Browse files Browse the repository at this point in the history
* Create an annotation to request packing of collections.

* The specification only allows packing for primitive types (wire types 1, 2 or 5) to allow
decoders decode either format independently of the proto specification.

* Make pushback work in respect to currentType/currentId. This allows it to be used to effectively
peek the type without assumptions on state.

* Clarify in the documentation that reading will (per the standard) supports inputs in either format,
independent of the annotation. The annotation only affects writing.

* Support decoding "packed" arrays as toplevels. Add tests for handling of strings
and "packed" toplevel arrays. The checking for eof works as bytesize is always >= array length.
  • Loading branch information
pdvrieze committed Apr 29, 2022
1 parent 4c34c23 commit 6047db8
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 17 deletions.
7 changes: 7 additions & 0 deletions formats/protobuf/api/kotlinx-serialization-protobuf.api
Expand Up @@ -38,6 +38,13 @@ public final class kotlinx/serialization/protobuf/ProtoNumber$Impl : kotlinx/ser
public final synthetic fun number ()I
}

public abstract interface annotation class kotlinx/serialization/protobuf/ProtoPacked : java/lang/annotation/Annotation {
}

public final class kotlinx/serialization/protobuf/ProtoPacked$Impl : kotlinx/serialization/protobuf/ProtoPacked {
public fun <init> ()V
}

public abstract interface annotation class kotlinx/serialization/protobuf/ProtoType : java/lang/annotation/Annotation {
public abstract fun type ()Lkotlinx/serialization/protobuf/ProtoIntegerType;
}
Expand Down
Expand Up @@ -32,9 +32,9 @@ public annotation class ProtoNumber(public val number: Int)
@Suppress("NO_EXPLICIT_VISIBILITY_IN_API_MODE_WARNING")
@ExperimentalSerializationApi
public enum class ProtoIntegerType(internal val signature: Long) {
DEFAULT(0L shl 32),
SIGNED(1L shl 32),
FIXED(2L shl 32);
DEFAULT(0L shl 33),
SIGNED(1L shl 33),
FIXED(2L shl 33);
}

/**
Expand All @@ -45,3 +45,12 @@ public enum class ProtoIntegerType(internal val signature: Long) {
@Target(AnnotationTarget.PROPERTY)
@ExperimentalSerializationApi
public annotation class ProtoType(public val type: ProtoIntegerType)


/**
* Instructs that a particular collection should be written as [packed array](https://developers.google.com/protocol-buffers/docs/encoding#packed)
*/
@SerialInfo
@Target(AnnotationTarget.PROPERTY)
@ExperimentalSerializationApi
public annotation class ProtoPacked
Expand Up @@ -16,7 +16,15 @@ internal const val i64 = 1
internal const val SIZE_DELIMITED = 2
internal const val i32 = 5

private const val MASK = Int.MAX_VALUE.toLong() shl 32
private const val INTTYPEMASK = (Int.MAX_VALUE.toLong() shr 1) shl 33
private const val PACKEDMASK = 1L shl 32

@Suppress("NOTHING_TO_INLINE")
internal inline fun ProtoDesc(protoId: Int, type: ProtoIntegerType, packed: Boolean): ProtoDesc {
val packedBits = if (packed) 1L shl 32 else 0L
val signature = type.signature or packedBits
return signature or protoId.toLong()
}

@Suppress("NOTHING_TO_INLINE")
internal inline fun ProtoDesc(protoId: Int, type: ProtoIntegerType): ProtoDesc {
Expand All @@ -26,25 +34,40 @@ internal inline fun ProtoDesc(protoId: Int, type: ProtoIntegerType): ProtoDesc {
internal inline val ProtoDesc.protoId: Int get() = (this and Int.MAX_VALUE.toLong()).toInt()

internal val ProtoDesc.integerType: ProtoIntegerType
get() = when(this and MASK) {
get() = when(this and INTTYPEMASK) {
ProtoIntegerType.DEFAULT.signature -> ProtoIntegerType.DEFAULT
ProtoIntegerType.SIGNED.signature -> ProtoIntegerType.SIGNED
else -> ProtoIntegerType.FIXED
}

internal val SerialDescriptor.isPackable: Boolean
@OptIn(kotlinx.serialization.ExperimentalSerializationApi::class)
get() = when (kind) {
PrimitiveKind.STRING,
!is PrimitiveKind -> false
else -> true
}

internal val ProtoDesc.isPacked: Boolean
get() = (this and PACKEDMASK) != 0L

internal fun SerialDescriptor.extractParameters(index: Int): ProtoDesc {
val annotations = getElementAnnotations(index)
var protoId: Int = index + 1
var format: ProtoIntegerType = ProtoIntegerType.DEFAULT
var protoPacked = false

for (i in annotations.indices) { // Allocation-friendly loop
val annotation = annotations[i]
if (annotation is ProtoNumber) {
protoId = annotation.number
} else if (annotation is ProtoType) {
format = annotation.type
} else if (annotation is ProtoPacked) {
protoPacked = true
}
}
return ProtoDesc(protoId, format)
return ProtoDesc(protoId, format, protoPacked)
}

internal fun extractProtoId(descriptor: SerialDescriptor, index: Int, zeroBasedDefault: Boolean): Int {
Expand Down
@@ -0,0 +1,32 @@
package kotlinx.serialization.protobuf.internal

import kotlinx.serialization.*
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.encoding.*
import kotlinx.serialization.protobuf.*

@OptIn(ExperimentalSerializationApi::class)
internal class PackedArrayDecoder(
proto: ProtoBuf,
reader: ProtobufReader,
descriptor: SerialDescriptor,
) : ProtobufDecoder(proto, reader, descriptor) {
private var nextIndex: Int = 0

// Tags are omitted in the packed array format
override fun SerialDescriptor.getTag(index: Int): ProtoDesc = MISSING_TAG

override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
throw SerializationException("Packing only supports primitive number types. The input type however was a struct: $descriptor")
}

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
// We need eof here as there is no tag to read in packed form.
if (reader.eof) return CompositeDecoder.DECODE_DONE
return nextIndex++
}

override fun decodeTaggedString(tag: ProtoDesc): String {
throw SerializationException("Packing only supports primitive number types. The actual reading is for string.")
}
}
@@ -0,0 +1,31 @@
package kotlinx.serialization.protobuf.internal

import kotlinx.serialization.*
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.encoding.*
import kotlinx.serialization.protobuf.*

@OptIn(ExperimentalSerializationApi::class)
internal class PackedArrayEncoder(
proto: ProtoBuf,
writer: ProtobufWriter,
curTag: ProtoDesc,
descriptor: SerialDescriptor,
stream: ByteArrayOutput = ByteArrayOutput()
) : NestedRepeatedEncoder(proto, writer, curTag, descriptor, stream) {

// Triggers not writing header
override fun SerialDescriptor.getTag(index: Int): ProtoDesc = MISSING_TAG

override fun beginCollection(descriptor: SerialDescriptor, collectionSize: Int): CompositeEncoder {
throw SerializationException("Packing only supports primitive number types")
}

override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder {
throw SerializationException("Packing only supports primitive number types")
}

override fun encodeTaggedString(tag: ProtoDesc, value: String) {
throw SerializationException("Packing only supports primitive number types")
}
}
Expand Up @@ -108,6 +108,11 @@ internal open class ProtobufDecoder(
reader.readTag()
// all elements always have id = 1
RepeatedDecoder(proto, reader, ProtoDesc(1, ProtoIntegerType.DEFAULT), descriptor)

} else if (reader.currentType == SIZE_DELIMITED && descriptor.getElementDescriptor(0).isPackable) {
val sliceReader = ProtobufReader(reader.objectInput())
PackedArrayDecoder(proto, sliceReader, descriptor)

} else {
RepeatedDecoder(proto, reader, tag, descriptor)
}
Expand Down Expand Up @@ -287,7 +292,8 @@ private class RepeatedDecoder(
private fun decodeListIndexNoTag(): Int {
val size = -tagOrSize
val idx = ++index
if (idx.toLong() == size) return CompositeDecoder.DECODE_DONE
// Check for eof is here for the case that it is an out-of-spec packed array where size is bytesize not list length.
if (idx.toLong() == size || reader.eof) return CompositeDecoder.DECODE_DONE
return idx
}

Expand Down
Expand Up @@ -30,13 +30,17 @@ internal open class ProtobufEncoder(
): CompositeEncoder = when (descriptor.kind) {
StructureKind.LIST -> {
val tag = currentTagOrDefault
if (tag == MISSING_TAG) {
writer.writeInt(collectionSize)
}
if (this.descriptor.kind == StructureKind.LIST && tag != MISSING_TAG && this.descriptor != descriptor) {
NestedRepeatedEncoder(proto, writer, tag, descriptor)
if (tag.isPacked && descriptor.getElementDescriptor(0).isPackable) {
PackedArrayEncoder(proto, writer, currentTagOrDefault, descriptor)
} else {
RepeatedEncoder(proto, writer, tag, descriptor)
if (tag == MISSING_TAG) {
writer.writeInt(collectionSize)
}
if (this.descriptor.kind == StructureKind.LIST && tag != MISSING_TAG && this.descriptor != descriptor) {
NestedRepeatedEncoder(proto, writer, tag, descriptor)
} else {
RepeatedEncoder(proto, writer, tag, descriptor)
}
}
}
StructureKind.MAP -> {
Expand All @@ -47,7 +51,13 @@ internal open class ProtobufEncoder(
}

override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder = when (descriptor.kind) {
StructureKind.LIST -> RepeatedEncoder(proto, writer, currentTagOrDefault, descriptor)
StructureKind.LIST -> {
if (descriptor.getElementDescriptor(0).isPackable && currentTagOrDefault.isPacked) {
PackedArrayEncoder(proto, writer, currentTagOrDefault, descriptor)
} else {
RepeatedEncoder(proto, writer, currentTagOrDefault, descriptor)
}
}
StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> {
val tag = currentTagOrDefault
if (tag == MISSING_TAG && descriptor == this.descriptor) this
Expand Down Expand Up @@ -183,7 +193,7 @@ private class RepeatedEncoder(
override fun SerialDescriptor.getTag(index: Int) = curTag
}

private class NestedRepeatedEncoder(
internal open class NestedRepeatedEncoder(
proto: ProtoBuf,
@JvmField val writer: ProtobufWriter,
@JvmField val curTag: ProtoDesc,
Expand Down
Expand Up @@ -15,14 +15,27 @@ internal class ProtobufReader(private val input: ByteArrayInput) {
@JvmField
public var currentType = -1
private var pushBack = false
private var pushBackHeader = 0

public val eof
get() = !pushBack && input.availableBytes == 0

public fun readTag(): Int {
if (pushBack) {
pushBack = false
return currentId
val previousHeader = (currentId shl 3) or currentType
return updateIdAndType(pushBackHeader).also {
pushBackHeader = previousHeader
}
}
// Header to use when pushed back is the old id/type
pushBackHeader = (currentId shl 3) or currentType

val header = input.readVarint64(true).toInt()
return updateIdAndType(header)
}

private fun updateIdAndType(header: Int): Int {
return if (header == -1) {
currentId = -1
currentType = -1
Expand All @@ -36,6 +49,10 @@ internal class ProtobufReader(private val input: ByteArrayInput) {

public fun pushBackTag() {
pushBack = true

val nextHeader = (currentId shl 3) or currentType
updateIdAndType(pushBackHeader)
pushBackHeader = nextHeader
}

fun skipElement() {
Expand Down
Expand Up @@ -8,7 +8,7 @@ import kotlinx.serialization.*

internal class ByteArrayInput(private var array: ByteArray, private val endIndex: Int = array.size) {
private var position: Int = 0
private val availableBytes: Int get() = endIndex - position
val availableBytes: Int get() = endIndex - position

fun slice(size: Int): ByteArrayInput {
ensureEnoughBytes(size)
Expand Down

0 comments on commit 6047db8

Please sign in to comment.