Skip to content

Commit

Permalink
Visit all trees
Browse files Browse the repository at this point in the history
  • Loading branch information
mbovel committed Dec 29, 2021
1 parent a349775 commit 67d8995
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 47 deletions.
71 changes: 24 additions & 47 deletions compiler/src/dotty/tools/dotc/transform/TailRec.scala
Expand Up @@ -21,14 +21,12 @@ import scala.collection.mutable
*
* What it does:
*
* Finds method calls in tail-position and replaces them with jumps.
* A call is in a tail-position if it is the last instruction to be
* executed in the body of a method. This includes being in
* tail-position of a `return` from a `Labeled` block which is itself
* in tail-position (which is critical for tail-recursive calls in the
* cases of a `match`). To identify tail positions, we recurse over
* the trees that may contain calls in tail-position (trees that can't
* contain such calls are not transformed).
* Finds method calls in tail-position and replaces them with jumps. A call is
* in a tail-position if it is the last instruction to be executed in the body
* of a method. This includes being in tail-position inside a `return`
* expression. If the `return` targets a `Labeled` block, then the target block
* must itself be in tail-position (which is critical for tail-recursive calls
* in the cases of a `match`).
*
* When a method contains at least one tail-recursive call, its rhs
* is wrapped in the following structure:
Expand All @@ -49,7 +47,7 @@ import scala.collection.mutable
* reassigning the local `var`s substituting formal parameters and
* (b) a `return` from the `tailResult` labeled block, which has the
* net effect of looping back to the beginning of the method.
* If the receiver is modifed in a recursive call, an additional `var`
* If the receiver is modified in a recursive call, an additional `var`
* is used to replace `this`.
*
* As a complete example of the transformation, the classical `fact`
Expand Down Expand Up @@ -118,7 +116,7 @@ class TailRec extends MiniPhase {
override def transformDefDef(tree: DefDef)(using Context): Tree = {
val method = tree.symbol
val mandatory = method.hasAnnotation(defn.TailrecAnnot)
def noTailTransform(failureReported: Boolean) = {
def transform(failureReported: Boolean) = {
// FIXME: want to report this error on `tree.nameSpan`, but
// because of extension method getting a weird position, it is
// better to report on method symbol so there's no overlap.
Expand Down Expand Up @@ -212,9 +210,9 @@ class TailRec extends MiniPhase {
)
)
}
else noTailTransform(failureReported = transformer.failureReported)
else transform(failureReported = transformer.failureReported)
}
else noTailTransform(failureReported = false)
else transform(failureReported = false)
}

class TailRecElimination(method: Symbol, enclosingClass: ClassSymbol, paramSyms: List[Symbol], isMandatory: Boolean) extends TreeMap {
Expand Down Expand Up @@ -274,34 +272,13 @@ class TailRec extends MiniPhase {
finally inTailPosition = saved
}

def yesTailTransform(tree: Tree)(using Context): Tree =
transform(tree, tailPosition = true)

/** If not in tail position a tree traversal may not be needed.
*
* A recursive call may still be in tail position if within the return
* expression of a labeled block.
* A tree traversal may also be needed to report a failure to transform
* a recursive call of a @tailrec annotated method (i.e. `isMandatory`).
*/
private def isTraversalNeeded =
isMandatory || tailPositionLabeledSyms.size > 0

def noTailTransform(tree: Tree)(using Context): Tree =
if (isTraversalNeeded) transform(tree, tailPosition = false)
else tree

def noTailTransforms[Tr <: Tree](trees: List[Tr])(using Context): List[Tr] =
if (isTraversalNeeded) trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]]
else trees

override def transform(tree: Tree)(using Context): Tree = {
/* Rewrite an Apply to be considered for tail call transformation. */
def rewriteApply(tree: Apply): Tree = {
val arguments = noTailTransforms(tree.args)
val arguments = transform(tree.args)

def continue =
cpy.Apply(tree)(noTailTransform(tree.fun), arguments)
cpy.Apply(tree)(transform(tree.fun), arguments)

def fail(reason: String) = {
if (isMandatory) {
Expand Down Expand Up @@ -344,7 +321,7 @@ class TailRec extends MiniPhase {
if (prefix eq EmptyTree) assignParamPairs
else
// TODO Opt: also avoid assigning `this` if the prefix is `this.`
(getVarForRewrittenThis(), noTailTransform(prefix)) :: assignParamPairs
(getVarForRewrittenThis(), transform(prefix)) :: assignParamPairs

val assignments = assignThisAndParamPairs match {
case (lhs, rhs) :: Nil =>
Expand Down Expand Up @@ -377,22 +354,22 @@ class TailRec extends MiniPhase {
case tree @ Apply(fun, args) =>
val meth = fun.symbol
if (meth == defn.Boolean_|| || meth == defn.Boolean_&&)
cpy.Apply(tree)(noTailTransform(fun), transform(args))
cpy.Apply(tree)(transform(fun), transform(args))
else
rewriteApply(tree)

case tree @ Select(qual, name) =>
cpy.Select(tree)(noTailTransform(qual), name)
cpy.Select(tree)(transform(qual), name)

case tree @ Block(stats, expr) =>
cpy.Block(tree)(
noTailTransforms(stats),
transform(stats),
transform(expr)
)

case tree @ If(cond, thenp, elsep) =>
cpy.If(tree)(
noTailTransform(cond),
transform(cond),
transform(thenp),
transform(elsep)
)
Expand All @@ -402,33 +379,33 @@ class TailRec extends MiniPhase {

case tree @ Match(selector, cases) =>
cpy.Match(tree)(
noTailTransform(selector),
transform(selector),
transformSub(cases)
)

case tree: Try =>
val expr = noTailTransform(tree.expr)
val expr = transform(tree.expr)
if (tree.finalizer eq EmptyTree)
// SI-1672 Catches are in tail position when there is no finalizer
cpy.Try(tree)(expr, transformSub(tree.cases), EmptyTree)
else cpy.Try(tree)(
expr,
noTailTransforms(tree.cases),
noTailTransform(tree.finalizer)
transformSub(tree.cases),
transform(tree.finalizer)
)

case tree @ WhileDo(cond, body) =>
cpy.WhileDo(tree)(
noTailTransform(cond),
noTailTransform(body)
transform(cond),
transform(body)
)

case _: Alternative | _: Bind =>
assert(false, "We should never have gotten inside a pattern")
tree

case tree: ValOrDefDef =>
if (isMandatory) noTailTransform(tree.rhs)
if (isMandatory) transform(tree.rhs)
tree

case _: Super | _: This | _: Literal | _: TypeTree | _: TypeDef | EmptyTree =>
Expand Down
4 changes: 4 additions & 0 deletions tests/run/tailrec-return.check
@@ -1,2 +1,6 @@
6
false
true
false
true
Ada Lovelace, Alan Turing
42 changes: 42 additions & 0 deletions tests/run/tailrec-return.scala
Expand Up @@ -11,6 +11,48 @@ object Test:
if n == 1 then return false
true

@annotation.tailrec
def isEvenApply(n: Int): Boolean =
// Return inside an `Apply.fun`
(
if n != 0 && n != 1 then return isEvenApply(n - 2)
else if n == 1 then return false
else (x: Boolean) => x
)(true)

@annotation.tailrec
def isEvenWhile(n: Int): Boolean =
// Return inside a `WhileDo.cond`
while(
if n != 0 && n != 1 then return isEvenWhile(n - 2)
else if n == 1 then return false
else true
) {}
true

@annotation.tailrec
def isEvenReturn(n: Int): Boolean =
// Return inside a `Return`
return
if n != 0 && n != 1 then return isEvenReturn(n - 2)
else if n == 1 then return false
else true

@annotation.tailrec
def names(l: List[(String, String) | Null], acc: List[String] = Nil): List[String] =
l match
case Nil => acc.reverse
case x :: xs =>
if x == null then return names(xs, acc)

val displayName = x._1 + " " + x._2
names(xs, displayName :: acc)


def main(args: Array[String]): Unit =
println(sum(3))
println(isEven(5))
println(isEvenApply(6))
println(isEvenWhile(7))
println(isEvenReturn(8))
println(names(List(("Ada", "Lovelace"), null, ("Alan", "Turing"))).mkString(", "))

0 comments on commit 67d8995

Please sign in to comment.