Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow return in tailrec position #14067

Merged
merged 2 commits into from Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 3 additions & 15 deletions compiler/src/dotty/tools/dotc/transform/TailRec.scala
Expand Up @@ -277,23 +277,11 @@ class TailRec extends MiniPhase {
def yesTailTransform(tree: Tree)(using Context): Tree =
transform(tree, tailPosition = true)

/** If not in tail position a tree traversal may not be needed.
*
* A recursive call may still be in tail position if within the return
* expression of a labeled block.
* A tree traversal may also be needed to report a failure to transform
* a recursive call of a @tailrec annotated method (i.e. `isMandatory`).
*/
private def isTraversalNeeded =
isMandatory || tailPositionLabeledSyms.size > 0

def noTailTransform(tree: Tree)(using Context): Tree =
if (isTraversalNeeded) transform(tree, tailPosition = false)
else tree
transform(tree, tailPosition = false)

def noTailTransforms[Tr <: Tree](trees: List[Tr])(using Context): List[Tr] =
if (isTraversalNeeded) trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]]
else trees
trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]]

override def transform(tree: Tree)(using Context): Tree = {
/* Rewrite an Apply to be considered for tail call transformation. */
Expand Down Expand Up @@ -444,7 +432,7 @@ class TailRec extends MiniPhase {

case Return(expr, from) =>
val fromSym = from.symbol
val inTailPosition = fromSym.is(Label) && tailPositionLabeledSyms.contains(fromSym)
val inTailPosition = !fromSym.is(Label) || tailPositionLabeledSyms.contains(fromSym)
cpy.Return(tree)(transform(expr, inTailPosition), from)

case _ =>
Expand Down
7 changes: 7 additions & 0 deletions tests/run/tailrec-return.check
@@ -0,0 +1,7 @@
6
false
true
false
true
Ada Lovelace, Alan Turing
List(9, 10)
66 changes: 66 additions & 0 deletions tests/run/tailrec-return.scala
@@ -0,0 +1,66 @@
object Test:

@annotation.tailrec
def sum(n: Int, acc: Int = 0): Int =
if n != 0 then return sum(n - 1, acc + n)
acc

@annotation.tailrec
def isEven(n: Int): Boolean =
if n != 0 && n != 1 then return isEven(n - 2)
if n == 1 then return false
true

@annotation.tailrec
def isEvenApply(n: Int): Boolean =
// Return inside an `Apply.fun`
(
if n != 0 && n != 1 then return isEvenApply(n - 2)
else if n == 1 then return false
else (x: Boolean) => x
)(true)

@annotation.tailrec
def isEvenWhile(n: Int): Boolean =
// Return inside a `WhileDo.cond`
while(
if n != 0 && n != 1 then return isEvenWhile(n - 2)
else if n == 1 then return false
else true
) {}
true

@annotation.tailrec
def isEvenReturn(n: Int): Boolean =
// Return inside a `Return`
return
if n != 0 && n != 1 then return isEvenReturn(n - 2)
else if n == 1 then return false
else true

@annotation.tailrec
def names(l: List[(String, String) | Null], acc: List[String] = Nil): List[String] =
l match
case Nil => acc.reverse
case x :: xs =>
if x == null then return names(xs, acc)

val displayName = x._1 + " " + x._2
names(xs, displayName :: acc)

def nonTail(l: List[Int]): List[Int] =
l match
case Nil => Nil
case x :: xs =>
// The call to nonTail should *not* be eliminated
(x + 1) :: nonTail(xs)


def main(args: Array[String]): Unit =
println(sum(3))
println(isEven(5))
println(isEvenApply(6))
println(isEvenWhile(7))
println(isEvenReturn(8))
println(names(List(("Ada", "Lovelace"), null, ("Alan", "Turing"))).mkString(", "))
println(nonTail(List(8, 9)))