Skip to content

Commit

Permalink
Support Match nodes with non-int-literals in the back-end.
Browse files Browse the repository at this point in the history
Since Scala 2.13.2, the pattern matcher will keep `Match` nodes
that match on `String`s and `null`s as is, to be desugared later
in `cleanup`. This was implemented upstream in
scala/scala#8451

We implement a more general translation that will accept any kind
of scalac `Literal`. If the scrutinee is an integer and all cases
are int literals, we emit a `js.Match` as before. Otherwise, we
emit an `if..else` chain with `===` comparisons.
  • Loading branch information
sjrd committed Nov 11, 2019
1 parent 5c23320 commit 24a22d1
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 28 deletions.
80 changes: 53 additions & 27 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala
Expand Up @@ -3230,15 +3230,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
implicit val pos = tree.pos
val Match(selector, cases) = tree

/* We adapt the selector to IntType so that we can use it in a js.Match,
* just like GenBCode does for the JVM. This seems to be redundant,
* though, as anything that comes out of the pattern matching has already
* been adapted to an Int (along with the cases). However, since GenBCode
* adapts, we do the same, to be on the safe side (for example, a
* compiler plugin could generate a Match with other types of
* primitives ...).
/* Although GenBCode adapts the scrutinee and the cases to `int`, only
* true `int`s can reach the back-end, as asserted by the String-switch
* transformation in `cleanup`. Therefore, we do not adapt, preserving
* the `string`s and `null`s that come out of the pattern matching in
* Scala 2.13.2+.
*/
val expr = adaptPrimitive(genExpr(selector), jstpe.IntType)
val genSelector = genExpr(selector)

val resultType = toIRType(tree.tpe)

Expand All @@ -3248,7 +3246,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
body.symbol
}.getOrElse(NoSymbol)

var clauses: List[(List[js.IntLiteral], js.Tree)] = Nil
var clauses: List[(List[js.Tree], js.Tree)] = Nil
var optElseClause: Option[js.Tree] = None
var optElseClauseLabel: Option[js.LabelIdent] = None

Expand Down Expand Up @@ -3299,16 +3297,9 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
genStatOrExpr(body, isStat)
}

/* value.intValue implicitly adapts the constant value to an Int. This
* is also what GenBCode for the JVM. See also the comment about
* adaptPrimitive at the beginning of this method.
*/
def genLiteral(lit: Literal): js.IntLiteral =
js.IntLiteral(lit.value.intValue)(lit.pos)

pat match {
case lit: Literal =>
clauses = (List(genLiteral(lit)), genBody(body)) :: clauses
clauses = (List(genExpr(lit)), genBody(body)) :: clauses
case Ident(nme.WILDCARD) =>
optElseClause = Some(body match {
case LabelDef(_, Nil, rhs) if hasSynthCaseSymbol(body) =>
Expand All @@ -3319,7 +3310,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
case Alternative(alts) =>
val genAlts = {
alts map {
case lit: Literal => genLiteral(lit)
case lit: Literal => genExpr(lit)
case _ =>
abort("Invalid case in alternative in switch-like pattern match: " +
tree + " at: " + tree.pos)
Expand All @@ -3341,31 +3332,66 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
* case is a typical product of `match`es that are full of
* `case n if ... =>`, which are used instead of `if` chains for
* convenience and/or readability.
*
* When no optimization applies, and any of the case values is not a
* literal int, we emit a series of `if..else` instead of a `js.Match`.
* This became necessary in 2.13.2 with strings and nulls.
*/
def buildMatch(selector: js.Tree,
cases: List[(List[js.IntLiteral], js.Tree)],
def buildMatch(cases: List[(List[js.Tree], js.Tree)],
default: js.Tree, tpe: jstpe.Type): js.Tree = {

def isInt(tree: js.Tree): Boolean = tree.tpe == jstpe.IntType

cases match {
case Nil =>
/* Completely remove the Match. Preserve the side-effects of
* `selector`.
* `genSelector`.
*/
js.Block(exprToStat(selector), default)
js.Block(exprToStat(genSelector), default)

case (uniqueAlt :: Nil, caseRhs) :: Nil =>
/* Simplify the `match` as an `if`, so that the optimizer has less
* work to do, and we emit less code at the end of the day.
* Use `Int_==` instead of `===` if possible, since it is a common
* case.
*/
js.If(js.BinaryOp(js.BinaryOp.Int_==, selector, uniqueAlt),
caseRhs, default)(tpe)
val op =
if (isInt(genSelector) && isInt(uniqueAlt)) js.BinaryOp.Int_==
else js.BinaryOp.===
js.If(js.BinaryOp(op, genSelector, uniqueAlt), caseRhs, default)(tpe)

case _ =>
js.Match(selector, cases, default)(tpe)
if (isInt(genSelector) &&
cases.forall(_._1.forall(_.isInstanceOf[js.IntLiteral]))) {
// We have int literals only: use a js.Match
val intCases = cases.asInstanceOf[List[(List[js.IntLiteral], js.Tree)]]
js.Match(genSelector, intCases, default)(tpe)
} else {
// We have other stuff: generate an if..else chain
val (tempSelectorDef, tempSelectorRef) = genSelector match {
case varRef: js.VarRef =>
(js.Skip(), varRef)
case _ =>
val varDef = js.VarDef(freshLocalIdent(), genSelector.tpe,
mutable = false, genSelector)
(varDef, varDef.ref)
}
val ifElseChain = cases.foldRight(default) { (caze, elsep) =>
val conds = caze._1.map { caseValue =>
js.BinaryOp(js.BinaryOp.===, tempSelectorRef, caseValue)
}
val cond = conds.reduceRight[js.Tree] { (left, right) =>
js.If(left, js.BooleanLiteral(true), right)(jstpe.BooleanType)
}
js.If(cond, caze._2, elsep)(tpe)
}
js.Block(tempSelectorDef, ifElseChain)
}
}
}

optElseClauseLabel.fold[js.Tree] {
buildMatch(expr, clauses.reverse, elseClause, resultType)
buildMatch(clauses.reverse, elseClause, resultType)
} { elseClauseLabel =>
val matchResultLabel = freshLabelIdent("matchResult")
val patchedClauses = for ((alts, body) <- clauses) yield {
Expand All @@ -3377,7 +3403,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
}
js.Labeled(matchResultLabel, resultType, js.Block(List(
js.Labeled(elseClauseLabel, jstpe.NoType, {
buildMatch(expr, patchedClauses.reverse, js.Skip(), jstpe.NoType)
buildMatch(patchedClauses.reverse, js.Skip(), jstpe.NoType)
}),
elseClause
)))
Expand Down
Expand Up @@ -477,7 +477,38 @@ class RegressionTest {
assertThrows(classOf[MatchError], bug.bug(2, false))
}

@Test def return_x_match_issue_2928(): Unit = {
@Test def return_x_match_issue_2928_ints(): Unit = {
// scalastyle:off return

def testNonUnit(x: Int): Boolean = {
return x match {
case 1 => true
case _ => false
}
}

var r: Option[Boolean] = None

def testUnit(x: Int): Unit = {
return x match {
case 1 => r = Some(true)
case _ => r = Some(false)
}
}

assertEquals(true, testNonUnit(1))
assertEquals(false, testNonUnit(2))

testUnit(1)
assertEquals(Some(true), r)
r = None
testUnit(2)
assertEquals(Some(false), r)

// scalastyle:on return
}

@Test def return_x_match_issue_2928_strings(): Unit = {
// scalastyle:off return

def testNonUnit(x: String): Boolean = {
Expand Down Expand Up @@ -508,6 +539,41 @@ class RegressionTest {
// scalastyle:on return
}

@Test def return_x_match_issue_2928_lists(): Unit = {
// scalastyle:off return

def testNonUnit(x: List[String]): Boolean = {
return x match {
case "True" :: Nil => true
case _ => false
}
}

var r: Option[Boolean] = None

def testUnit(x: List[String]): Unit = {
return x match {
case "True" :: Nil => r = Some(true)
case _ => r = Some(false)
}
}

assertEquals(true, testNonUnit("True" :: Nil))
assertEquals(false, testNonUnit("not true" :: Nil))
assertEquals(false, testNonUnit("True" :: "second" :: Nil))

testUnit("True" :: Nil)
assertEquals(Some(true), r)
r = None
testUnit("not true" :: Nil)
assertEquals(Some(false), r)
r = None
testUnit("True" :: "second" :: Nil)
assertEquals(Some(false), r)

// scalastyle:on return
}

@Test def null_asInstanceOf_Unit_should_succeed_issue_1691(): Unit = {
/* Avoid scalac's special treatment of `<literal null>.asInstanceOf[X]`.
* It does have the benefit to test our constant-folder of that pattern,
Expand Down

0 comments on commit 24a22d1

Please sign in to comment.