Skip to content

Commit

Permalink
Matching strings makes switches
Browse files Browse the repository at this point in the history
Switchable matches with string-typed scrutinee survive the pattern
matcher in the same way as those on integer types do: as a series of
`CaseDef`s with empty guard and literal pattern.

Cleanup collates them by hash code and emits a switch on that.
No sooner, so scala.js can emit a more JS-friendly implementation.

Labels were used to avoid a proliferation of `throw new MatchError`.

Works with nulls. Works with Unit.

Enclosed "pos" test stands for positions, not positivity.

Fixes scala/bug#11740
  • Loading branch information
hrhino authored and retronym committed Oct 4, 2019
1 parent 520a88e commit 0fc5566
Show file tree
Hide file tree
Showing 12 changed files with 288 additions and 15 deletions.
76 changes: 73 additions & 3 deletions src/compiler/scala/tools/nsc/transform/CleanUp.scala
Expand Up @@ -61,9 +61,6 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
}
private def mkTerm(prefix: String): TermName = unit.freshTermName(prefix)

//private val classConstantMeth = new HashMap[String, Symbol]
//private val symbolStaticFields = new HashMap[String, (Symbol, Tree, Tree)]

private var localTyper: analyzer.Typer = null

private def typedWithPos(pos: Position)(tree: Tree) =
Expand Down Expand Up @@ -383,6 +380,77 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
}
}

// transform scrutinee of all matches to ints
def transformSwitch(sw: Match): Tree = { import CODE._
sw.selector.tpe match {
case IntTpe => sw // can switch directly on ints
case StringTpe =>
// these assumptions about the shape of the tree are justified by the codegen in MatchOptimization
val Match(Typed(selTree: Ident, _), cases) = sw
val sel = selTree.symbol
val restpe = sw.tpe
val swPos = sw.pos.focus

// genbcode isn't thrilled about seeing labels with Unit arguments, so `success`'s type is one of
// `${sw.tpe} => ${sw.tpe}` or `() => Unit` depending.
val success = {
val lab = currentOwner.newLabel(unit.freshTermName("matchEnd"), swPos)
if (restpe =:= UnitTpe) {
lab.setInfo(MethodType(Nil, restpe))
} else {
lab.setInfo(MethodType(lab.newValueParameter(nme.x_1).setInfo(restpe) :: Nil, restpe))
}
}
def succeed(res: Tree): Tree = {
if (restpe =:= UnitTpe) BLOCK(res, REF(success) APPLY Nil) else REF(success) APPLY res
}

val failure = currentOwner.newLabel(unit.freshTermName("matchEnd"), swPos).setInfo(MethodType(Nil, NothingTpe))
def fail(): Tree = atPos(swPos) { Apply(REF(failure), Nil) }

val newSel = atPos(sel.pos) { IF (sel OBJ_EQ NULL) THEN LIT(0) ELSE (Apply(REF(sel) DOT Object_hashCode, Nil)) }
var dfltBody = fail()
object StringsPattern {
def unapply(arg: Tree): Option[List[String]] = arg match {
case Literal(Constant(value: String)) => Some(value :: Nil)
case Literal(Constant(null)) => Some(null :: Nil)
case Alternative(alts) => traverseOpt(alts)(unapply).map(_.flatten)
case _ => None
}
}
val casesByHash =
cases.flatMap {
case cd@CaseDef(StringsPattern(strs), _, body) => strs.map((_, body, cd.pat.pos))
case cd@CaseDef(Ident(nme.WILDCARD), _, body) =>
if (!cd.hasAttachment[SynthDefaultCase.type]) dfltBody = body
None
case cd => globalError(s"unhandled in switch: $cd"); None
}.groupBy(_._1.##)
val newCases = casesByHash.toList.sortBy(_._1).map {
case (hash, cases) =>
val newBody = cases.foldLeft(fail()) {
case (next, (pat, body, pos)) =>
val comparison = if (pat == null) Object_eq else Object_equals
atPos(pos) {
IF(LIT(pat) DOT comparison APPLY REF(sel)) THEN body ELSE next
}
}
CaseDef(LIT(hash), EmptyTree, newBody)
}
val res = Block(
succeed(Match(newSel, newCases :+ CaseDef(Ident(nme.WILDCARD), EmptyTree, dfltBody))),
LabelDef(failure, Nil, Throw(New(definitions.MatchErrorClass.tpe_*, REF(sel)))),
if (restpe =:= UnitTpe) {
LabelDef(success, Nil, gen.mkLiteralUnit)
} else {
LabelDef(success, success.info.params.head :: Nil, REF(success.info.params.head))
}
)
localTyper.typedPos(sw.pos)(res)
case _ => globalError(s"unhandled switch scrutinee type ${sw.selector.tpe}: $sw"); sw
}
}

override def transform(tree: Tree): Tree = tree match {
case _: ClassDef if genBCode.codeGen.CodeGenImpl.isJavaEntryPoint(tree.symbol, currentUnit, settings.mainClass.valueSetByUser.map(_.toString)) =>
// collecting symbols for entry points here (as opposed to GenBCode where they are used)
Expand Down Expand Up @@ -498,6 +566,8 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
super.transform(localTyper.typedPos(tree.pos)(consed))
}

case switch: Match => transformSwitch(switch)

case _ =>
super.transform(tree)
}
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/scala/tools/nsc/transform/Erasure.scala
Expand Up @@ -1285,7 +1285,7 @@ abstract class Erasure extends InfoTransform
treeCopy.Template(tree, parents, noSelfType, addBridgesToTemplate(body, currentOwner))

case Match(selector, cases) =>
Match(Typed(selector, TypeTree(selector.tpe)), cases)
treeCopy.Match(tree, Typed(selector, TypeTree(selector.tpe)), cases)

case Literal(ct) =>
// We remove the original tree attachments in pre-erasure to free up memory
Expand Down
Expand Up @@ -235,7 +235,8 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
def isDefault(x: CaseDef): Boolean
def defaultSym: Symbol
def defaultBody: Tree
def defaultCase(scrutSym: Symbol = defaultSym, guard: Tree = EmptyTree, body: Tree = defaultBody): CaseDef
def defaultCase(): CaseDef = defaultCase(defaultSym, EmptyTree, defaultBody)
def defaultCase(scrutSym: Symbol, guard: Tree, body: Tree): CaseDef

object GuardAndBodyTreeMakers {
def unapply(tms: List[TreeMaker]): Option[(Tree, Tree)] = {
Expand Down Expand Up @@ -499,19 +500,31 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
}
}

class RegularSwitchMaker(scrutSym: Symbol, matchFailGenOverride: Option[Tree => Tree], val unchecked: Boolean) extends SwitchMaker {
val switchableTpe = Set(ByteTpe, ShortTpe, IntTpe, CharTpe)
class RegularSwitchMaker(scrutSym: Symbol, matchFailGenOverride: Option[Tree => Tree], val unchecked: Boolean) extends SwitchMaker { import CODE._
val switchableTpe = Set(ByteTpe, ShortTpe, IntTpe, CharTpe, StringTpe)
val alternativesSupported = true
val canJump = true

// Constant folding sets the type of a constant tree to `ConstantType(Constant(folded))`
// The tree itself can be a literal, an ident, a selection, ...
object SwitchablePattern { def unapply(pat: Tree): Option[Tree] = pat.tpe match {
case const: ConstantType if const.value.isIntRange =>
Some(Literal(Constant(const.value.intValue))) // TODO: Java 7 allows strings in switches
case const: ConstantType =>
if (const.value.isIntRange)
Some(LIT(const.value.intValue) setPos pat.pos)
else if (const.value.tag == StringTag)
Some(LIT(const.value.stringValue) setPos pat.pos)
else if (const.value.tag == NullTag)
Some(LIT(null) setPos pat.pos)
else None
case _ => None
}}

def scrutRef(scrut: Symbol): Tree = dealiasWiden(scrut.tpe) match {
case subInt if definitions.isNumericSubClass(subInt.typeSymbol, IntClass) =>
REF(scrut) DOT nme.toInt
case _ => REF(scrut)
}

object SwitchableTreeMaker extends SwitchableTreeMakerExtractor {
def unapply(x: TreeMaker): Option[Tree] = x match {
case EqualityTestTreeMaker(_, SwitchablePattern(const), _) => Some(const)
Expand All @@ -525,8 +538,9 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
}

def defaultSym: Symbol = scrutSym
def defaultBody: Tree = { import CODE._; matchFailGenOverride map (gen => gen(REF(scrutSym))) getOrElse Throw(MatchErrorClass.tpe, REF(scrutSym)) }
def defaultCase(scrutSym: Symbol = defaultSym, guard: Tree = EmptyTree, body: Tree = defaultBody): CaseDef = { import CODE._; atPos(body.pos) {
def defaultBody: Tree = { matchFailGenOverride map (gen => gen(REF(scrutSym))) getOrElse Throw(MatchErrorClass.tpe, REF(scrutSym)) }
override def defaultCase(): global.CaseDef = super.defaultCase().updateAttachment(SynthDefaultCase)
def defaultCase(scrutSym: Symbol, guard: Tree, body: Tree): CaseDef = { atPos(body.pos) {
(DEFAULT IF guard) ==> body
}}
}
Expand All @@ -539,12 +553,9 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
if (caseDefsWithDefault.isEmpty) None // not worth emitting a switch.
else {
// match on scrutSym -- converted to an int if necessary -- not on scrut directly (to avoid duplicating scrut)
val scrutToInt: Tree =
if (scrutSym.tpe =:= IntTpe) REF(scrutSym)
else (REF(scrutSym) DOT (nme.toInt))
Some(BLOCK(
ValDef(scrutSym, scrut),
Match(scrutToInt, caseDefsWithDefault) // a switch
Match(regularSwitchMaker.scrutRef(scrutSym), caseDefsWithDefault) // a switch
))
}
} else None
Expand Down
3 changes: 3 additions & 0 deletions src/reflect/scala/reflect/internal/StdAttachments.scala
Expand Up @@ -119,4 +119,7 @@ trait StdAttachments {
class QualTypeSymAttachment(val sym: Symbol)

case object ConstructorNeedsFence extends PlainAttachment

/** Attached to CaseDefs generated by the compiler to match when none other does. */
case object SynthDefaultCase extends PlainAttachment
}
7 changes: 7 additions & 0 deletions test/files/jvm/string-switch/Switch_1.scala
@@ -0,0 +1,7 @@
import annotation.switch
class Switches {
val cond = true
def two = ("foo" : @switch) match { case "foo" => case "bar" => }
def guard = ("foo" : @switch) match { case "z" => case "y" => case x if cond => }
def colli = ("foo" : @switch) match { case "DB" => case "Ca" => }
}
16 changes: 16 additions & 0 deletions test/files/jvm/string-switch/Test.scala
@@ -0,0 +1,16 @@
import scala.tools.partest.BytecodeTest
import scala.tools.asm
import scala.collection.JavaConverters._
import scala.PartialFunction.cond

object Test extends BytecodeTest {
def show: Unit = {
val clasz = loadClassNode("Switches")
List("two", "guard", "colli") foreach { meth =>
val mn = getMethod(clasz, meth)
assert(mn.instructions.iterator.asScala.exists(isSwitchInsn), meth)
}
}
def isSwitchInsn(insn: asm.tree.AbstractInsnNode) =
cond(insn.getOpcode) { case asm.Opcodes.LOOKUPSWITCH | asm.Opcodes.TABLESWITCH => true }
}
2 changes: 2 additions & 0 deletions test/files/run/string-switch-defaults-null.check
@@ -0,0 +1,2 @@
2
-1
16 changes: 16 additions & 0 deletions test/files/run/string-switch-defaults-null.scala
@@ -0,0 +1,16 @@
import annotation.switch

object Test {
def test(s: String): Int = {
(s : @switch) match {
case "1" => 0
case null => -1
case _ => s.toInt
}
}

def main(args: Array[String]): Unit = {
println(test("2"))
println(test(null))
}
}
64 changes: 64 additions & 0 deletions test/files/run/string-switch-pos.check
@@ -0,0 +1,64 @@
[[syntax trees at end of patmat]] // newSource1.scala
[6]package [6]<empty> {
[6]class Switch extends [13][187]scala.AnyRef {
[187]def <init>(): [13]Switch = [187]{
[187][187][187]Switch.super.<init>();
[13]()
};
[21]def switch([28]s: [31]<type: [31]scala.Predef.String>, [39]cond: [45]<type: [45]scala.Boolean>): [21]Int = [56]{
[56]case <synthetic> val x1: [56]String = [56]s;
[56][56]x1 match {
[56]case [75]"AaAa" => [93]1
[56]case [104]"asdf" => [122]2
[133]case [133]"BbBb" => [133]if ([143]cond)
[151]3
else
[180]4
[56]case [56]_ => [56]throw [56][56][56]new [56]MatchError([56]x1)
}
}
}
}

[[syntax trees at end of cleanup]] // newSource1.scala
[6]package [6]<empty> {
[6]class Switch extends [13][13]Object {
[21]def switch([28]s: [31]<type: [31]scala.Predef.String>, [39]cond: [45]<type: [45]scala.Boolean>): [21]Int = [56]{
[56]case <synthetic> val x1: [56]String = [56]s;
[56]{
[56][56]matchEnd1([56][56]if ([56][56]x1.eq([56]null))
[56]0
else
[56][56]x1.hashCode() match {
[56]case [56]2031744 => [75]if ([75][75][75]"AaAa".equals([75]x1))
[93]1
else
[56][56]matchEnd2()
[56]case [56]2062528 => [133]if ([133][133][133]"BbBb".equals([133]x1))
[133]if ([143]cond)
[151]3
else
[180]4
else
[56][56]matchEnd2()
[56]case [56]3003444 => [104]if ([104][104][104]"asdf".equals([104]x1))
[122]2
else
[56][56]matchEnd2()
[56]case [56]_ => [56][56]matchEnd2()
});
[56]matchEnd2(){
[56]throw [56][56][56]new [56]MatchError([56]x1)
};
[56]matchEnd1(x$1: [NoPosition]Int){
[56]x$1
}
}
};
[187]def <init>(): [13]Switch = [187]{
[187][187][187]Switch.super.<init>();
[13]()
}
}
}

18 changes: 18 additions & 0 deletions test/files/run/string-switch-pos.scala
@@ -0,0 +1,18 @@
import scala.tools.partest._

object Test extends DirectTest {
override def extraSettings: String = "-usejavacp -stop:cleanup -Vprint:patmat,cleanup -Vprint-pos"

override def code =
"""class Switch {
| def switch(s: String, cond: Boolean) = s match {
| case "AaAa" => 1
| case "asdf" => 2
| case "BbBb" if cond => 3
| case "BbBb" => 4
| }
|}
""".stripMargin.trim

override def show(): Unit = Console.withErr(Console.out) { super.compile() }
}
20 changes: 20 additions & 0 deletions test/files/run/string-switch.check
@@ -0,0 +1,20 @@
fido Success(dog)
garfield Success(cat)
wanda Success(fish)
henry Success(horse)
felix Failure(scala.MatchError: felix (of class java.lang.String))
deuteronomy Success(cat)
=====
AaAa 2031744 Success(1)
BBBB 2031744 Success(2)
BBAa 2031744 Failure(scala.MatchError: BBAa (of class java.lang.String))
cCCc 3015872 Success(3)
ddDd 3077408 Success(4)
EEee 2125120 Failure(scala.MatchError: EEee (of class java.lang.String))
=====
A Success(())
X Failure(scala.MatchError: X (of class java.lang.String))
=====
Success(3)
null Failure(scala.MatchError: null)
7 Failure(scala.MatchError: 7 (of class java.lang.String))
46 changes: 46 additions & 0 deletions test/files/run/string-switch.scala
@@ -0,0 +1,46 @@
// scalac: -Werror
import annotation.switch
import util.Try

object Test extends App {

def species(name: String) = (name.toLowerCase : @switch) match {
case "fido" => "dog"
case "garfield" | "deuteronomy" => "cat"
case "wanda" => "fish"
case "henry" => "horse"
}
List("fido", "garfield", "wanda", "henry", "felix", "deuteronomy").foreach { n => println(s"$n ${Try(species(n))}") }

println("=====")

def collide(in: String) = (in : @switch) match {
case "AaAa" => 1
case "BBBB" => 2
case "cCCc" => 3
case x if x == "ddDd" => 4
}
List("AaAa", "BBBB", "BBAa", "cCCc", "ddDd", "EEee").foreach { s =>
println(s"$s ${s.##} ${Try(collide(s))}")
}

println("=====")

def unitary(in: String) = (in : @switch) match {
case "A" =>
}
List("A","X").foreach { s =>
println(s"$s ${Try(unitary(s))}")
}

println("=====")

def nullFun(in: String) = (in : @switch) match {
case "1" => 1
case null => 2
case "" => 3
}
List("", null, "7").foreach { s =>
println(s"$s ${Try(nullFun(s))}")
}
}

0 comments on commit 0fc5566

Please sign in to comment.