Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matching strings makes switches in bytecode #8451

Merged
merged 1 commit into from Nov 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
106 changes: 103 additions & 3 deletions src/compiler/scala/tools/nsc/transform/CleanUp.scala
Expand Up @@ -61,9 +61,6 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
}
private def mkTerm(prefix: String): TermName = unit.freshTermName(prefix)

//private val classConstantMeth = new HashMap[String, Symbol]
//private val symbolStaticFields = new HashMap[String, (Symbol, Tree, Tree)]

private var localTyper: analyzer.Typer = null

private def typedWithPos(pos: Position)(tree: Tree) =
Expand Down Expand Up @@ -383,6 +380,106 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
}
}

object StringsPattern {
def unapply(arg: Tree): Option[List[String]] = arg match {
case Literal(Constant(value: String)) => Some(value :: Nil)
case Literal(Constant(null)) => Some(null :: Nil)
case Alternative(alts) => traverseOpt(alts)(unapply).map(_.flatten)
case _ => None
}
}

// transform scrutinee of all matches to ints
def transformSwitch(sw: Match): Tree = { import CODE._
sw.selector.tpe match {
case IntTpe => sw // can switch directly on ints
case StringTpe =>
// these assumptions about the shape of the tree are justified by the codegen in MatchOptimization
val Match(Typed(selTree: Ident, _), cases) = sw
val sel = selTree.symbol
val restpe = sw.tpe
val swPos = sw.pos.focus

/* From this:
* string match { case "AaAa" => 1 case "BBBB" | "c" => 2 case _ => 3}
* Generate this:
* string.## match {
* case 2031744 =>
* if ("AaAa" equals string) goto match1
* else if ("BBBB" equals string) goto match2
* else goto matchFailure
* case 99 =>
* if ("c" equals string) goto match2
* else goto matchFailure
* case _ => goto matchFailure
* }
* match1: goto matchSuccess (1)
* match2: goto matchSuccess (2)
* matchFailure: goto matchSuccess (3) // would be throw new MatchError(string) if no default was given
* matchSuccess(res: Int): res
* This proliferation of labels is needed to handle alternative patterns, since multiple branches in the
* resulting switch may need to correspond to a single case body.
*/

val stats = mutable.ListBuffer.empty[Tree]
var failureBody = Throw(New(definitions.MatchErrorClass.tpe_*, REF(sel))) : Tree
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't remember why I thought I couldn't use the synth default case from patmat. This should never be used (the Ident(WILDCARD, _, body) case should always be hit but it seems safer.


// genbcode isn't thrilled about seeing labels with Unit arguments, so `success`'s type is one of
// `${sw.tpe} => ${sw.tpe}` or `() => Unit` depending.
val success = {
val lab = currentOwner.newLabel(unit.freshTermName("matchEnd"), swPos)
if (restpe =:= UnitTpe) {
lab.setInfo(MethodType(Nil, restpe))
} else {
lab.setInfo(MethodType(lab.newValueParameter(nme.x_1).setInfo(restpe) :: Nil, restpe))
}
}
def succeed(res: Tree): Tree =
if (restpe =:= UnitTpe) BLOCK(res, REF(success) APPLY Nil) else REF(success) APPLY res

val failure = currentOwner.newLabel(unit.freshTermName("matchEnd"), swPos).setInfo(MethodType(Nil, restpe))
def fail(): Tree = atPos(swPos) { Apply(REF(failure), Nil) }

val newSel = atPos(sel.pos) { IF (sel OBJ_EQ NULL) THEN LIT(0) ELSE (Apply(REF(sel) DOT Object_hashCode, Nil)) }
val casesByHash =
cases.flatMap {
case cd@CaseDef(StringsPattern(strs), _, body) =>
val jump = currentOwner.newLabel(unit.freshTermName("case"), swPos).setInfo(MethodType(Nil, restpe))
stats += LabelDef(jump, Nil, succeed(body))
strs.map((_, jump, cd.pat.pos))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used to copy the body of a case if there were alternatives, so

x match { case "a" | "b" => body }

would introduce tree sharing.

case cd@CaseDef(Ident(nme.WILDCARD), _, body) =>
failureBody = succeed(body)
None
case cd => globalError(s"unhandled in switch: $cd"); None
}.groupBy(_._1.##)
val newCases = casesByHash.toList.sortBy(_._1).map {
case (hash, cases) =>
val newBody = cases.foldLeft(fail()) {
case (next, (pat, jump, pos)) =>
val comparison = if (pat == null) Object_eq else Object_equals
atPos(pos) {
IF(LIT(pat) DOT comparison APPLY REF(sel)) THEN (REF(jump) APPLY Nil) ELSE next
}
}
CaseDef(LIT(hash), EmptyTree, newBody)
}

stats += LabelDef(failure, Nil, failureBody)

stats += (if (restpe =:= UnitTpe) {
LabelDef(success, Nil, gen.mkLiteralUnit)
} else {
LabelDef(success, success.info.params.head :: Nil, REF(success.info.params.head))
})

stats prepend Match(newSel, newCases :+ CaseDef(Ident(nme.WILDCARD), EmptyTree, fail()))

val res = Block(stats.result : _*)
localTyper.typedPos(sw.pos)(res)
case _ => globalError(s"unhandled switch scrutinee type ${sw.selector.tpe}: $sw"); sw
}
}

override def transform(tree: Tree): Tree = tree match {
case _: ClassDef if genBCode.codeGen.CodeGenImpl.isJavaEntryPoint(tree.symbol, currentUnit, settings.mainClass.valueSetByUser.map(_.toString)) =>
// collecting symbols for entry points here (as opposed to GenBCode where they are used)
Expand Down Expand Up @@ -498,6 +595,9 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
super.transform(localTyper.typedPos(tree.pos)(consed))
}

case switch: Match =>
super.transform(transformSwitch(switch))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks... I can't believe I forgot this


case _ =>
super.transform(tree)
}
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/scala/tools/nsc/transform/Erasure.scala
Expand Up @@ -1285,7 +1285,7 @@ abstract class Erasure extends InfoTransform
treeCopy.Template(tree, parents, noSelfType, addBridgesToTemplate(body, currentOwner))

case Match(selector, cases) =>
Match(Typed(selector, TypeTree(selector.tpe)), cases)
treeCopy.Match(tree, Typed(selector, TypeTree(selector.tpe)), cases)

case Literal(ct) =>
// We remove the original tree attachments in pre-erasure to free up memory
Expand Down
Expand Up @@ -499,19 +499,33 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
}
}

class RegularSwitchMaker(scrutSym: Symbol, matchFailGenOverride: Option[Tree => Tree], val unchecked: Boolean) extends SwitchMaker {
val switchableTpe = Set(ByteTpe, ShortTpe, IntTpe, CharTpe)
class RegularSwitchMaker(scrutSym: Symbol, matchFailGenOverride: Option[Tree => Tree], val unchecked: Boolean) extends SwitchMaker { import CODE._
val switchableTpe = Set(ByteTpe, ShortTpe, IntTpe, CharTpe, StringTpe)
val alternativesSupported = true
val canJump = true

// Constant folding sets the type of a constant tree to `ConstantType(Constant(folded))`
// The tree itself can be a literal, an ident, a selection, ...
object SwitchablePattern { def unapply(pat: Tree): Option[Tree] = pat.tpe match {
case const: ConstantType if const.value.isIntRange =>
Some(Literal(Constant(const.value.intValue))) // TODO: Java 7 allows strings in switches
case const: ConstantType =>
if (const.value.isIntRange)
Some(LIT(const.value.intValue) setPos pat.pos)
else if (const.value.tag == StringTag)
Some(LIT(const.value.stringValue) setPos pat.pos)
else if (const.value.tag == NullTag)
Some(LIT(null) setPos pat.pos)
else None
case _ => None
}}

def scrutRef(scrut: Symbol): Tree = dealiasWiden(scrut.tpe) match {
case subInt if subInt =:= IntTpe =>
REF(scrut)
case subInt if definitions.isNumericSubClass(subInt.typeSymbol, IntClass) =>
REF(scrut) DOT nme.toInt
case _ => REF(scrut)
}

object SwitchableTreeMaker extends SwitchableTreeMakerExtractor {
def unapply(x: TreeMaker): Option[Tree] = x match {
case EqualityTestTreeMaker(_, SwitchablePattern(const), _) => Some(const)
Expand All @@ -525,8 +539,8 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
}

def defaultSym: Symbol = scrutSym
def defaultBody: Tree = { import CODE._; matchFailGenOverride map (gen => gen(REF(scrutSym))) getOrElse Throw(MatchErrorClass.tpe, REF(scrutSym)) }
def defaultCase(scrutSym: Symbol = defaultSym, guard: Tree = EmptyTree, body: Tree = defaultBody): CaseDef = { import CODE._; atPos(body.pos) {
def defaultBody: Tree = { matchFailGenOverride map (gen => gen(REF(scrutSym))) getOrElse Throw(MatchErrorClass.tpe, REF(scrutSym)) }
def defaultCase(scrutSym: Symbol = defaultSym, guard: Tree = EmptyTree, body: Tree = defaultBody): CaseDef = { atPos(body.pos) {
(DEFAULT IF guard) ==> body
}}
}
Expand All @@ -539,12 +553,9 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
if (caseDefsWithDefault.isEmpty) None // not worth emitting a switch.
else {
// match on scrutSym -- converted to an int if necessary -- not on scrut directly (to avoid duplicating scrut)
val scrutToInt: Tree =
if (scrutSym.tpe =:= IntTpe) REF(scrutSym)
else (REF(scrutSym) DOT (nme.toInt))
Some(BLOCK(
ValDef(scrutSym, scrut),
Match(scrutToInt, caseDefsWithDefault) // a switch
Match(regularSwitchMaker.scrutRef(scrutSym), caseDefsWithDefault) // a switch
))
}
} else None
Expand Down
2 changes: 2 additions & 0 deletions test/files/run/string-switch-defaults-null.check
@@ -0,0 +1,2 @@
2
-1
16 changes: 16 additions & 0 deletions test/files/run/string-switch-defaults-null.scala
@@ -0,0 +1,16 @@
import annotation.switch

object Test {
def test(s: String): Int = {
(s : @switch) match {
case "1" => 0
case null => -1
case _ => s.toInt
}
}

def main(args: Array[String]): Unit = {
println(test("2"))
println(test(null))
}
}
73 changes: 73 additions & 0 deletions test/files/run/string-switch-pos.check
@@ -0,0 +1,73 @@
[[syntax trees at end of patmat]] // newSource1.scala
[6]package [6]<empty> {
[6]class Switch extends [13][187]scala.AnyRef {
[187]def <init>(): [13]Switch = [187]{
[187][187][187]Switch.super.<init>();
[13]()
};
[21]def switch([28]s: [31]<type: [31]scala.Predef.String>, [39]cond: [45]<type: [45]scala.Boolean>): [21]Int = [56]{
[56]case <synthetic> val x1: [56]String = [56]s;
[56][56]x1 match {
[56]case [75]"AaAa" => [93]1
[56]case [104]"asdf" => [122]2
[133]case [133]"BbBb" => [133]if ([143]cond)
[151]3
else
[180]4
[56]case [56]_ => [56]throw [56][56][56]new [56]MatchError([56]x1)
}
}
}
}

[[syntax trees at end of cleanup]] // newSource1.scala
[6]package [6]<empty> {
[6]class Switch extends [13][13]Object {
[21]def switch([28]s: [31]<type: [31]scala.Predef.String>, [39]cond: [45]<type: [45]scala.Boolean>): [21]Int = [56]{
[56]case <synthetic> val x1: [56]String = [56]s;
[56]{
[56][56]if ([56][56]x1.eq([56]null))
[56]0
else
[56][56]x1.hashCode() match {
[56]case [56]2031744 => [75]if ([75][75][75]"AaAa".equals([75]x1))
[75][75]case1()
else
[56][56]matchEnd2()
[56]case [56]2062528 => [133]if ([133][133][133]"BbBb".equals([133]x1))
[133][133]case3()
else
[56][56]matchEnd2()
[56]case [56]3003444 => [104]if ([104][104][104]"asdf".equals([104]x1))
[104][104]case2()
else
[56][56]matchEnd2()
[56]case [56]_ => [56][56]matchEnd2()
};
[56]case1(){
[56][56]matchEnd1([93]1)
};
[56]case2(){
[56][56]matchEnd1([122]2)
};
[56]case3(){
[56][56]matchEnd1([133]if ([143]cond)
[151]3
else
[180]4)
};
[56]matchEnd2(){
[56][56]matchEnd1([56]throw [56][56][56]new [56]MatchError([56]x1))
};
[56]matchEnd1(x$1: [NoPosition]Int){
[56]x$1
}
}
};
[187]def <init>(): [13]Switch = [187]{
[187][187][187]Switch.super.<init>();
[13]()
}
}
}

18 changes: 18 additions & 0 deletions test/files/run/string-switch-pos.scala
@@ -0,0 +1,18 @@
import scala.tools.partest._

object Test extends DirectTest {
override def extraSettings: String = "-usejavacp -stop:cleanup -Vprint:patmat,cleanup -Vprint-pos"

override def code =
"""class Switch {
| def switch(s: String, cond: Boolean) = s match {
| case "AaAa" => 1
| case "asdf" => 2
| case "BbBb" if cond => 3
| case "BbBb" => 4
| }
|}
""".stripMargin.trim

override def show(): Unit = Console.withErr(Console.out) { super.compile() }
}
29 changes: 29 additions & 0 deletions test/files/run/string-switch.check
@@ -0,0 +1,29 @@
fido Success(dog)
garfield Success(cat)
wanda Success(fish)
henry Success(horse)
felix Failure(scala.MatchError: felix (of class java.lang.String))
deuteronomy Success(cat)
=====
AaAa 2031744 Success(1)
BBBB 2031744 Success(2)
BBAa 2031744 Failure(scala.MatchError: BBAa (of class java.lang.String))
cCCc 3015872 Success(3)
ddDd 3077408 Success(4)
EEee 2125120 Failure(scala.MatchError: EEee (of class java.lang.String))
=====
A Success(())
X Failure(scala.MatchError: X (of class java.lang.String))
=====
Success(3)
null Success(2)
7 Failure(scala.MatchError: 7 (of class java.lang.String))
=====
pig Success(1)
dog Success(2)
=====
Ea 2236 Success(1)
FB 2236 Success(2)
cC 3136 Success(3)
xx 3840 Success(4)
null 0 Success(4)