Skip to content

Commit

Permalink
Matching strings makes switches (#8451)
Browse files Browse the repository at this point in the history
Matching strings makes switches
  • Loading branch information
lrytz committed Nov 19, 2019
2 parents acd6c1c + 9f25ca0 commit 48f73bc
Show file tree
Hide file tree
Showing 10 changed files with 352 additions and 14 deletions.
106 changes: 103 additions & 3 deletions src/compiler/scala/tools/nsc/transform/CleanUp.scala
Expand Up @@ -61,9 +61,6 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
}
private def mkTerm(prefix: String): TermName = unit.freshTermName(prefix)

//private val classConstantMeth = new HashMap[String, Symbol]
//private val symbolStaticFields = new HashMap[String, (Symbol, Tree, Tree)]

private var localTyper: analyzer.Typer = null

private def typedWithPos(pos: Position)(tree: Tree) =
Expand Down Expand Up @@ -383,6 +380,106 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
}
}

object StringsPattern {
def unapply(arg: Tree): Option[List[String]] = arg match {
case Literal(Constant(value: String)) => Some(value :: Nil)
case Literal(Constant(null)) => Some(null :: Nil)
case Alternative(alts) => traverseOpt(alts)(unapply).map(_.flatten)
case _ => None
}
}

// transform scrutinee of all matches to ints
def transformSwitch(sw: Match): Tree = { import CODE._
sw.selector.tpe match {
case IntTpe => sw // can switch directly on ints
case StringTpe =>
// these assumptions about the shape of the tree are justified by the codegen in MatchOptimization
val Match(Typed(selTree: Ident, _), cases) = sw
val sel = selTree.symbol
val restpe = sw.tpe
val swPos = sw.pos.focus

/* From this:
* string match { case "AaAa" => 1 case "BBBB" | "c" => 2 case _ => 3}
* Generate this:
* string.## match {
* case 2031744 =>
* if ("AaAa" equals string) goto match1
* else if ("BBBB" equals string) goto match2
* else goto matchFailure
* case 99 =>
* if ("c" equals string) goto match2
* else goto matchFailure
* case _ => goto matchFailure
* }
* match1: goto matchSuccess (1)
* match2: goto matchSuccess (2)
* matchFailure: goto matchSuccess (3) // would be throw new MatchError(string) if no default was given
* matchSuccess(res: Int): res
* This proliferation of labels is needed to handle alternative patterns, since multiple branches in the
* resulting switch may need to correspond to a single case body.
*/

val stats = mutable.ListBuffer.empty[Tree]
var failureBody = Throw(New(definitions.MatchErrorClass.tpe_*, REF(sel))) : Tree

// genbcode isn't thrilled about seeing labels with Unit arguments, so `success`'s type is one of
// `${sw.tpe} => ${sw.tpe}` or `() => Unit` depending.
val success = {
val lab = currentOwner.newLabel(unit.freshTermName("matchEnd"), swPos)
if (restpe =:= UnitTpe) {
lab.setInfo(MethodType(Nil, restpe))
} else {
lab.setInfo(MethodType(lab.newValueParameter(nme.x_1).setInfo(restpe) :: Nil, restpe))
}
}
def succeed(res: Tree): Tree =
if (restpe =:= UnitTpe) BLOCK(res, REF(success) APPLY Nil) else REF(success) APPLY res

val failure = currentOwner.newLabel(unit.freshTermName("matchEnd"), swPos).setInfo(MethodType(Nil, restpe))
def fail(): Tree = atPos(swPos) { Apply(REF(failure), Nil) }

val newSel = atPos(sel.pos) { IF (sel OBJ_EQ NULL) THEN LIT(0) ELSE (Apply(REF(sel) DOT Object_hashCode, Nil)) }
val casesByHash =
cases.flatMap {
case cd@CaseDef(StringsPattern(strs), _, body) =>
val jump = currentOwner.newLabel(unit.freshTermName("case"), swPos).setInfo(MethodType(Nil, restpe))
stats += LabelDef(jump, Nil, succeed(body))
strs.map((_, jump, cd.pat.pos))
case cd@CaseDef(Ident(nme.WILDCARD), _, body) =>
failureBody = succeed(body)
None
case cd => globalError(s"unhandled in switch: $cd"); None
}.groupBy(_._1.##)
val newCases = casesByHash.toList.sortBy(_._1).map {
case (hash, cases) =>
val newBody = cases.foldLeft(fail()) {
case (next, (pat, jump, pos)) =>
val comparison = if (pat == null) Object_eq else Object_equals
atPos(pos) {
IF(LIT(pat) DOT comparison APPLY REF(sel)) THEN (REF(jump) APPLY Nil) ELSE next
}
}
CaseDef(LIT(hash), EmptyTree, newBody)
}

stats += LabelDef(failure, Nil, failureBody)

stats += (if (restpe =:= UnitTpe) {
LabelDef(success, Nil, gen.mkLiteralUnit)
} else {
LabelDef(success, success.info.params.head :: Nil, REF(success.info.params.head))
})

stats prepend Match(newSel, newCases :+ CaseDef(Ident(nme.WILDCARD), EmptyTree, fail()))

val res = Block(stats.result : _*)
localTyper.typedPos(sw.pos)(res)
case _ => globalError(s"unhandled switch scrutinee type ${sw.selector.tpe}: $sw"); sw
}
}

override def transform(tree: Tree): Tree = tree match {
case _: ClassDef if genBCode.codeGen.CodeGenImpl.isJavaEntryPoint(tree.symbol, currentUnit, settings.mainClass.valueSetByUser.map(_.toString)) =>
// collecting symbols for entry points here (as opposed to GenBCode where they are used)
Expand Down Expand Up @@ -498,6 +595,9 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
super.transform(localTyper.typedPos(tree.pos)(consed))
}

case switch: Match =>
super.transform(transformSwitch(switch))

case _ =>
super.transform(tree)
}
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/scala/tools/nsc/transform/Erasure.scala
Expand Up @@ -1285,7 +1285,7 @@ abstract class Erasure extends InfoTransform
treeCopy.Template(tree, parents, noSelfType, addBridgesToTemplate(body, currentOwner))

case Match(selector, cases) =>
Match(Typed(selector, TypeTree(selector.tpe)), cases)
treeCopy.Match(tree, Typed(selector, TypeTree(selector.tpe)), cases)

case Literal(ct) =>
// We remove the original tree attachments in pre-erasure to free up memory
Expand Down
Expand Up @@ -499,19 +499,33 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
}
}

class RegularSwitchMaker(scrutSym: Symbol, matchFailGenOverride: Option[Tree => Tree], val unchecked: Boolean) extends SwitchMaker {
val switchableTpe = Set(ByteTpe, ShortTpe, IntTpe, CharTpe)
class RegularSwitchMaker(scrutSym: Symbol, matchFailGenOverride: Option[Tree => Tree], val unchecked: Boolean) extends SwitchMaker { import CODE._
val switchableTpe = Set(ByteTpe, ShortTpe, IntTpe, CharTpe, StringTpe)
val alternativesSupported = true
val canJump = true

// Constant folding sets the type of a constant tree to `ConstantType(Constant(folded))`
// The tree itself can be a literal, an ident, a selection, ...
object SwitchablePattern { def unapply(pat: Tree): Option[Tree] = pat.tpe match {
case const: ConstantType if const.value.isIntRange =>
Some(Literal(Constant(const.value.intValue))) // TODO: Java 7 allows strings in switches
case const: ConstantType =>
if (const.value.isIntRange)
Some(LIT(const.value.intValue) setPos pat.pos)
else if (const.value.tag == StringTag)
Some(LIT(const.value.stringValue) setPos pat.pos)
else if (const.value.tag == NullTag)
Some(LIT(null) setPos pat.pos)
else None
case _ => None
}}

def scrutRef(scrut: Symbol): Tree = dealiasWiden(scrut.tpe) match {
case subInt if subInt =:= IntTpe =>
REF(scrut)
case subInt if definitions.isNumericSubClass(subInt.typeSymbol, IntClass) =>
REF(scrut) DOT nme.toInt
case _ => REF(scrut)
}

object SwitchableTreeMaker extends SwitchableTreeMakerExtractor {
def unapply(x: TreeMaker): Option[Tree] = x match {
case EqualityTestTreeMaker(_, SwitchablePattern(const), _) => Some(const)
Expand All @@ -525,8 +539,8 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
}

def defaultSym: Symbol = scrutSym
def defaultBody: Tree = { import CODE._; matchFailGenOverride map (gen => gen(REF(scrutSym))) getOrElse Throw(MatchErrorClass.tpe, REF(scrutSym)) }
def defaultCase(scrutSym: Symbol = defaultSym, guard: Tree = EmptyTree, body: Tree = defaultBody): CaseDef = { import CODE._; atPos(body.pos) {
def defaultBody: Tree = { matchFailGenOverride map (gen => gen(REF(scrutSym))) getOrElse Throw(MatchErrorClass.tpe, REF(scrutSym)) }
def defaultCase(scrutSym: Symbol = defaultSym, guard: Tree = EmptyTree, body: Tree = defaultBody): CaseDef = { atPos(body.pos) {
(DEFAULT IF guard) ==> body
}}
}
Expand All @@ -539,12 +553,9 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
if (caseDefsWithDefault.isEmpty) None // not worth emitting a switch.
else {
// match on scrutSym -- converted to an int if necessary -- not on scrut directly (to avoid duplicating scrut)
val scrutToInt: Tree =
if (scrutSym.tpe =:= IntTpe) REF(scrutSym)
else (REF(scrutSym) DOT (nme.toInt))
Some(BLOCK(
ValDef(scrutSym, scrut),
Match(scrutToInt, caseDefsWithDefault) // a switch
Match(regularSwitchMaker.scrutRef(scrutSym), caseDefsWithDefault) // a switch
))
}
} else None
Expand Down
2 changes: 2 additions & 0 deletions test/files/run/string-switch-defaults-null.check
@@ -0,0 +1,2 @@
2
-1
16 changes: 16 additions & 0 deletions test/files/run/string-switch-defaults-null.scala
@@ -0,0 +1,16 @@
import annotation.switch

object Test {
def test(s: String): Int = {
(s : @switch) match {
case "1" => 0
case null => -1
case _ => s.toInt
}
}

def main(args: Array[String]): Unit = {
println(test("2"))
println(test(null))
}
}
73 changes: 73 additions & 0 deletions test/files/run/string-switch-pos.check
@@ -0,0 +1,73 @@
[[syntax trees at end of patmat]] // newSource1.scala
[6]package [6]<empty> {
[6]class Switch extends [13][187]scala.AnyRef {
[187]def <init>(): [13]Switch = [187]{
[187][187][187]Switch.super.<init>();
[13]()
};
[21]def switch([28]s: [31]<type: [31]scala.Predef.String>, [39]cond: [45]<type: [45]scala.Boolean>): [21]Int = [56]{
[56]case <synthetic> val x1: [56]String = [56]s;
[56][56]x1 match {
[56]case [75]"AaAa" => [93]1
[56]case [104]"asdf" => [122]2
[133]case [133]"BbBb" => [133]if ([143]cond)
[151]3
else
[180]4
[56]case [56]_ => [56]throw [56][56][56]new [56]MatchError([56]x1)
}
}
}
}

[[syntax trees at end of cleanup]] // newSource1.scala
[6]package [6]<empty> {
[6]class Switch extends [13][13]Object {
[21]def switch([28]s: [31]<type: [31]scala.Predef.String>, [39]cond: [45]<type: [45]scala.Boolean>): [21]Int = [56]{
[56]case <synthetic> val x1: [56]String = [56]s;
[56]{
[56][56]if ([56][56]x1.eq([56]null))
[56]0
else
[56][56]x1.hashCode() match {
[56]case [56]2031744 => [75]if ([75][75][75]"AaAa".equals([75]x1))
[75][75]case1()
else
[56][56]matchEnd2()
[56]case [56]2062528 => [133]if ([133][133][133]"BbBb".equals([133]x1))
[133][133]case3()
else
[56][56]matchEnd2()
[56]case [56]3003444 => [104]if ([104][104][104]"asdf".equals([104]x1))
[104][104]case2()
else
[56][56]matchEnd2()
[56]case [56]_ => [56][56]matchEnd2()
};
[56]case1(){
[56][56]matchEnd1([93]1)
};
[56]case2(){
[56][56]matchEnd1([122]2)
};
[56]case3(){
[56][56]matchEnd1([133]if ([143]cond)
[151]3
else
[180]4)
};
[56]matchEnd2(){
[56][56]matchEnd1([56]throw [56][56][56]new [56]MatchError([56]x1))
};
[56]matchEnd1(x$1: [NoPosition]Int){
[56]x$1
}
}
};
[187]def <init>(): [13]Switch = [187]{
[187][187][187]Switch.super.<init>();
[13]()
}
}
}

18 changes: 18 additions & 0 deletions test/files/run/string-switch-pos.scala
@@ -0,0 +1,18 @@
import scala.tools.partest._

object Test extends DirectTest {
override def extraSettings: String = "-usejavacp -stop:cleanup -Vprint:patmat,cleanup -Vprint-pos"

override def code =
"""class Switch {
| def switch(s: String, cond: Boolean) = s match {
| case "AaAa" => 1
| case "asdf" => 2
| case "BbBb" if cond => 3
| case "BbBb" => 4
| }
|}
""".stripMargin.trim

override def show(): Unit = Console.withErr(Console.out) { super.compile() }
}
29 changes: 29 additions & 0 deletions test/files/run/string-switch.check
@@ -0,0 +1,29 @@
fido Success(dog)
garfield Success(cat)
wanda Success(fish)
henry Success(horse)
felix Failure(scala.MatchError: felix (of class java.lang.String))
deuteronomy Success(cat)
=====
AaAa 2031744 Success(1)
BBBB 2031744 Success(2)
BBAa 2031744 Failure(scala.MatchError: BBAa (of class java.lang.String))
cCCc 3015872 Success(3)
ddDd 3077408 Success(4)
EEee 2125120 Failure(scala.MatchError: EEee (of class java.lang.String))
=====
A Success(())
X Failure(scala.MatchError: X (of class java.lang.String))
=====
Success(3)
null Success(2)
7 Failure(scala.MatchError: 7 (of class java.lang.String))
=====
pig Success(1)
dog Success(2)
=====
Ea 2236 Success(1)
FB 2236 Success(2)
cC 3136 Success(3)
xx 3840 Success(4)
null 0 Success(4)

0 comments on commit 48f73bc

Please sign in to comment.