Skip to content

Commit

Permalink
Prevent Function0 execution during LazyList deserialization
Browse files Browse the repository at this point in the history
This PR ensures that LazyList deserialization will not execute an
arbitrary Function0 when being passed a forged serialization stream.

See the PR description for a detailed explanation.
  • Loading branch information
lrytz committed Aug 29, 2022
1 parent c46fd04 commit a40c17a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 5 deletions.
10 changes: 6 additions & 4 deletions src/library/scala/collection/immutable/LazyList.scala
Expand Up @@ -19,7 +19,7 @@ import java.lang.{StringBuilder => JStringBuilder}

import scala.annotation.tailrec
import scala.collection.generic.SerializeEnd
import scala.collection.mutable.{ArrayBuffer, Builder, ReusableBuilder, StringBuilder}
import scala.collection.mutable.{Builder, ReusableBuilder, StringBuilder}
import scala.language.implicitConversions
import scala.runtime.Statics

Expand Down Expand Up @@ -1353,7 +1353,7 @@ object LazyList extends SeqFactory[LazyList] {
private[this] def writeObject(out: ObjectOutputStream): Unit = {
out.defaultWriteObject()
var these = coll
while(these.knownNonEmpty) {
while (these.knownNonEmpty) {
out.writeObject(these.head)
these = these.tail
}
Expand All @@ -1363,14 +1363,16 @@ object LazyList extends SeqFactory[LazyList] {

private[this] def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
val init = new ArrayBuffer[A]
val init = new mutable.ListBuffer[A]
var initRead = false
while (!initRead) in.readObject match {
case SerializeEnd => initRead = true
case a => init += a.asInstanceOf[A]
}
val tail = in.readObject().asInstanceOf[LazyList[A]]
coll = init ++: tail
// scala/scala#10118: caution that no code path can evaluate `tail.state`
// before the resulting LazyList is returned
coll = newLL(stateFromIteratorConcatSuffix(init.toList.iterator)(tail.state))
}

private[this] def readResolve(): Any = coll
Expand Down
71 changes: 70 additions & 1 deletion test/junit/scala/collection/immutable/LazyListTest.scala
Expand Up @@ -8,12 +8,77 @@ import org.junit.Assert._

import scala.annotation.unused
import scala.collection.mutable.{Builder, ListBuffer}
import scala.tools.testkit.AssertUtil
import scala.tools.testkit.{AssertUtil, ReflectUtil}
import scala.util.Try

@RunWith(classOf[JUnit4])
class LazyListTest {

@Test
def serialization(): Unit = {
import java.io._

def serialize(obj: AnyRef): Array[Byte] = {
val buffer = new ByteArrayOutputStream
val out = new ObjectOutputStream(buffer)
out.writeObject(obj)
buffer.toByteArray
}

def deserialize(a: Array[Byte]): AnyRef = {
val in = new ObjectInputStream(new ByteArrayInputStream(a))
in.readObject
}

def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T]

val l = LazyList.from(10)

val ld1 = serializeDeserialize(l)
assertEquals(l.take(10).toList, ld1.take(10).toList)

l.tail.head
val ld2 = serializeDeserialize(l)
assertEquals(l.take(10).toList, ld2.take(10).toList)

LazyListTest.serializationForceCount = 0
val u = LazyList.from(10).map(x => { LazyListTest.serializationForceCount += 1; x})

@unused def printDiff(): Unit = {
val a = serialize(u)
ReflectUtil.getFieldAccessible[LazyList[_]]("scala$collection$immutable$LazyList$$stateEvaluated").setBoolean(u, true)
val b = serialize(u)
val i = a.zip(b).indexWhere(p => p._1 != p._2)
println("difference: ")
println(s"val from = ${a.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
println(s"val to = ${b.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
}

// to update this test, comment-out `LazyList.writeReplace` and run `printDiff`
// printDiff()

val from = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 0, 115, 114, 0, 33, 106, 97, 118, 97, 46)
val to = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 1, 115, 114, 0, 33, 106, 97, 118, 97, 46)

assertEquals(LazyListTest.serializationForceCount, 0)

u.head
assertEquals(LazyListTest.serializationForceCount, 1)

val data = serialize(u)
var i = data.indexOfSlice(from)
to.foreach(x => {data(i) = x; i += 1})

val ud1 = deserialize(data).asInstanceOf[LazyList[Int]]

// this check failed before scala/scala#10118, deserialization triggered evaluation
assertEquals(LazyListTest.serializationForceCount, 1)

ud1.tail.head
assertEquals(LazyListTest.serializationForceCount, 2)

}

@Test
def t6727_and_t6440_and_8627(): Unit = {
assertTrue(LazyList.continually(()).filter(_ => true).take(2) == Seq((), ()))
Expand Down Expand Up @@ -378,3 +443,7 @@ class LazyListTest {
assertEquals(1, count)
}
}

object LazyListTest {
var serializationForceCount = 0
}

0 comments on commit a40c17a

Please sign in to comment.