Skip to content

Commit

Permalink
[bug#12009] Make ListBuffer's iterator fail-fast
Browse files Browse the repository at this point in the history
Make `ListBuffer`'s iterator fail-fast when the buffer is
mutated after the iterator's creation.

Co-authored-by: Jason Zaugg <jzaugg@gmail.com>
  • Loading branch information
NthPortal and retronym committed Oct 21, 2020
1 parent 1390315 commit 8f6e522
Show file tree
Hide file tree
Showing 11 changed files with 435 additions and 41 deletions.
5 changes: 5 additions & 0 deletions build.sbt
Expand Up @@ -451,6 +451,11 @@ val mimaFilterSettings = Seq {
// this is safe because the default cannot be used; instead the single-param overload in
// `IterableOnceOps` is chosen (https://github.com/scala/scala/pull/9232#discussion_r501554458)
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.collection.immutable.ArraySeq.copyToArray$default$2"),

// Fix for scala/bug#12009
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.MutationTracker"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.MutationTracker$"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.MutationTracker$CheckedIterator"),
),
}

Expand Down
10 changes: 3 additions & 7 deletions src/compiler/scala/tools/nsc/transform/async/AnfTransform.scala
Expand Up @@ -136,19 +136,15 @@ private[async] trait AnfTransform extends TransformUtils {
}

case ValDef(mods, name, tpt, rhs) => atOwner(tree.symbol) {
// Capture current cursor of a non-empty `stats` buffer so we can efficiently restrict the
// Capture size of `stats` buffer so we can efficiently restrict the
// `changeOwner` to the newly added items...
var statsIterator = if (currentStats.isEmpty) null else currentStats.iterator
val oldItemsCount = currentStats.length

val expr = atOwner(currentOwner.owner)(transform(rhs))

// But, ListBuffer.empty.iterator doesn't reflect later mutation. Luckily we can just start
// from the beginning of the buffer
if (statsIterator == null) statsIterator = currentStats.iterator

// Definitions within stats lifted out of the `ValDef` rhs should no longer be owned by the
// the ValDef.
statsIterator.foreach(_.changeOwner((currentOwner, currentOwner.owner)))
currentStats.iterator.drop(oldItemsCount).foreach(_.changeOwner((currentOwner, currentOwner.owner)))
val expr1 = if (isUnitType(expr.tpe)) {
currentStats += expr
literalBoxedUnit
Expand Down
41 changes: 32 additions & 9 deletions src/library/scala/collection/mutable/ListBuffer.scala
Expand Up @@ -35,13 +35,15 @@ import scala.runtime.Statics.releaseFence
* @define mayNotTerminateInf
* @define willNotTerminateInf
*/
@SerialVersionUID(-8428291952499836345L)
class ListBuffer[A]
extends AbstractBuffer[A]
with SeqOps[A, ListBuffer, ListBuffer[A]]
with StrictOptimizedSeqOps[A, ListBuffer, ListBuffer[A]]
with ReusableBuilder[A, immutable.List[A]]
with IterableFactoryDefaults[A, ListBuffer]
with DefaultSerializable {
@transient private[this] var mutationCount: Int = 0

private var first: List[A] = Nil
private var last0: ::[A] = null
Expand All @@ -50,7 +52,7 @@ class ListBuffer[A]

private type Predecessor[A0] = ::[A0] /*| Null*/

def iterator = first.iterator
def iterator: Iterator[A] = new MutationTracker.CheckedIterator(first.iterator, mutationCount)

override def iterableFactory: SeqFactory[ListBuffer] = ListBuffer

Expand All @@ -69,7 +71,12 @@ class ListBuffer[A]
aliased = false
}

private def ensureUnaliased() = if (aliased) copyElems()
// we only call this before mutating things, so it's
// a good place to track mutations for the iterator
private def ensureUnaliased(): Unit = {
mutationCount += 1
if (aliased) copyElems()
}

// Avoids copying where possible.
override def toList: List[A] = {
Expand Down Expand Up @@ -97,6 +104,7 @@ class ListBuffer[A]
}

def clear(): Unit = {
mutationCount += 1
first = Nil
len = 0
last0 = null
Expand Down Expand Up @@ -301,15 +309,17 @@ class ListBuffer[A]
}

def mapInPlace(f: A => A): this.type = {
ensureUnaliased()
mutationCount += 1
val buf = new ListBuffer[A]
for (elem <- this) buf += f(elem)
first = buf.first
last0 = buf.last0
aliased = false // we just assigned from a new instance
this
}

def flatMapInPlace(f: A => IterableOnce[A]): this.type = {
mutationCount += 1
var src = first
var dst: List[A] = null
last0 = null
Expand All @@ -325,6 +335,7 @@ class ListBuffer[A]
src = src.tail
}
first = if(dst eq null) Nil else dst
aliased = false // we just rebuilt a fresh, unaliased instance
this
}

Expand All @@ -348,12 +359,24 @@ class ListBuffer[A]
}

def patchInPlace(from: Int, patch: collection.IterableOnce[A], replaced: Int): this.type = {
val i = math.min(math.max(from, 0), length)
val n = math.min(math.max(replaced, 0), length)
ensureUnaliased()
val p = locate(i)
removeAfter(p, math.min(n, len - i))
insertAfter(p, patch.iterator)
val _len = len
val _from = math.max(from, 0) // normalized
val _replaced = math.max(replaced, 0) // normalized
val it = patch.iterator

val nonEmptyPatch = it.hasNext
val nonEmptyReplace = (_from < _len) && (_replaced > 0)

// don't want to add a mutation or check aliasing (potentially expensive)
// if there's no patching to do
if (nonEmptyPatch || nonEmptyReplace) {
ensureUnaliased()
val i = math.min(_from, _len)
val n = math.min(_replaced, _len)
val p = locate(i)
removeAfter(p, math.min(n, _len - i))
insertAfter(p, it)
}
this
}

Expand Down
78 changes: 78 additions & 0 deletions src/library/scala/collection/mutable/MutationTracker.scala
@@ -0,0 +1,78 @@
/*
* Scala (https://www.scala-lang.org)
*
* Copyright EPFL and Lightbend, Inc.
*
* Licensed under Apache License 2.0
* (http://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/

package scala
package collection
package mutable

import java.util.ConcurrentModificationException

/**
* Utilities to check that mutations to a client that tracks
* its mutations have not occurred since a given point.
* [[Iterator `Iterator`]]s that perform this check automatically
* during iteration can be created by wrapping an `Iterator`
* in a [[MutationTracker.CheckedIterator `CheckedIterator`]],
* or by manually using the [[MutationTracker.checkMutations() `checkMutations`]]
* and [[MutationTracker.checkMutationsForIteration() `checkMutationsForIteration`]]
* methods.
*/
private object MutationTracker {

/**
* Checks whether or not the actual mutation count differs from
* the expected one, throwing an exception, if it does.
*
* @param expectedCount the expected mutation count
* @param actualCount the actual mutation count
* @param message the exception message in case of mutations
* @throws ConcurrentModificationException if the expected and actual
* mutation counts differ
*/
@throws[ConcurrentModificationException]
def checkMutations(expectedCount: Int, actualCount: Int, message: String): Unit = {
if (actualCount != expectedCount) throw new ConcurrentModificationException(message)
}

/**
* Checks whether or not the actual mutation count differs from
* the expected one, throwing an exception, if it does. This method
* produces an exception message saying that it was called because a
* backing collection was mutated during iteration.
*
* @param expectedCount the expected mutation count
* @param actualCount the actual mutation count
* @throws ConcurrentModificationException if the expected and actual
* mutation counts differ
*/
@throws[ConcurrentModificationException]
@inline def checkMutationsForIteration(expectedCount: Int, actualCount: Int): Unit =
checkMutations(expectedCount, actualCount, "mutation occurred during iteration")

/**
* An iterator wrapper that checks if the underlying collection has
* been mutated.
*
* @param underlying the underlying iterator
* @param mutationCount a by-name provider of the current mutation count
* @tparam A the type of the iterator's elements
*/
final class CheckedIterator[A](underlying: Iterator[A], mutationCount: => Int) extends AbstractIterator[A] {
private[this] val expectedCount = mutationCount

def hasNext: Boolean = {
checkMutationsForIteration(expectedCount, mutationCount)
underlying.hasNext
}
def next(): A = underlying.next()
}
}
@@ -0,0 +1,49 @@
/*
* Scala (https://www.scala-lang.org)
*
* Copyright EPFL and Lightbend, Inc.
*
* Licensed under Apache License 2.0
* (http://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/

package scala.collection
package mutable

import java.util.concurrent.TimeUnit

import org.openjdk.jmh.annotations._
import org.openjdk.jmh.infra._

@BenchmarkMode(Array(Mode.AverageTime))
@Fork(2)
@Threads(1)
@Warmup(iterations = 20)
@Measurement(iterations = 20)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Benchmark)
class ConstructionBenchmark {
@Param(Array("0", "1", "10", "100"))
var size: Int = _

var values: Range = _

@Setup(Level.Trial) def init(): Unit = {
values = 1 to size
}

@Benchmark def listBuffer_new: Any = {
new ListBuffer ++= values
}

@Benchmark def listBuffer_from: Any = {
ListBuffer from values
}

@Benchmark def listBuffer_to: Any = {
values to ListBuffer
}
}
Expand Up @@ -98,4 +98,18 @@ class ListBufferBenchmark {
b.flatMapInPlace { _ => seq }
bh.consume(b)
}

@Benchmark def iteratorA(bh: Blackhole): Unit = {
val b = ref.clone()
var n = 0
for (x <- b.iterator) n += x
bh.consume(n)
bh.consume(b)
}

@Benchmark def iteratorB(bh: Blackhole): Unit = {
val b = ref.clone()
bh.consume(b.iterator.toVector)
bh.consume(b)
}
}
1 change: 0 additions & 1 deletion test/files/run/t8153.check

This file was deleted.

14 changes: 0 additions & 14 deletions test/files/run/t8153.scala

This file was deleted.

32 changes: 32 additions & 0 deletions test/junit/scala/collection/mutable/MutationTrackerTest.scala
@@ -0,0 +1,32 @@
/*
* Scala (https://www.scala-lang.org)
*
* Copyright EPFL and Lightbend, Inc.
*
* Licensed under Apache License 2.0
* (http://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/

package scala.collection.mutable

import java.util.ConcurrentModificationException

import org.junit.Test

import scala.tools.testkit.AssertUtil.assertThrows

class MutationTrackerTest {
@Test
def checkedIterator(): Unit = {
var mutationCount = 0
def it = new MutationTracker.CheckedIterator(List(1, 2, 3).iterator, mutationCount)
val it1 = it
it1.toList // does not throw
val it2 = it
mutationCount += 1
assertThrows[ConcurrentModificationException](it2.toList, _ contains "iteration")
}
}

0 comments on commit 8f6e522

Please sign in to comment.