Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For security, prevent Function0 execution during LazyList deserialization #10118

Merged
merged 1 commit into from Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/library/scala/collection/immutable/LazyList.scala
Expand Up @@ -19,7 +19,7 @@ import java.lang.{StringBuilder => JStringBuilder}

import scala.annotation.tailrec
import scala.collection.generic.SerializeEnd
import scala.collection.mutable.{ArrayBuffer, Builder, ReusableBuilder, StringBuilder}
import scala.collection.mutable.{Builder, ReusableBuilder, StringBuilder}
import scala.language.implicitConversions
import scala.runtime.Statics

Expand Down Expand Up @@ -1353,7 +1353,7 @@ object LazyList extends SeqFactory[LazyList] {
private[this] def writeObject(out: ObjectOutputStream): Unit = {
out.defaultWriteObject()
var these = coll
while(these.knownNonEmpty) {
while (these.knownNonEmpty) {
out.writeObject(these.head)
these = these.tail
}
Expand All @@ -1363,14 +1363,17 @@ object LazyList extends SeqFactory[LazyList] {

private[this] def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
val init = new ArrayBuffer[A]
val init = new mutable.ListBuffer[A]
var initRead = false
while (!initRead) in.readObject match {
case SerializeEnd => initRead = true
case a => init += a.asInstanceOf[A]
}
val tail = in.readObject().asInstanceOf[LazyList[A]]
NthPortal marked this conversation as resolved.
Show resolved Hide resolved
coll = init ++: tail
// scala/scala#10118: caution that no code path can evaluate `tail.state`
// before the resulting LazyList is returned
val it = init.toList.iterator
coll = newLL(stateFromIteratorConcatSuffix(it)(tail.state))
}

private[this] def readResolve(): Any = coll
Expand Down
73 changes: 72 additions & 1 deletion test/junit/scala/collection/immutable/LazyListTest.scala
Expand Up @@ -8,12 +8,79 @@ import org.junit.Assert._

import scala.annotation.unused
import scala.collection.mutable.{Builder, ListBuffer}
import scala.tools.testkit.AssertUtil
import scala.tools.testkit.{AssertUtil, ReflectUtil}
import scala.util.Try

@RunWith(classOf[JUnit4])
class LazyListTest {

@Test
def serialization(): Unit = {
import java.io._

def serialize(obj: AnyRef): Array[Byte] = {
val buffer = new ByteArrayOutputStream
val out = new ObjectOutputStream(buffer)
out.writeObject(obj)
buffer.toByteArray
}

def deserialize(a: Array[Byte]): AnyRef = {
val in = new ObjectInputStream(new ByteArrayInputStream(a))
in.readObject
}

def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T]

val l = LazyList.from(10)

val ld1 = serializeDeserialize(l)
assertEquals(l.take(10).toList, ld1.take(10).toList)

l.tail.head
val ld2 = serializeDeserialize(l)
assertEquals(l.take(10).toList, ld2.take(10).toList)

LazyListTest.serializationForceCount = 0
val u = LazyList.from(10).map(x => { LazyListTest.serializationForceCount += 1; x })

@unused def printDiff(): Unit = {
val a = serialize(u)
ReflectUtil.getFieldAccessible[LazyList[_]]("scala$collection$immutable$LazyList$$stateEvaluated").setBoolean(u, true)
val b = serialize(u)
val i = a.zip(b).indexWhere(p => p._1 != p._2)
println("difference: ")
println(s"val from = ${a.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
println(s"val to = ${b.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
}

// to update this test, comment-out `LazyList.writeReplace` and run `printDiff`
// printDiff()

val from = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 0, 115, 114, 0, 33, 106, 97, 118, 97, 46)
val to = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 1, 115, 114, 0, 33, 106, 97, 118, 97, 46)

assertEquals(LazyListTest.serializationForceCount, 0)

u.head
assertEquals(LazyListTest.serializationForceCount, 1)

val data = serialize(u)
var i = data.indexOfSlice(from)
to.foreach(x => {data(i) = x; i += 1})

val ud1 = deserialize(data).asInstanceOf[LazyList[Int]]

// this check failed before scala/scala#10118, deserialization triggered evaluation
assertEquals(LazyListTest.serializationForceCount, 1)

ud1.tail.head
assertEquals(LazyListTest.serializationForceCount, 2)
lrytz marked this conversation as resolved.
Show resolved Hide resolved

u.tail.head
assertEquals(LazyListTest.serializationForceCount, 3)
}

@Test
def t6727_and_t6440_and_8627(): Unit = {
assertTrue(LazyList.continually(()).filter(_ => true).take(2) == Seq((), ()))
Expand Down Expand Up @@ -378,3 +445,7 @@ class LazyListTest {
assertEquals(1, count)
}
}

object LazyListTest {
var serializationForceCount = 0
}