Skip to content

Commit

Permalink
Fix regression with non-awaiting patterns in tail position
Browse files Browse the repository at this point in the history
  • Loading branch information
retronym committed Mar 27, 2020
1 parent e160b8d commit 9a7f84e
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 6 deletions.
Expand Up @@ -164,6 +164,8 @@ private[async] trait AnfTransform extends TransformUtils {

case ld @ LabelDef(name, params, rhs) =>
treeCopy.LabelDef(tree, name, params, transformNewControlFlowBlock(rhs))
case t @ Typed(expr, tpt) =>
transform(expr).setType(t.tpe)
case _ =>
super.transform(tree)
}
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/scala/tools/nsc/transform/async/AsyncNames.scala
Expand Up @@ -43,6 +43,7 @@ final class AsyncNames[U <: reflect.internal.Names with Singleton](val u: U) {
}
private val matchRes: TermNameCache = new TermNameCache("match")
private val ifRes: TermNameCache = new TermNameCache("if")
private val qual: TermNameCache = new TermNameCache("qual")
private val await: TermNameCache = new TermNameCache("await")


Expand All @@ -54,6 +55,7 @@ final class AsyncNames[U <: reflect.internal.Names with Singleton](val u: U) {
class AsyncName {
final val matchRes = new NameSource[U#TermName](self.matchRes)
final val ifRes = new NameSource[U#TermName](self.ifRes)
final val qual = new NameSource[U#TermName](self.qual)
final val await = new NameSource[U#TermName](self.await)

private val seenPrefixes = mutable.AnyRefMap[Name, AtomicInteger]()
Expand Down
14 changes: 10 additions & 4 deletions src/compiler/scala/tools/nsc/transform/async/ExprBuilder.scala
Expand Up @@ -113,7 +113,14 @@ trait ExprBuilder extends TransformUtils {
// An exception should bubble out to the enclosing handler, don't insert a complete call.
} else {
val expr = stats.remove(stats.size - 1)
stats += completeSuccess(expr)
def pushIntoMatchEnd(t: Tree): Tree = {
t match {
case MatchEnd(ld) => treeCopy.LabelDef(ld, ld.name, ld.params, pushIntoMatchEnd(ld.rhs))
case b@Block(caseStats, caseExpr) => assignUnitType(treeCopy.Block(b, caseStats, pushIntoMatchEnd(caseExpr)))
case expr => completeSuccess(expr)
}
}
stats += pushIntoMatchEnd(expr)
}
stats += typed(Return(literalUnit).setSymbol(currentTransformState.applySym))
allNextStates -= nextState
Expand Down Expand Up @@ -243,7 +250,7 @@ trait ExprBuilder extends TransformUtils {
}

case ld @ LabelDef(name, params, rhs) =>
if (isCaseLabel(ld.symbol) || isMatchEndLabel(ld.symbol)) {
if (isCaseLabel(ld.symbol) || (isMatchEndLabel(ld.symbol) && labelDefStates.contains(ld.symbol))) {
// LabelDefs from patterns are a bit trickier as they can (forward) branch to each other.

labelDefStates.get(ld.symbol).foreach { startLabelState =>
Expand All @@ -259,7 +266,7 @@ trait ExprBuilder extends TransformUtils {
}

val afterLabelState = afterState()
val (inlinedState, nestedStates) = buildNestedStatesFirstForInlining(rhs, afterLabelState)
val (inlinedState, nestedStates) = buildNestedStatesFirstForInlining(rhs, afterLabelState)

// Leave this label here for synchronous jumps from previous cases. This is
// allowed even if this case has its own state (ie if there is an asynchrounous path
Expand Down Expand Up @@ -288,7 +295,6 @@ trait ExprBuilder extends TransformUtils {
checkForUnsupportedAwait(stat)
stateBuilder += stat
}

case _ =>
checkForUnsupportedAwait(stat)
stateBuilder += stat
Expand Down
41 changes: 39 additions & 2 deletions test/junit/scala/tools/nsc/async/AnnotationDrivenAsync.scala
Expand Up @@ -34,6 +34,43 @@ class AnnotationDrivenAsync {
assertEquals(3, run(code))
}

@Test
def patternTailPosition(): Unit = {
val code =
"""
|import scala.concurrent._, duration.Duration, ExecutionContext.Implicits.global
|import scala.tools.partest.async.Async.{async, await}
|import Future.{successful => f}
|
|object Test {
| def test = async {
| {
| await(f(1))
| "foo" match {
| case x if "".isEmpty => x
| }
| }: AnyRef
| }
|}
|""".stripMargin
assertEquals("foo", run(code))
}

@Test
def awaitTyped(): Unit = {
val code =
"""
|import scala.concurrent._, duration.Duration, ExecutionContext.Implicits.global
|import scala.tools.partest.async.Async.{async, await}
|import Future.{successful => f}
|
|object Test {
| def test = async {(("msg: " + await(f(0))): String).toString}
|}
|""".stripMargin
assertEquals("msg: 0", run(code))
}

@Test
def testBooleanAndOr(): Unit = {
val code =
Expand Down Expand Up @@ -646,8 +683,8 @@ class AnnotationDrivenAsync {

// settings.debug.value = true
// settings.uniqid.value = true
// settings.processArgumentString("-Xprint:async -nowarn")
// settings.log.value = List("async")
settings.processArgumentString("-Xprint:async -nowarn")
settings.log.value = List("async")

// NOTE: edit ANFTransform.traceAsync to `= true` to get additional diagnostic tracing.

Expand Down

0 comments on commit 9a7f84e

Please sign in to comment.