Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ListBuffer's iterator fail when the buffer is mutated #9174

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions build.sbt
Expand Up @@ -451,6 +451,11 @@ val mimaFilterSettings = Seq {
// this is safe because the default cannot be used; instead the single-param overload in
// `IterableOnceOps` is chosen (https://github.com/scala/scala/pull/9232#discussion_r501554458)
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.collection.immutable.ArraySeq.copyToArray$default$2"),

// Fix for scala/bug#12009
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.MutationTracker"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.MutationTracker$"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.MutationTracker$CheckedIterator"),
),
}

Expand Down
10 changes: 3 additions & 7 deletions src/compiler/scala/tools/nsc/transform/async/AnfTransform.scala
Expand Up @@ -136,19 +136,15 @@ private[async] trait AnfTransform extends TransformUtils {
}

case ValDef(mods, name, tpt, rhs) => atOwner(tree.symbol) {
// Capture current cursor of a non-empty `stats` buffer so we can efficiently restrict the
// Capture size of `stats` buffer so we can efficiently restrict the
// `changeOwner` to the newly added items...
var statsIterator = if (currentStats.isEmpty) null else currentStats.iterator
val oldItemsCount = currentStats.length

val expr = atOwner(currentOwner.owner)(transform(rhs))

// But, ListBuffer.empty.iterator doesn't reflect later mutation. Luckily we can just start
// from the beginning of the buffer
if (statsIterator == null) statsIterator = currentStats.iterator

NthPortal marked this conversation as resolved.
Show resolved Hide resolved
// Definitions within stats lifted out of the `ValDef` rhs should no longer be owned by the
// the ValDef.
statsIterator.foreach(_.changeOwner((currentOwner, currentOwner.owner)))
currentStats.iterator.drop(oldItemsCount).foreach(_.changeOwner((currentOwner, currentOwner.owner)))
NthPortal marked this conversation as resolved.
Show resolved Hide resolved
val expr1 = if (isUnitType(expr.tpe)) {
currentStats += expr
literalBoxedUnit
Expand Down
12 changes: 8 additions & 4 deletions src/library/scala/collection/mutable/Growable.scala
Expand Up @@ -14,7 +14,7 @@ package scala
package collection
package mutable

import scala.collection.IterableOnce
import scala.annotation.nowarn

/** This trait forms part of collections that can be augmented
* using a `+=` operator and that can be cleared of all elements using
Expand Down Expand Up @@ -56,10 +56,14 @@ trait Growable[-A] extends Clearable {
* @param xs the IterableOnce producing the elements to $add.
* @return the $coll itself.
*/
@nowarn("msg=will most likely never compare equal")
SethTisue marked this conversation as resolved.
Show resolved Hide resolved
def addAll(xs: IterableOnce[A]): this.type = {
val it = xs.iterator
while (it.hasNext) {
addOne(it.next())
if (xs.asInstanceOf[AnyRef] eq this) addAll(Buffer.from(xs)) // avoid mutating under our own iterator
else {
val it = xs.iterator
while (it.hasNext) {
addOne(it.next())
}
}
this
}
Expand Down
99 changes: 74 additions & 25 deletions src/library/scala/collection/mutable/ListBuffer.scala
Expand Up @@ -35,13 +35,15 @@ import scala.runtime.Statics.releaseFence
* @define mayNotTerminateInf
* @define willNotTerminateInf
*/
@SerialVersionUID(-8428291952499836345L)
class ListBuffer[A]
extends AbstractBuffer[A]
with SeqOps[A, ListBuffer, ListBuffer[A]]
with StrictOptimizedSeqOps[A, ListBuffer, ListBuffer[A]]
with ReusableBuilder[A, immutable.List[A]]
with IterableFactoryDefaults[A, ListBuffer]
with DefaultSerializable {
@transient private[this] var mutationCount: Int = 0

private var first: List[A] = Nil
private var last0: ::[A] = null
Expand All @@ -50,7 +52,7 @@ class ListBuffer[A]

private type Predecessor[A0] = ::[A0] /*| Null*/

def iterator = first.iterator
def iterator: Iterator[A] = new MutationTracker.CheckedIterator(first.iterator, mutationCount)

override def iterableFactory: SeqFactory[ListBuffer] = ListBuffer

Expand All @@ -69,7 +71,12 @@ class ListBuffer[A]
aliased = false
}

private def ensureUnaliased() = if (aliased) copyElems()
// we only call this before mutating things, so it's
// a good place to track mutations for the iterator
private def ensureUnaliased(): Unit = {
mutationCount += 1
if (aliased) copyElems()
}

// Avoids copying where possible.
override def toList: List[A] = {
Expand Down Expand Up @@ -97,6 +104,7 @@ class ListBuffer[A]
}

def clear(): Unit = {
mutationCount += 1
first = Nil
len = 0
last0 = null
Expand All @@ -114,18 +122,28 @@ class ListBuffer[A]

// Overridden for performance
override final def addAll(xs: IterableOnce[A]): this.type = {
val it = xs.iterator
if (it.hasNext) {
ensureUnaliased()
val last1 = new ::[A](it.next(), Nil)
if (len == 0) first = last1 else last0.next = last1
last0 = last1
len += 1
while (it.hasNext) {
if (xs.asInstanceOf[AnyRef] eq this) { // avoid mutating under our own iterator
if (len > 0) {
ensureUnaliased()
val copy = ListBuffer.from(this)
last0.next = copy.first
last0 = copy.last0
len *= 2
}
} else {
val it = xs.iterator
if (it.hasNext) {
ensureUnaliased()
val last1 = new ::[A](it.next(), Nil)
last0.next = last1
if (len == 0) first = last1 else last0.next = last1
last0 = last1
len += 1
while (it.hasNext) {
val last1 = new ::[A](it.next(), Nil)
last0.next = last1
last0 = last1
len += 1
}
}
}
this
Expand Down Expand Up @@ -230,13 +248,29 @@ class ListBuffer[A]
}

def insertAll(idx: Int, elems: IterableOnce[A]): Unit = {
ensureUnaliased()
val it = elems.iterator
if (it.hasNext) {
ensureUnaliased()
if (idx < 0 || idx > len) throw new IndexOutOfBoundsException(s"$idx is out of bounds (min 0, max ${len-1})")
if (idx == len) ++=(elems)
else insertAfter(locate(idx), it)
if (idx < 0 || idx > len) throw new IndexOutOfBoundsException(s"$idx is out of bounds (min 0, max ${len-1})")
NthPortal marked this conversation as resolved.
Show resolved Hide resolved
elems match {
case elems: AnyRef if elems eq this => // avoid mutating under our own iterator
if (len > 0) {
val copy = ListBuffer.from(this)
if (idx == 0 || idx == len) { // prepend/append
last0.next = copy.first
last0 = copy.last0
} else {
val prev = locate(idx) // cannot be `null` because other condition catches that
val follow = prev.next
prev.next = copy.first
copy.last0.next = follow
}
len *= 2
}
case elems =>
val it = elems.iterator
if (it.hasNext) {
ensureUnaliased()
if (idx == len) ++=(elems)
else insertAfter(locate(idx), it)
}
}
}

Expand Down Expand Up @@ -275,15 +309,17 @@ class ListBuffer[A]
}

def mapInPlace(f: A => A): this.type = {
ensureUnaliased()
mutationCount += 1
val buf = new ListBuffer[A]
for (elem <- this) buf += f(elem)
first = buf.first
last0 = buf.last0
aliased = false // we just assigned from a new instance
this
}

def flatMapInPlace(f: A => IterableOnce[A]): this.type = {
mutationCount += 1
var src = first
var dst: List[A] = null
last0 = null
Expand All @@ -299,6 +335,7 @@ class ListBuffer[A]
src = src.tail
}
first = if(dst eq null) Nil else dst
aliased = false // we just rebuilt a fresh, unaliased instance
this
}

Expand All @@ -322,12 +359,24 @@ class ListBuffer[A]
}

def patchInPlace(from: Int, patch: collection.IterableOnce[A], replaced: Int): this.type = {
val i = math.min(math.max(from, 0), length)
val n = math.min(math.max(replaced, 0), length)
ensureUnaliased()
val p = locate(i)
removeAfter(p, math.min(n, len - i))
insertAfter(p, patch.iterator)
val _len = len
val _from = math.max(from, 0) // normalized
val _replaced = math.max(replaced, 0) // normalized
val it = patch.iterator

val nonEmptyPatch = it.hasNext
val nonEmptyReplace = (_from < _len) && (_replaced > 0)

// don't want to add a mutation or check aliasing (potentially expensive)
// if there's no patching to do
if (nonEmptyPatch || nonEmptyReplace) {
ensureUnaliased()
val i = math.min(_from, _len)
val n = math.min(_replaced, _len)
val p = locate(i)
removeAfter(p, math.min(n, _len - i))
insertAfter(p, it)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fails with a ConcurrentModificationException if it comes from this. might fix in a separate PR. it didn't work before this PR anyway - it failed in an unspecified way (adding/removing the wrong elements) rather than with an exception

Copy link
Contributor Author

@NthPortal NthPortal Oct 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

specifically, the following code

val b = ListBuffer(1, 2, 3)
b.patchInPlace(1, b, 1)

would previously result in ListBuffer(1, 1, 3, 3) when it should result in ListBuffer(1, 1, 2, 3, 3), but now throws an exception

}
this
}

Expand Down
78 changes: 78 additions & 0 deletions src/library/scala/collection/mutable/MutationTracker.scala
@@ -0,0 +1,78 @@
/*
* Scala (https://www.scala-lang.org)
*
* Copyright EPFL and Lightbend, Inc.
*
* Licensed under Apache License 2.0
* (http://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/

package scala
package collection
package mutable

import java.util.ConcurrentModificationException

/**
* Utilities to check that mutations to a client that tracks
* its mutations have not occurred since a given point.
* [[Iterator `Iterator`]]s that perform this check automatically
* during iteration can be created by wrapping an `Iterator`
* in a [[MutationTracker.CheckedIterator `CheckedIterator`]],
* or by manually using the [[MutationTracker.checkMutations() `checkMutations`]]
* and [[MutationTracker.checkMutationsForIteration() `checkMutationsForIteration`]]
* methods.
*/
private object MutationTracker {

/**
* Checks whether or not the actual mutation count differs from
* the expected one, throwing an exception, if it does.
*
* @param expectedCount the expected mutation count
* @param actualCount the actual mutation count
* @param message the exception message in case of mutations
* @throws ConcurrentModificationException if the expected and actual
* mutation counts differ
*/
@throws[ConcurrentModificationException]
def checkMutations(expectedCount: Int, actualCount: Int, message: String): Unit = {
if (actualCount != expectedCount) throw new ConcurrentModificationException(message)
}

/**
* Checks whether or not the actual mutation count differs from
* the expected one, throwing an exception, if it does. This method
* produces an exception message saying that it was called because a
* backing collection was mutated during iteration.
*
* @param expectedCount the expected mutation count
* @param actualCount the actual mutation count
* @throws ConcurrentModificationException if the expected and actual
* mutation counts differ
*/
@throws[ConcurrentModificationException]
@inline def checkMutationsForIteration(expectedCount: Int, actualCount: Int): Unit =
checkMutations(expectedCount, actualCount, "mutation occurred during iteration")

/**
* An iterator wrapper that checks if the underlying collection has
* been mutated.
*
* @param underlying the underlying iterator
* @param mutationCount a by-name provider of the current mutation count
* @tparam A the type of the iterator's elements
*/
final class CheckedIterator[A](underlying: Iterator[A], mutationCount: => Int) extends AbstractIterator[A] {
private[this] val expectedCount = mutationCount

def hasNext: Boolean = {
checkMutationsForIteration(expectedCount, mutationCount)
underlying.hasNext
}
def next(): A = underlying.next()
}
}
16 changes: 12 additions & 4 deletions src/library/scala/collection/mutable/Shrinkable.scala
Expand Up @@ -13,7 +13,7 @@
package scala
package collection.mutable

import scala.annotation.tailrec
import scala.annotation.{nowarn, tailrec}

/** This trait forms part of collections that can be reduced
* using a `-=` operator.
Expand Down Expand Up @@ -52,16 +52,24 @@ trait Shrinkable[-A] {
* @param xs the iterator producing the elements to remove.
* @return the $coll itself
*/
@nowarn("msg=will most likely never compare equal")
def subtractAll(xs: collection.IterableOnce[A]): this.type = {
@tailrec def loop(xs: collection.LinearSeq[A]): Unit = {
if (xs.nonEmpty) {
subtractOne(xs.head)
loop(xs.tail)
}
}
xs match {
case xs: collection.LinearSeq[A] => loop(xs)
case xs => xs.iterator.foreach(subtractOne)
if (xs.asInstanceOf[AnyRef] eq this) { // avoid mutating under our own iterator
xs match {
case xs: Clearable => xs.clear()
case xs => subtractAll(Buffer.from(xs))
}
} else {
xs match {
case xs: collection.LinearSeq[A] => loop(xs)
case xs => xs.iterator.foreach(subtractOne)
}
}
this
}
Expand Down