Skip to content

Commit

Permalink
Merge pull request #9641 from som-snytt/issue/12403
Browse files Browse the repository at this point in the history
Fix `ArrayOps` bugs (by avoiding using `ArraySeq#array`, which does not guarantee element type)
  • Loading branch information
lrytz committed Jun 2, 2021
2 parents f74ccec + 836c5a9 commit 074cae1
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/library/scala/collection/ArrayOps.scala
Expand Up @@ -1569,18 +1569,18 @@ final class ArrayOps[A](private val xs: Array[A]) extends AnyVal {
* ''n'' times in `that`, then the first ''n'' occurrences of `x` will not form
* part of the result, but any following occurrences will.
*/
def diff[B >: A](that: Seq[B]): Array[A] = mutable.ArraySeq.make(xs).diff(that).array.asInstanceOf[Array[A]]
def diff[B >: A](that: Seq[B]): Array[A] = mutable.ArraySeq.make(xs).diff(that).toArray[A]

/** Computes the multiset intersection between this array and another sequence.
*
* @param that the sequence of elements to intersect with.
* @return a new array which contains all elements of this array
* which also appear in `that`.
* If an element value `x` appears
* ''n'' times in `that`, then the first ''n'' occurrences of `x` will be retained
* in the result, but any following occurrences will be omitted.
*/
def intersect[B >: A](that: Seq[B]): Array[A] = mutable.ArraySeq.make(xs).intersect(that).array.asInstanceOf[Array[A]]
*
* @param that the sequence of elements to intersect with.
* @return a new array which contains all elements of this array
* which also appear in `that`.
* If an element value `x` appears
* ''n'' times in `that`, then the first ''n'' occurrences of `x` will be retained
* in the result, but any following occurrences will be omitted.
*/
def intersect[B >: A](that: Seq[B]): Array[A] = mutable.ArraySeq.make(xs).intersect(that).toArray[A]

/** Groups elements in fixed size blocks by passing a "sliding window"
* over them (as opposed to partitioning them, as is done in grouped.)
Expand All @@ -1592,7 +1592,7 @@ final class ArrayOps[A](private val xs: Array[A]) extends AnyVal {
* last element (which may be the only element) will be truncated
* if there are fewer than `size` elements remaining to be grouped.
*/
def sliding(size: Int, step: Int = 1): Iterator[Array[A]] = mutable.ArraySeq.make(xs).sliding(size, step).map(_.array.asInstanceOf[Array[A]])
def sliding(size: Int, step: Int = 1): Iterator[Array[A]] = mutable.ArraySeq.make(xs).sliding(size, step).map(_.toArray[A])

/** Iterates over combinations. A _combination_ of length `n` is a subsequence of
* the original array, with the elements taken in order. Thus, `Array("x", "y")` and `Array("y", "y")`
Expand All @@ -1609,7 +1609,7 @@ final class ArrayOps[A](private val xs: Array[A]) extends AnyVal {
* Array("a", "b", "b", "b", "c").combinations(2) == Iterator(Array(a, b), Array(a, c), Array(b, b), Array(b, c))
* }}}
*/
def combinations(n: Int): Iterator[Array[A]] = mutable.ArraySeq.make(xs).combinations(n).map(_.array.asInstanceOf[Array[A]])
def combinations(n: Int): Iterator[Array[A]] = mutable.ArraySeq.make(xs).combinations(n).map(_.toArray[A])

/** Iterates over distinct permutations.
*
Expand All @@ -1618,7 +1618,7 @@ final class ArrayOps[A](private val xs: Array[A]) extends AnyVal {
* Array("a", "b", "b").permutations == Iterator(Array(a, b, b), Array(b, a, b), Array(b, b, a))
* }}}
*/
def permutations: Iterator[Array[A]] = mutable.ArraySeq.make(xs).permutations.map(_.array.asInstanceOf[Array[A]])
def permutations: Iterator[Array[A]] = mutable.ArraySeq.make(xs).permutations.map(_.toArray[A])

// we have another overload here, so we need to duplicate this method
/** Tests whether this array contains the given sequence at a given index.
Expand Down
9 changes: 9 additions & 0 deletions test/files/run/t12403.scala
@@ -0,0 +1,9 @@

object Test extends App {
val xs =
Array.empty[Double]
val ys =
Array(0.0)
assert(xs.intersect(ys).getClass.getComponentType == classOf[Double])
assert(Array.empty[Double].intersect(Array(0.0)).getClass.getComponentType == classOf[Double])
}
20 changes: 20 additions & 0 deletions test/junit/scala/collection/ArrayOpsTest.scala
Expand Up @@ -122,4 +122,24 @@ class ArrayOpsTest {
val a: Array[Byte] = new Array[Byte](1000).sortWith { _ < _ }
assertEquals(0, a(0))
}

@Test
def `empty intersection has correct component type for array`(): Unit = {
val something = Array(3.14)
val nothing = Array[Double]()
val empty = Array.empty[Double]

assertEquals(classOf[Double], nothing.intersect(something).getClass.getComponentType)
assertTrue(nothing.intersect(something).isEmpty)

assertEquals(classOf[Double], empty.intersect(something).getClass.getComponentType)
assertTrue(empty.intersect(something).isEmpty)
assertEquals(classOf[Double], empty.intersect(nothing).getClass.getComponentType)
assertTrue(empty.intersect(nothing).isEmpty)

assertEquals(classOf[Double], something.intersect(nothing).getClass.getComponentType)
assertTrue(something.intersect(nothing).isEmpty)
assertEquals(classOf[Double], something.intersect(empty).getClass.getComponentType)
assertTrue(something.intersect(empty).isEmpty)
}
}

0 comments on commit 074cae1

Please sign in to comment.