Skip to content

Commit

Permalink
Emit efficient code for switch over strings
Browse files Browse the repository at this point in the history
The pattern matcher will now emit `Match` with `String` scrutinee as
well as the existing `Int` scrutinee. The JVM backend handles this case
by emitting bytecode that switches on the String's `hashCode` (this
matches what Java does). The SJS already handles `String` matches.

The approach is similar to scala/scala#8451 (see scala/bug#11740 too),
except that instead of doing a transformation on the AST, we just emit the
right bytecode straight away. This is desirable since it means that
Scala.js (and any other backend) can choose their own optimised strategy
for compiling a match on strings.

Fixes scala#11923
  • Loading branch information
harpocrates committed Aug 18, 2021
1 parent 4af4ffe commit 3d8fa54
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 75 deletions.
202 changes: 156 additions & 46 deletions compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala
Expand Up @@ -3,6 +3,7 @@ package backend
package jvm

import scala.annotation.switch
import scala.collection.mutable.SortedMap

import scala.tools.asm
import scala.tools.asm.{Handle, Label, Opcodes}
Expand Down Expand Up @@ -840,61 +841,170 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
generatedType
}

/*
* A Match node contains one or more case clauses,
* each case clause lists one or more Int values to use as keys, and a code block.
* Except the "default" case clause which (if it exists) doesn't list any Int key.
*
* On a first pass over the case clauses, we flatten the keys and their targets (the latter represented with asm.Labels).
* That representation allows JCodeMethodV to emit a lookupswitch or a tableswitch.
*
* On a second pass, we emit the switch blocks, one for each different target.
/* A Match node contains one or more case clauses, each case clause lists one or more
* Int/String values to use as keys, and a code block. The exception is the "default" case
* clause which doesn't list any key (there is exactly one of these per match).
*/
private def genMatch(tree: Match): BType = tree match {
case Match(selector, cases) =>
lineNumber(tree)
genLoad(selector, INT)
val generatedType = tpeTK(tree)
val postMatch = new asm.Label

var flatKeys: List[Int] = Nil
var targets: List[asm.Label] = Nil
var default: asm.Label = null
var switchBlocks: List[(asm.Label, Tree)] = Nil

// collect switch blocks and their keys, but don't emit yet any switch-block.
for (caze @ CaseDef(pat, guard, body) <- cases) {
assert(guard == tpd.EmptyTree, guard)
val switchBlockPoint = new asm.Label
switchBlocks ::= (switchBlockPoint, body)
pat match {
case Literal(value) =>
flatKeys ::= value.intValue
targets ::= switchBlockPoint
case Ident(nme.WILDCARD) =>
assert(default == null, s"multiple default targets in a Match node, at ${tree.span}")
default = switchBlockPoint
case Alternative(alts) =>
alts foreach {
case Literal(value) =>
flatKeys ::= value.intValue
targets ::= switchBlockPoint
case _ =>
abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}")
}
case _ =>
abort(s"Invalid pattern in Match node: $tree at: ${tree.span}")
// Only two possible selector types exist in `Match` trees at this point: Int and String
if (tpeTK(selector) == INT) {

/* On a first pass over the case clauses, we flatten the keys and their
* targets (the latter represented with asm.Labels). That representation
* allows JCodeMethodV to emit a lookupswitch or a tableswitch.
*
* On a second pass, we emit the switch blocks, one for each different target.
*/

var flatKeys: List[Int] = Nil
var targets: List[asm.Label] = Nil
var default: asm.Label = null
var switchBlocks: List[(asm.Label, Tree)] = Nil

genLoad(selector, INT)

// collect switch blocks and their keys, but don't emit yet any switch-block.
for (caze @ CaseDef(pat, guard, body) <- cases) {
assert(guard == tpd.EmptyTree, guard)
val switchBlockPoint = new asm.Label
switchBlocks ::= (switchBlockPoint, body)
pat match {
case Literal(value) =>
flatKeys ::= value.intValue
targets ::= switchBlockPoint
case Ident(nme.WILDCARD) =>
assert(default == null, s"multiple default targets in a Match node, at ${tree.span}")
default = switchBlockPoint
case Alternative(alts) =>
alts foreach {
case Literal(value) =>
flatKeys ::= value.intValue
targets ::= switchBlockPoint
case _ =>
abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}")
}
case _ =>
abort(s"Invalid pattern in Match node: $tree at: ${tree.span}")
}
}
}

bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY)
bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY)

// emit switch-blocks.
val postMatch = new asm.Label
for (sb <- switchBlocks.reverse) {
val (caseLabel, caseBody) = sb
markProgramPoint(caseLabel)
genLoad(caseBody, generatedType)
bc goTo postMatch
// emit switch-blocks.
for (sb <- switchBlocks.reverse) {
val (caseLabel, caseBody) = sb
markProgramPoint(caseLabel)
genLoad(caseBody, generatedType)
bc goTo postMatch
}
} else {

/* Since the JVM doesn't have a way to switch on a string, we switch
* on the `hashCode` of the string then do an `equals` check (with a
* possible second set of jumps if blocks can be reach from multiple
* string alternatives).
*
* This mirrors the way that Java compiles `switch` on Strings.
*/

var default: asm.Label = null
var indirectBlocks: List[(asm.Label, Tree)] = Nil

import scala.collection.mutable

// Cases grouped by their hashCode
val casesByHash = SortedMap.empty[Int, List[(String, Either[asm.Label, Tree])]]
var caseFallback: Tree = null

for (caze @ CaseDef(pat, guard, body) <- cases) {
assert(guard == tpd.EmptyTree, guard)
pat match {
case Literal(value) =>
val strValue = value.stringValue
casesByHash.updateWith(strValue.##) { existingCasesOpt =>
val newCase = (strValue, Right(body))
Some(newCase :: existingCasesOpt.getOrElse(Nil))
}
case Ident(nme.WILDCARD) =>
assert(default == null, s"multiple default targets in a Match node, at ${tree.span}")
default = new asm.Label
indirectBlocks ::= (default, body)
case Alternative(alts) =>
// We need an extra basic block since multiple strings can lead to this code
val indirectCaseGroupLabel = new asm.Label
indirectBlocks ::= (indirectCaseGroupLabel, body)
alts foreach {
case Literal(value) =>
val strValue = value.stringValue
casesByHash.updateWith(strValue.##) { existingCasesOpt =>
val newCase = (strValue, Left(indirectCaseGroupLabel))
Some(newCase :: existingCasesOpt.getOrElse(Nil))
}
case _ =>
abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}")
}

case _ =>
abort(s"Invalid pattern in Match node: $tree at: ${tree.span}")
}
}

// Organize the hashCode options into switch cases
var flatKeys: List[Int] = Nil
var targets: List[asm.Label] = Nil
var hashBlocks: List[(asm.Label, List[(String, Either[asm.Label, Tree])])] = Nil
for ((hashValue, hashCases) <- casesByHash) {
val switchBlockPoint = new asm.Label
hashBlocks ::= (switchBlockPoint, hashCases)
flatKeys ::= hashValue
targets ::= switchBlockPoint
}

// Push the hashCode of the string (or `0` it is `null`) onto the stack and switch on it
genLoadIf(
If(
tree.selector.select(defn.Any_==).appliedTo(nullLiteral),
Literal(Constant(0)),
tree.selector.select(defn.Any_hashCode).appliedToNone
),
INT
)
bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY)

// emit blocks for each hash case
for ((hashLabel, caseAlternatives) <- hashBlocks.reverse) {
markProgramPoint(hashLabel)
for ((caseString, indirectLblOrBody) <- caseAlternatives) {
val comparison = if (caseString == null) defn.Any_== else defn.Any_equals
val condp = Literal(Constant(caseString)).select(defn.Any_==).appliedTo(tree.selector)
val keepGoing = new asm.Label
indirectLblOrBody match {
case Left(jump) =>
genCond(condp, jump, keepGoing, targetIfNoJump = keepGoing)

case Right(caseBody) =>
val thisCaseMatches = new asm.Label
genCond(condp, thisCaseMatches, keepGoing, targetIfNoJump = thisCaseMatches)
markProgramPoint(thisCaseMatches)
genLoad(caseBody, generatedType)
bc goTo postMatch
}
markProgramPoint(keepGoing)
}
bc goTo default
}

// emit blocks for common patterns
for ((caseLabel, caseBody) <- indirectBlocks.reverse) {
markProgramPoint(caseLabel)
genLoad(caseBody, generatedType)
bc goTo postMatch
}
}

markProgramPoint(postMatch)
Expand Down
11 changes: 0 additions & 11 deletions compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala
Expand Up @@ -2872,12 +2872,6 @@ class JSCodeGen()(using genCtx: Context) {
def abortMatch(msg: String): Nothing =
throw new FatalError(s"$msg in switch-like pattern match at ${tree.span}: $tree")

/* Although GenBCode adapts the scrutinee and the cases to `int`, only
* true `int`s can reach the back-end, as asserted by the String-switch
* transformation in `cleanup`. Therefore, we do not adapt, preserving
* the `string`s and `null`s that come out of the pattern matching in
* Scala 2.13.2+.
*/
val genSelector = genExpr(selector)

// Sanity check: we can handle Ints and Strings (including `null`s), but nothing else
Expand Down Expand Up @@ -2934,11 +2928,6 @@ class JSCodeGen()(using genCtx: Context) {
* When no optimization applies, and any of the case values is not a
* literal int, we emit a series of `if..else` instead of a `js.Match`.
* This became necessary in 2.13.2 with strings and nulls.
*
* Note that dotc has not adopted String-switch-Matches yet, so these code
* paths are dead code at the moment. However, they already existed in the
* scalac, so were ported, to be immediately available and working when
* dotc starts emitting switch-Matches on Strings.
*/
def isInt(tree: js.Tree): Boolean = tree.tpe == jstpe.IntType

Expand Down
41 changes: 23 additions & 18 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Expand Up @@ -20,7 +20,7 @@ import util.Property._

/** The pattern matching transform.
* After this phase, the only Match nodes remaining in the code are simple switches
* where every pattern is an integer constant
* where every pattern is an integer or string constant
*/
class PatternMatcher extends MiniPhase {
import ast.tpd._
Expand Down Expand Up @@ -768,13 +768,15 @@ object PatternMatcher {
(tpe isRef defn.IntClass) ||
(tpe isRef defn.ByteClass) ||
(tpe isRef defn.ShortClass) ||
(tpe isRef defn.CharClass)
(tpe isRef defn.CharClass) ||
(tpe isRef defn.StringClass)

val seen = mutable.Set[Int]()
val seen = mutable.Set[Any]()

def isNewIntConst(tree: Tree) = tree match {
case Literal(const) if const.isIntRange && !seen.contains(const.intValue) =>
seen += const.intValue
def isNewSwitchableConst(tree: Tree) = tree match {
case Literal(const)
if (const.isIntRange || const.tag == Constants.StringTag) && !seen.contains(const.value) =>
seen += const.value
true
case _ =>
false
Expand All @@ -789,7 +791,7 @@ object PatternMatcher {
val alts = List.newBuilder[Tree]
def rec(innerPlan: Plan): Boolean = innerPlan match {
case SeqPlan(TestPlan(EqualTest(tree), scrut, _, ReturnPlan(`innerLabel`)), tail)
if scrut === scrutinee && isNewIntConst(tree) =>
if scrut === scrutinee && isNewSwitchableConst(tree) =>
alts += tree
rec(tail)
case ReturnPlan(`outerLabel`) =>
Expand All @@ -809,7 +811,7 @@ object PatternMatcher {

def recur(plan: Plan): List[(List[Tree], Plan)] = plan match {
case SeqPlan(testPlan @ TestPlan(EqualTest(tree), scrut, _, ons), tail)
if scrut === scrutinee && !canFallThrough(ons) && isNewIntConst(tree) =>
if scrut === scrutinee && !canFallThrough(ons) && isNewSwitchableConst(tree) =>
(tree :: Nil, ons) :: recur(tail)
case SeqPlan(AlternativesPlan(alts, ons), tail) =>
(alts, ons) :: recur(tail)
Expand All @@ -832,29 +834,32 @@ object PatternMatcher {

/** Emit a switch-match */
private def emitSwitchMatch(scrutinee: Tree, cases: List[(List[Tree], Plan)]): Match = {
/* Make sure to adapt the scrutinee to Int, as well as all the alternatives
* of all cases, so that only Matches on pritimive Ints survive this phase.
/* Make sure to adapt the scrutinee to Int or String, as well as all the
* alternatives, so that only Matches on pritimive Ints or Strings survive
* this phase.
*/

val intScrutinee =
if (scrutinee.tpe.widen.isRef(defn.IntClass)) scrutinee
else scrutinee.select(nme.toInt)
val (primScrutinee, scrutineeTpe) =
if (scrutinee.tpe.widen.isRef(defn.IntClass)) (scrutinee, defn.IntType)
else if (scrutinee.tpe.widen.isRef(defn.StringClass)) (scrutinee, defn.StringType)
else (scrutinee.select(nme.toInt), defn.IntType)

def intLiteral(lit: Tree): Tree =
def primLiteral(lit: Tree): Tree =
val Literal(constant) = lit
if (constant.tag == Constants.IntTag) lit
else if (constant.tag == Constants.StringTag) lit
else cpy.Literal(lit)(Constant(constant.intValue))

val caseDefs = cases.map { (alts, ons) =>
val pat = alts match {
case alt :: Nil => intLiteral(alt)
case Nil => Underscore(defn.IntType) // default case
case _ => Alternative(alts.map(intLiteral))
case alt :: Nil => primLiteral(alt)
case Nil => Underscore(scrutineeTpe) // default case
case _ => Alternative(alts.map(primLiteral))
}
CaseDef(pat, EmptyTree, emit(ons))
}

Match(intScrutinee, caseDefs)
Match(primScrutinee, caseDefs)
}

/** If selfCheck is `true`, used to check whether a tree gets generated twice */
Expand Down
22 changes: 22 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala
Expand Up @@ -118,6 +118,28 @@ class TestBCode extends DottyBytecodeTest {
}
}

@Test def switchOnStrings = {
val source =
"""
|object Foo {
| import scala.annotation.switch
| def foo(s: String) = s match {
| case "AaAa" => println(3)
| case "BBBB" | "c" => println(2)
| case "D" | "E" => println(1)
| case _ => println(0)
| }
|}
""".stripMargin

checkBCode(source) { dir =>
val moduleIn = dir.lookupName("Foo$.class", directory = false)
val moduleNode = loadClassNode(moduleIn.input)
val methodNode = getMethod(moduleNode, "foo")
assert(verifySwitch(methodNode))
}
}

@Test def matchWithDefaultNoThrowMatchError = {
val source =
"""class Test {
Expand Down
2 changes: 2 additions & 0 deletions tests/run/string-switch-defaults-null.check
@@ -0,0 +1,2 @@
2
-1
16 changes: 16 additions & 0 deletions tests/run/string-switch-defaults-null.scala
@@ -0,0 +1,16 @@
import annotation.switch

object Test {
def test(s: String): Int = {
(s : @switch) match {
case "1" => 0
case null => -1
case _ => s.toInt
}
}

def main(args: Array[String]): Unit = {
println(test("2"))
println(test(null))
}
}

0 comments on commit 3d8fa54

Please sign in to comment.