Skip to content

Commit

Permalink
[bug#12009] Make ArrayBuffer's iterator fail-fast
Browse files Browse the repository at this point in the history
Make `ArrayBuffer`'s iterator fail-fast when the buffer is
mutated after the iterator's creation.
  • Loading branch information
NthPortal committed Jan 14, 2021
1 parent 593024b commit e579b7e
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 57 deletions.
20 changes: 20 additions & 0 deletions project/MimaFilters.scala
Expand Up @@ -24,6 +24,26 @@ object MimaFilters extends AutoPlugin {
// with JDK 11 and run MiMa it'll complain IteratorWrapper isn't forwards compatible with 2.13.0 - but we
// don't publish the artifact built with JDK 11 anyways
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.collection.convert.JavaCollectionWrappers#IteratorWrapper.asIterator"),

// Fixes for scala/bug#12009
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.collection.mutable.ArrayBufferView.this"),
ProblemFilters.exclude[FinalClassProblem]("scala.collection.IndexedSeqView$IndexedSeqViewIterator"),
ProblemFilters.exclude[FinalClassProblem]("scala.collection.IndexedSeqView$IndexedSeqViewReverseIterator"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$CheckedIterator"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$CheckedReverseIterator"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$Id"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$Appended"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$Prepended"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$Concat"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$Take"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$TakeRight"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$Drop"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$DropRight"),
ProblemFilters.exclude[MissingClassProblem](s"scala.collection.mutable.CheckedIndexedSeqView$$Map"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$Reverse"),
ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.CheckedIndexedSeqView$Slice"),
)

override val buildSettings = Seq(
Expand Down
10 changes: 1 addition & 9 deletions src/library/scala/collection/IndexedSeq.scala
Expand Up @@ -47,15 +47,7 @@ trait IndexedSeqOps[+A, +CC[_], +C] extends Any with SeqOps[A, CC, C] { self =>
s.asInstanceOf[S with EfficientSplit]
}

override def reverseIterator: Iterator[A] = new AbstractIterator[A] {
private[this] var i = self.length
def hasNext: Boolean = 0 < i
def next(): A =
if (0 < i) {
i -= 1
self(i)
} else Iterator.empty.next()
}
override def reverseIterator: Iterator[A] = view.reverseIterator

override def foldRight[B](z: B)(op: (A, B) => B): B = {
val it = reverseIterator
Expand Down
26 changes: 13 additions & 13 deletions src/library/scala/collection/IndexedSeqView.scala
Expand Up @@ -49,14 +49,15 @@ trait IndexedSeqView[+A] extends IndexedSeqOps[A, View, View[A]] with SeqView[A]
object IndexedSeqView {

@SerialVersionUID(3L)
private final class IndexedSeqViewIterator[A](self: IndexedSeqView[A]) extends AbstractIterator[A] with Serializable {
private[collection] class IndexedSeqViewIterator[A](self: IndexedSeqView[A]) extends AbstractIterator[A] with Serializable {
private[this] var current = 0
private[this] var remainder = self.size
private[this] var remainder = self.length
override def knownSize: Int = remainder
def hasNext = remainder > 0
@inline private[this] def _hasNext: Boolean = remainder > 0
def hasNext: Boolean = _hasNext
def next(): A =
if (hasNext) {
val r = self.apply(current)
if (_hasNext) {
val r = self(current)
current += 1
remainder -= 1
r
Expand All @@ -82,18 +83,18 @@ object IndexedSeqView {
}
}
@SerialVersionUID(3L)
private final class IndexedSeqViewReverseIterator[A](self: IndexedSeqView[A]) extends AbstractIterator[A] with Serializable {
private[this] var pos = self.size - 1
private[this] var remainder = self.size
def hasNext: Boolean = remainder > 0
private[collection] class IndexedSeqViewReverseIterator[A](self: IndexedSeqView[A]) extends AbstractIterator[A] with Serializable {
private[this] var pos = self.length - 1
private[this] var remainder = self.length
@inline private[this] def _hasNext: Boolean = remainder > 0
def hasNext: Boolean = _hasNext
def next(): A =
if (pos < 0) throw new NoSuchElementException
else {
if (_hasNext) {
val r = self(pos)
pos -= 1
remainder -= 1
r
}
} else Iterator.empty.next()

override def drop(n: Int): Iterator[A] = {
if (n > 0) {
Expand All @@ -103,7 +104,6 @@ object IndexedSeqView {
this
}


override def sliceIterator(from: Int, until: Int): Iterator[A] = {
val startCutoff = pos
val untilCutoff = startCutoff - remainder + 1
Expand Down
90 changes: 68 additions & 22 deletions src/library/scala/collection/mutable/ArrayBuffer.scala
Expand Up @@ -39,6 +39,7 @@ import scala.util.chaining._
* @define mayNotTerminateInf
* @define willNotTerminateInf
*/
@SerialVersionUID(-1582447879429021880L)
class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
extends AbstractBuffer[A]
with IndexedBuffer[A]
Expand All @@ -51,6 +52,8 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)

def this(initialSize: Int) = this(new Array[AnyRef](initialSize max 1), 0)

@transient private[this] var mutationCount: Int = 0

protected[collection] var array: Array[AnyRef] = initialElements
protected var size0 = initialSize

Expand All @@ -62,14 +65,17 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
override def knownSize: Int = super[IndexedSeqOps].knownSize

/** Ensure that the internal array has at least `n` cells. */
protected def ensureSize(n: Int): Unit =
protected def ensureSize(n: Int): Unit = {
mutationCount += 1
array = ArrayBuffer.ensureSize(array, size0, n)
}

def sizeHint(size: Int): Unit =
if(size > length && size >= 1) ensureSize(size)

/** Reduce length to `n`, nulling out all dropped elements */
private def reduceToSize(n: Int): Unit = {
mutationCount += 1
Arrays.fill(array, n, size0, null)
size0 = n
}
Expand All @@ -79,7 +85,10 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
* which may replace the array by a shorter one.
* This allows releasing some unused memory.
*/
def trimToSize(): Unit = resize(length)
def trimToSize(): Unit = {
mutationCount += 1
resize(length)
}

/** Trims the `array` buffer size down to either a power of 2
* or Int.MaxValue while keeping first `requiredLength` elements.
Expand All @@ -99,12 +108,13 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)

def update(@deprecatedName("n", "2.13.0") index: Int, elem: A): Unit = {
checkWithinBounds(index, index + 1)
mutationCount += 1
array(index) = elem.asInstanceOf[AnyRef]
}

def length = size0

override def view: ArrayBufferView[A] = new ArrayBufferView(array, size0)
override def view: ArrayBufferView[A] = new ArrayBufferView(array, size0, () => mutationCount)

override def iterableFactory: SeqFactory[ArrayBuffer] = ArrayBuffer

Expand Down Expand Up @@ -136,9 +146,12 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
override def addAll(elems: IterableOnce[A]): this.type = {
elems match {
case elems: ArrayBuffer[_] =>
ensureSize(length + elems.length)
Array.copy(elems.array, 0, array, length, elems.length)
size0 = length + elems.length
val elemsLength = elems.size0
if (elemsLength > 0) {
ensureSize(length + elemsLength)
Array.copy(elems.array, 0, array, length, elemsLength)
size0 = length + elemsLength
}
case _ => super.addAll(elems)
}
this
Expand All @@ -164,19 +177,21 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
insertAll(index, ArrayBuffer.from(this))
case elems: collection.Iterable[A] =>
val elemsLength = elems.size
ensureSize(length + elemsLength)
Array.copy(array, index, array, index + elemsLength, size0 - index)
size0 = size0 + elemsLength
elems match {
case elems: ArrayBuffer[_] =>
Array.copy(elems.array, 0, array, index, elemsLength)
case _ =>
var i = 0
val it = elems.iterator
while (i < elemsLength) {
this(index + i) = it.next()
i += 1
}
if (elemsLength > 0) {
ensureSize(length + elemsLength)
Array.copy(array, index, array, index + elemsLength, size0 - index)
size0 = size0 + elemsLength
elems match {
case elems: ArrayBuffer[_] =>
Array.copy(elems.array, 0, array, index, elemsLength)
case _ =>
var i = 0
val it = elems.iterator
while (i < elemsLength) {
this(index + i) = it.next()
i += 1
}
}
}
case _ =>
insertAll(index, ArrayBuffer.from(elems))
Expand Down Expand Up @@ -232,7 +247,10 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
* @return modified input $coll sorted according to the ordering `ord`.
*/
override def sortInPlace[B >: A]()(implicit ord: Ordering[B]): this.type = {
if (length > 1) scala.util.Sorting.stableSort(array.asInstanceOf[Array[B]], 0, length)
if (length > 1) {
mutationCount += 1
scala.util.Sorting.stableSort(array.asInstanceOf[Array[B]], 0, length)
}
this
}
}
Expand Down Expand Up @@ -297,8 +315,36 @@ object ArrayBuffer extends StrictOptimizedSeqFactory[ArrayBuffer] {
}
}

final class ArrayBufferView[A](val array: Array[AnyRef], val length: Int) extends AbstractIndexedSeqView[A] {
final class ArrayBufferView[A] private[mutable](val array: Array[AnyRef], val length: Int, mutationCount: () => Int)
extends AbstractIndexedSeqView[A] {
@deprecated("never intended to be public; call ArrayBuffer#view instead", since = "2.13.5")
def this(array: Array[AnyRef], length: Int) = {
// this won't actually track mutation, but it would be a pain to have the implementation
// check if we have a method to get the current mutation count or not on every method and
// change what it does based on that. hopefully no one ever calls this.
this(array, length, () => 0)
}

@throws[ArrayIndexOutOfBoundsException]
def apply(n: Int) = if (n < length) array(n).asInstanceOf[A] else throw new IndexOutOfBoundsException(s"$n is out of bounds (min 0, max ${length - 1})")
def apply(n: Int): A = if (n < length) array(n).asInstanceOf[A] else throw new IndexOutOfBoundsException(s"$n is out of bounds (min 0, max ${length - 1})")
override protected[this] def className = "ArrayBufferView"

// we could inherit all these from `CheckedIndexedSeqView`, except this class is public
override def iterator: Iterator[A] = new CheckedIndexedSeqView.CheckedIterator(this, mutationCount())
override def reverseIterator: Iterator[A] = new CheckedIndexedSeqView.CheckedReverseIterator(this, mutationCount())

override def appended[B >: A](elem: B): IndexedSeqView[B] = new CheckedIndexedSeqView.Appended(this, elem)(mutationCount)
override def prepended[B >: A](elem: B): IndexedSeqView[B] = new CheckedIndexedSeqView.Prepended(elem, this)(mutationCount)
override def take(n: Int): IndexedSeqView[A] = new CheckedIndexedSeqView.Take(this, n)(mutationCount)
override def takeRight(n: Int): IndexedSeqView[A] = new CheckedIndexedSeqView.TakeRight(this, n)(mutationCount)
override def drop(n: Int): IndexedSeqView[A] = new CheckedIndexedSeqView.Drop(this, n)(mutationCount)
override def dropRight(n: Int): IndexedSeqView[A] = new CheckedIndexedSeqView.DropRight(this, n)(mutationCount)
override def map[B](f: A => B): IndexedSeqView[B] = new CheckedIndexedSeqView.Map(this, f)(mutationCount)
override def reverse: IndexedSeqView[A] = new CheckedIndexedSeqView.Reverse(this)(mutationCount)
override def slice(from: Int, until: Int): IndexedSeqView[A] = new CheckedIndexedSeqView.Slice(this, from, until)(mutationCount)
override def tapEach[U](f: A => U): IndexedSeqView[A] = new CheckedIndexedSeqView.Map(this, { (a: A) => f(a); a})(mutationCount)

override def concat[B >: A](suffix: IndexedSeqView.SomeIndexedSeqOps[B]): IndexedSeqView[B] = new CheckedIndexedSeqView.Concat(this, suffix)(mutationCount)
override def appendedAll[B >: A](suffix: IndexedSeqView.SomeIndexedSeqOps[B]): IndexedSeqView[B] = new CheckedIndexedSeqView.Concat(this, suffix)(mutationCount)
override def prependedAll[B >: A](prefix: IndexedSeqView.SomeIndexedSeqOps[B]): IndexedSeqView[B] = new CheckedIndexedSeqView.Concat(prefix, this)(mutationCount)
}
117 changes: 117 additions & 0 deletions src/library/scala/collection/mutable/CheckedIndexedSeqView.scala
@@ -0,0 +1,117 @@
/*
* Scala (https://www.scala-lang.org)
*
* Copyright EPFL and Lightbend, Inc.
*
* Licensed under Apache License 2.0
* (http://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/

package scala
package collection
package mutable

private[mutable] trait CheckedIndexedSeqView[+A] extends IndexedSeqView[A] {
protected val mutationCount: () => Int

override def iterator: Iterator[A] = new CheckedIndexedSeqView.CheckedIterator(this, mutationCount())
override def reverseIterator: Iterator[A] = new CheckedIndexedSeqView.CheckedReverseIterator(this, mutationCount())

override def appended[B >: A](elem: B): IndexedSeqView[B] = new CheckedIndexedSeqView.Appended(this, elem)(mutationCount)
override def prepended[B >: A](elem: B): IndexedSeqView[B] = new CheckedIndexedSeqView.Prepended(elem, this)(mutationCount)
override def take(n: Int): IndexedSeqView[A] = new CheckedIndexedSeqView.Take(this, n)(mutationCount)
override def takeRight(n: Int): IndexedSeqView[A] = new CheckedIndexedSeqView.TakeRight(this, n)(mutationCount)
override def drop(n: Int): IndexedSeqView[A] = new CheckedIndexedSeqView.Drop(this, n)(mutationCount)
override def dropRight(n: Int): IndexedSeqView[A] = new CheckedIndexedSeqView.DropRight(this, n)(mutationCount)
override def map[B](f: A => B): IndexedSeqView[B] = new CheckedIndexedSeqView.Map(this, f)(mutationCount)
override def reverse: IndexedSeqView[A] = new CheckedIndexedSeqView.Reverse(this)(mutationCount)
override def slice(from: Int, until: Int): IndexedSeqView[A] = new CheckedIndexedSeqView.Slice(this, from, until)(mutationCount)
override def tapEach[U](f: A => U): IndexedSeqView[A] = new CheckedIndexedSeqView.Map(this, { (a: A) => f(a); a})(mutationCount)

override def concat[B >: A](suffix: IndexedSeqView.SomeIndexedSeqOps[B]): IndexedSeqView[B] = new CheckedIndexedSeqView.Concat(this, suffix)(mutationCount)
override def appendedAll[B >: A](suffix: IndexedSeqView.SomeIndexedSeqOps[B]): IndexedSeqView[B] = new CheckedIndexedSeqView.Concat(this, suffix)(mutationCount)
override def prependedAll[B >: A](prefix: IndexedSeqView.SomeIndexedSeqOps[B]): IndexedSeqView[B] = new CheckedIndexedSeqView.Concat(prefix, this)(mutationCount)
}

private[mutable] object CheckedIndexedSeqView {
import IndexedSeqView.SomeIndexedSeqOps

@SerialVersionUID(3L)
private[mutable] class CheckedIterator[A](self: IndexedSeqView[A], mutationCount: => Int)
extends IndexedSeqView.IndexedSeqViewIterator[A](self) {
private[this] val expectedCount = mutationCount
override def hasNext: Boolean = {
MutationTracker.checkMutationsForIteration(expectedCount, mutationCount)
super.hasNext
}
}

@SerialVersionUID(3L)
private[mutable] class CheckedReverseIterator[A](self: IndexedSeqView[A], mutationCount: => Int)
extends IndexedSeqView.IndexedSeqViewReverseIterator[A](self) {
private[this] val expectedCount = mutationCount
override def hasNext: Boolean = {
MutationTracker.checkMutationsForIteration(expectedCount, mutationCount)
super.hasNext
}
}

@SerialVersionUID(3L)
class Id[+A](underlying: SomeIndexedSeqOps[A])(protected val mutationCount: () => Int)
extends IndexedSeqView.Id(underlying) with CheckedIndexedSeqView[A]

@SerialVersionUID(3L)
class Appended[+A](underlying: SomeIndexedSeqOps[A], elem: A)(protected val mutationCount: () => Int)
extends IndexedSeqView.Appended(underlying, elem) with CheckedIndexedSeqView[A]

@SerialVersionUID(3L)
class Prepended[+A](elem: A, underlying: SomeIndexedSeqOps[A])(protected val mutationCount: () => Int)
extends IndexedSeqView.Prepended(elem, underlying) with CheckedIndexedSeqView[A]

@SerialVersionUID(3L)
class Concat[A](prefix: SomeIndexedSeqOps[A], suffix: SomeIndexedSeqOps[A])(protected val mutationCount: () => Int)
extends IndexedSeqView.Concat[A](prefix, suffix) with CheckedIndexedSeqView[A]

@SerialVersionUID(3L)
class Take[A](underlying: SomeIndexedSeqOps[A], n: Int)(protected val mutationCount: () => Int)
extends IndexedSeqView.Take(underlying, n) with CheckedIndexedSeqView[A]

@SerialVersionUID(3L)
class TakeRight[A](underlying: SomeIndexedSeqOps[A], n: Int)(protected val mutationCount: () => Int)
extends IndexedSeqView.TakeRight(underlying, n) with CheckedIndexedSeqView[A]

@SerialVersionUID(3L)
class Drop[A](underlying: SomeIndexedSeqOps[A], n: Int)(protected val mutationCount: () => Int)
extends IndexedSeqView.Drop[A](underlying, n) with CheckedIndexedSeqView[A]

@SerialVersionUID(3L)
class DropRight[A](underlying: SomeIndexedSeqOps[A], n: Int)(protected val mutationCount: () => Int)
extends IndexedSeqView.DropRight[A](underlying, n) with CheckedIndexedSeqView[A]

@SerialVersionUID(3L)
class Map[A, B](underlying: SomeIndexedSeqOps[A], f: A => B)(protected val mutationCount: () => Int)
extends IndexedSeqView.Map(underlying, f) with CheckedIndexedSeqView[B]

@SerialVersionUID(3L)
class Reverse[A](underlying: SomeIndexedSeqOps[A])(protected val mutationCount: () => Int)
extends IndexedSeqView.Reverse[A](underlying) with CheckedIndexedSeqView[A] {
override def reverse: IndexedSeqView[A] = underlying match {
case x: IndexedSeqView[A] => x
case _ => super.reverse
}
}

@SerialVersionUID(3L)
class Slice[A](underlying: SomeIndexedSeqOps[A], from: Int, until: Int)(protected val mutationCount: () => Int)
extends AbstractIndexedSeqView[A] with CheckedIndexedSeqView[A] {
protected val lo = from max 0
protected val hi = (until max 0) min underlying.length
protected val len = (hi - lo) max 0
@throws[IndexOutOfBoundsException]
def apply(i: Int): A = underlying(lo + i)
def length: Int = len
}
}

0 comments on commit e579b7e

Please sign in to comment.