Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ArrayOps bugs (by avoiding ArraySeq#array, which does not guarantee element type) #9641

Merged
merged 1 commit into from Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/library/scala/collection/ArrayOps.scala
Expand Up @@ -1569,18 +1569,18 @@ final class ArrayOps[A](private val xs: Array[A]) extends AnyVal {
* ''n'' times in `that`, then the first ''n'' occurrences of `x` will not form
* part of the result, but any following occurrences will.
*/
def diff[B >: A](that: Seq[B]): Array[A] = mutable.ArraySeq.make(xs).diff(that).array.asInstanceOf[Array[A]]
def diff[B >: A](that: Seq[B]): Array[A] = mutable.ArraySeq.make(xs).diff(that).toArray[A]

/** Computes the multiset intersection between this array and another sequence.
*
* @param that the sequence of elements to intersect with.
* @return a new array which contains all elements of this array
* which also appear in `that`.
* If an element value `x` appears
* ''n'' times in `that`, then the first ''n'' occurrences of `x` will be retained
* in the result, but any following occurrences will be omitted.
*/
def intersect[B >: A](that: Seq[B]): Array[A] = mutable.ArraySeq.make(xs).intersect(that).array.asInstanceOf[Array[A]]
*
* @param that the sequence of elements to intersect with.
* @return a new array which contains all elements of this array
* which also appear in `that`.
* If an element value `x` appears
* ''n'' times in `that`, then the first ''n'' occurrences of `x` will be retained
* in the result, but any following occurrences will be omitted.
*/
def intersect[B >: A](that: Seq[B]): Array[A] = mutable.ArraySeq.make(xs).intersect(that).toArray[A]

/** Groups elements in fixed size blocks by passing a "sliding window"
* over them (as opposed to partitioning them, as is done in grouped.)
Expand All @@ -1592,7 +1592,7 @@ final class ArrayOps[A](private val xs: Array[A]) extends AnyVal {
* last element (which may be the only element) will be truncated
* if there are fewer than `size` elements remaining to be grouped.
*/
def sliding(size: Int, step: Int = 1): Iterator[Array[A]] = mutable.ArraySeq.make(xs).sliding(size, step).map(_.array.asInstanceOf[Array[A]])
def sliding(size: Int, step: Int = 1): Iterator[Array[A]] = mutable.ArraySeq.make(xs).sliding(size, step).map(_.toArray[A])

/** Iterates over combinations. A _combination_ of length `n` is a subsequence of
* the original array, with the elements taken in order. Thus, `Array("x", "y")` and `Array("y", "y")`
Expand All @@ -1609,7 +1609,7 @@ final class ArrayOps[A](private val xs: Array[A]) extends AnyVal {
* Array("a", "b", "b", "b", "c").combinations(2) == Iterator(Array(a, b), Array(a, c), Array(b, b), Array(b, c))
* }}}
*/
def combinations(n: Int): Iterator[Array[A]] = mutable.ArraySeq.make(xs).combinations(n).map(_.array.asInstanceOf[Array[A]])
def combinations(n: Int): Iterator[Array[A]] = mutable.ArraySeq.make(xs).combinations(n).map(_.toArray[A])

/** Iterates over distinct permutations.
*
Expand All @@ -1618,7 +1618,7 @@ final class ArrayOps[A](private val xs: Array[A]) extends AnyVal {
* Array("a", "b", "b").permutations == Iterator(Array(a, b, b), Array(b, a, b), Array(b, b, a))
* }}}
*/
def permutations: Iterator[Array[A]] = mutable.ArraySeq.make(xs).permutations.map(_.array.asInstanceOf[Array[A]])
def permutations: Iterator[Array[A]] = mutable.ArraySeq.make(xs).permutations.map(_.toArray[A])

// we have another overload here, so we need to duplicate this method
/** Tests whether this array contains the given sequence at a given index.
Expand Down
9 changes: 9 additions & 0 deletions test/files/run/t12403.scala
@@ -0,0 +1,9 @@

object Test extends App {
val xs =
Array.empty[Double]
val ys =
Array(0.0)
assert(xs.intersect(ys).getClass.getComponentType == classOf[Double])
assert(Array.empty[Double].intersect(Array(0.0)).getClass.getComponentType == classOf[Double])
}
20 changes: 20 additions & 0 deletions test/junit/scala/collection/ArrayOpsTest.scala
Expand Up @@ -122,4 +122,24 @@ class ArrayOpsTest {
val a: Array[Byte] = new Array[Byte](1000).sortWith { _ < _ }
assertEquals(0, a(0))
}

@Test
def `empty intersection has correct component type for array`(): Unit = {
val something = Array(3.14)
val nothing = Array[Double]()
val empty = Array.empty[Double]

assertEquals(classOf[Double], nothing.intersect(something).getClass.getComponentType)
assertTrue(nothing.intersect(something).isEmpty)

assertEquals(classOf[Double], empty.intersect(something).getClass.getComponentType)
assertTrue(empty.intersect(something).isEmpty)
assertEquals(classOf[Double], empty.intersect(nothing).getClass.getComponentType)
assertTrue(empty.intersect(nothing).isEmpty)

assertEquals(classOf[Double], something.intersect(nothing).getClass.getComponentType)
assertTrue(something.intersect(nothing).isEmpty)
assertEquals(classOf[Double], something.intersect(empty).getClass.getComponentType)
assertTrue(something.intersect(empty).isEmpty)
}
}