Skip to content

Commit

Permalink
Merge pull request #9174 from NthPortal/topic/mutation-tracking-itera…
Browse files Browse the repository at this point in the history
…tors/PR

[bug#12009] Make ListBuffer's iterator fail-fast
  • Loading branch information
NthPortal committed Oct 22, 2020
2 parents 9faf48b + 8f6e522 commit 9a71698
Show file tree
Hide file tree
Showing 15 changed files with 534 additions and 67 deletions.
5 changes: 5 additions & 0 deletions build.sbt
Expand Up @@ -451,6 +451,11 @@ val mimaFilterSettings = Seq {
// this is safe because the default cannot be used; instead the single-param overload in
// `IterableOnceOps` is chosen (https://github.com/scala/scala/pull/9232#discussion_r501554458)
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.collection.immutable.ArraySeq.copyToArray$default$2"),

// Fix for scala/bug#12009
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.MutationTracker"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.MutationTracker$"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.MutationTracker$CheckedIterator"),
),
}

Expand Down
10 changes: 3 additions & 7 deletions src/compiler/scala/tools/nsc/transform/async/AnfTransform.scala
Expand Up @@ -136,19 +136,15 @@ private[async] trait AnfTransform extends TransformUtils {
}

case ValDef(mods, name, tpt, rhs) => atOwner(tree.symbol) {
// Capture current cursor of a non-empty `stats` buffer so we can efficiently restrict the
// Capture size of `stats` buffer so we can efficiently restrict the
// `changeOwner` to the newly added items...
var statsIterator = if (currentStats.isEmpty) null else currentStats.iterator
val oldItemsCount = currentStats.length

val expr = atOwner(currentOwner.owner)(transform(rhs))

// But, ListBuffer.empty.iterator doesn't reflect later mutation. Luckily we can just start
// from the beginning of the buffer
if (statsIterator == null) statsIterator = currentStats.iterator

// Definitions within stats lifted out of the `ValDef` rhs should no longer be owned by the
// the ValDef.
statsIterator.foreach(_.changeOwner((currentOwner, currentOwner.owner)))
currentStats.iterator.drop(oldItemsCount).foreach(_.changeOwner((currentOwner, currentOwner.owner)))
val expr1 = if (isUnitType(expr.tpe)) {
currentStats += expr
literalBoxedUnit
Expand Down
12 changes: 8 additions & 4 deletions src/library/scala/collection/mutable/Growable.scala
Expand Up @@ -14,7 +14,7 @@ package scala
package collection
package mutable

import scala.collection.IterableOnce
import scala.annotation.nowarn

/** This trait forms part of collections that can be augmented
* using a `+=` operator and that can be cleared of all elements using
Expand Down Expand Up @@ -56,10 +56,14 @@ trait Growable[-A] extends Clearable {
* @param xs the IterableOnce producing the elements to $add.
* @return the $coll itself.
*/
@nowarn("msg=will most likely never compare equal")
def addAll(xs: IterableOnce[A]): this.type = {
val it = xs.iterator
while (it.hasNext) {
addOne(it.next())
if (xs.asInstanceOf[AnyRef] eq this) addAll(Buffer.from(xs)) // avoid mutating under our own iterator
else {
val it = xs.iterator
while (it.hasNext) {
addOne(it.next())
}
}
this
}
Expand Down
99 changes: 74 additions & 25 deletions src/library/scala/collection/mutable/ListBuffer.scala
Expand Up @@ -35,13 +35,15 @@ import scala.runtime.Statics.releaseFence
* @define mayNotTerminateInf
* @define willNotTerminateInf
*/
@SerialVersionUID(-8428291952499836345L)
class ListBuffer[A]
extends AbstractBuffer[A]
with SeqOps[A, ListBuffer, ListBuffer[A]]
with StrictOptimizedSeqOps[A, ListBuffer, ListBuffer[A]]
with ReusableBuilder[A, immutable.List[A]]
with IterableFactoryDefaults[A, ListBuffer]
with DefaultSerializable {
@transient private[this] var mutationCount: Int = 0

private var first: List[A] = Nil
private var last0: ::[A] = null
Expand All @@ -50,7 +52,7 @@ class ListBuffer[A]

private type Predecessor[A0] = ::[A0] /*| Null*/

def iterator = first.iterator
def iterator: Iterator[A] = new MutationTracker.CheckedIterator(first.iterator, mutationCount)

override def iterableFactory: SeqFactory[ListBuffer] = ListBuffer

Expand All @@ -69,7 +71,12 @@ class ListBuffer[A]
aliased = false
}

private def ensureUnaliased() = if (aliased) copyElems()
// we only call this before mutating things, so it's
// a good place to track mutations for the iterator
private def ensureUnaliased(): Unit = {
mutationCount += 1
if (aliased) copyElems()
}

// Avoids copying where possible.
override def toList: List[A] = {
Expand Down Expand Up @@ -97,6 +104,7 @@ class ListBuffer[A]
}

def clear(): Unit = {
mutationCount += 1
first = Nil
len = 0
last0 = null
Expand All @@ -114,18 +122,28 @@ class ListBuffer[A]

// Overridden for performance
override final def addAll(xs: IterableOnce[A]): this.type = {
val it = xs.iterator
if (it.hasNext) {
ensureUnaliased()
val last1 = new ::[A](it.next(), Nil)
if (len == 0) first = last1 else last0.next = last1
last0 = last1
len += 1
while (it.hasNext) {
if (xs.asInstanceOf[AnyRef] eq this) { // avoid mutating under our own iterator
if (len > 0) {
ensureUnaliased()
val copy = ListBuffer.from(this)
last0.next = copy.first
last0 = copy.last0
len *= 2
}
} else {
val it = xs.iterator
if (it.hasNext) {
ensureUnaliased()
val last1 = new ::[A](it.next(), Nil)
last0.next = last1
if (len == 0) first = last1 else last0.next = last1
last0 = last1
len += 1
while (it.hasNext) {
val last1 = new ::[A](it.next(), Nil)
last0.next = last1
last0 = last1
len += 1
}
}
}
this
Expand Down Expand Up @@ -230,13 +248,29 @@ class ListBuffer[A]
}

def insertAll(idx: Int, elems: IterableOnce[A]): Unit = {
ensureUnaliased()
val it = elems.iterator
if (it.hasNext) {
ensureUnaliased()
if (idx < 0 || idx > len) throw new IndexOutOfBoundsException(s"$idx is out of bounds (min 0, max ${len-1})")
if (idx == len) ++=(elems)
else insertAfter(locate(idx), it)
if (idx < 0 || idx > len) throw new IndexOutOfBoundsException(s"$idx is out of bounds (min 0, max ${len-1})")
elems match {
case elems: AnyRef if elems eq this => // avoid mutating under our own iterator
if (len > 0) {
val copy = ListBuffer.from(this)
if (idx == 0 || idx == len) { // prepend/append
last0.next = copy.first
last0 = copy.last0
} else {
val prev = locate(idx) // cannot be `null` because other condition catches that
val follow = prev.next
prev.next = copy.first
copy.last0.next = follow
}
len *= 2
}
case elems =>
val it = elems.iterator
if (it.hasNext) {
ensureUnaliased()
if (idx == len) ++=(elems)
else insertAfter(locate(idx), it)
}
}
}

Expand Down Expand Up @@ -275,15 +309,17 @@ class ListBuffer[A]
}

def mapInPlace(f: A => A): this.type = {
ensureUnaliased()
mutationCount += 1
val buf = new ListBuffer[A]
for (elem <- this) buf += f(elem)
first = buf.first
last0 = buf.last0
aliased = false // we just assigned from a new instance
this
}

def flatMapInPlace(f: A => IterableOnce[A]): this.type = {
mutationCount += 1
var src = first
var dst: List[A] = null
last0 = null
Expand All @@ -299,6 +335,7 @@ class ListBuffer[A]
src = src.tail
}
first = if(dst eq null) Nil else dst
aliased = false // we just rebuilt a fresh, unaliased instance
this
}

Expand All @@ -322,12 +359,24 @@ class ListBuffer[A]
}

def patchInPlace(from: Int, patch: collection.IterableOnce[A], replaced: Int): this.type = {
val i = math.min(math.max(from, 0), length)
val n = math.min(math.max(replaced, 0), length)
ensureUnaliased()
val p = locate(i)
removeAfter(p, math.min(n, len - i))
insertAfter(p, patch.iterator)
val _len = len
val _from = math.max(from, 0) // normalized
val _replaced = math.max(replaced, 0) // normalized
val it = patch.iterator

val nonEmptyPatch = it.hasNext
val nonEmptyReplace = (_from < _len) && (_replaced > 0)

// don't want to add a mutation or check aliasing (potentially expensive)
// if there's no patching to do
if (nonEmptyPatch || nonEmptyReplace) {
ensureUnaliased()
val i = math.min(_from, _len)
val n = math.min(_replaced, _len)
val p = locate(i)
removeAfter(p, math.min(n, _len - i))
insertAfter(p, it)
}
this
}

Expand Down
78 changes: 78 additions & 0 deletions src/library/scala/collection/mutable/MutationTracker.scala
@@ -0,0 +1,78 @@
/*
* Scala (https://www.scala-lang.org)
*
* Copyright EPFL and Lightbend, Inc.
*
* Licensed under Apache License 2.0
* (http://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/

package scala
package collection
package mutable

import java.util.ConcurrentModificationException

/**
* Utilities to check that mutations to a client that tracks
* its mutations have not occurred since a given point.
* [[Iterator `Iterator`]]s that perform this check automatically
* during iteration can be created by wrapping an `Iterator`
* in a [[MutationTracker.CheckedIterator `CheckedIterator`]],
* or by manually using the [[MutationTracker.checkMutations() `checkMutations`]]
* and [[MutationTracker.checkMutationsForIteration() `checkMutationsForIteration`]]
* methods.
*/
private object MutationTracker {

/**
* Checks whether or not the actual mutation count differs from
* the expected one, throwing an exception, if it does.
*
* @param expectedCount the expected mutation count
* @param actualCount the actual mutation count
* @param message the exception message in case of mutations
* @throws ConcurrentModificationException if the expected and actual
* mutation counts differ
*/
@throws[ConcurrentModificationException]
def checkMutations(expectedCount: Int, actualCount: Int, message: String): Unit = {
if (actualCount != expectedCount) throw new ConcurrentModificationException(message)
}

/**
* Checks whether or not the actual mutation count differs from
* the expected one, throwing an exception, if it does. This method
* produces an exception message saying that it was called because a
* backing collection was mutated during iteration.
*
* @param expectedCount the expected mutation count
* @param actualCount the actual mutation count
* @throws ConcurrentModificationException if the expected and actual
* mutation counts differ
*/
@throws[ConcurrentModificationException]
@inline def checkMutationsForIteration(expectedCount: Int, actualCount: Int): Unit =
checkMutations(expectedCount, actualCount, "mutation occurred during iteration")

/**
* An iterator wrapper that checks if the underlying collection has
* been mutated.
*
* @param underlying the underlying iterator
* @param mutationCount a by-name provider of the current mutation count
* @tparam A the type of the iterator's elements
*/
final class CheckedIterator[A](underlying: Iterator[A], mutationCount: => Int) extends AbstractIterator[A] {
private[this] val expectedCount = mutationCount

def hasNext: Boolean = {
checkMutationsForIteration(expectedCount, mutationCount)
underlying.hasNext
}
def next(): A = underlying.next()
}
}
16 changes: 12 additions & 4 deletions src/library/scala/collection/mutable/Shrinkable.scala
Expand Up @@ -13,7 +13,7 @@
package scala
package collection.mutable

import scala.annotation.tailrec
import scala.annotation.{nowarn, tailrec}

/** This trait forms part of collections that can be reduced
* using a `-=` operator.
Expand Down Expand Up @@ -52,16 +52,24 @@ trait Shrinkable[-A] {
* @param xs the iterator producing the elements to remove.
* @return the $coll itself
*/
@nowarn("msg=will most likely never compare equal")
def subtractAll(xs: collection.IterableOnce[A]): this.type = {
@tailrec def loop(xs: collection.LinearSeq[A]): Unit = {
if (xs.nonEmpty) {
subtractOne(xs.head)
loop(xs.tail)
}
}
xs match {
case xs: collection.LinearSeq[A] => loop(xs)
case xs => xs.iterator.foreach(subtractOne)
if (xs.asInstanceOf[AnyRef] eq this) { // avoid mutating under our own iterator
xs match {
case xs: Clearable => xs.clear()
case xs => subtractAll(Buffer.from(xs))
}
} else {
xs match {
case xs: collection.LinearSeq[A] => loop(xs)
case xs => xs.iterator.foreach(subtractOne)
}
}
this
}
Expand Down

0 comments on commit 9a71698

Please sign in to comment.