Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add -Xmacro-check for Block constructors #13824

Merged
merged 1 commit into from Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 24 additions & 2 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Expand Up @@ -753,9 +753,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object Block extends BlockModule:
def apply(stats: List[Statement], expr: Term): Block =
withDefaultPos(tpd.Block(stats, expr))
xCheckMacroBlockOwners(withDefaultPos(tpd.Block(stats, expr)))
def copy(original: Tree)(stats: List[Statement], expr: Term): Block =
tpd.cpy.Block(original)(stats, expr)
xCheckMacroBlockOwners(tpd.cpy.Block(original)(stats, expr))
def unapply(x: Block): (List[Statement], Term) =
(x.statements, x.expr)
end Block
Expand Down Expand Up @@ -2913,6 +2913,28 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
case _ => traverseChildren(t)
}.traverse(tree)

/** Checks that all definitions in this block have the same owner.
* Nested definitions are ignored and assumed to be correct by construction.
*/
private def xCheckMacroBlockOwners(tree: Tree): tree.type =
if xCheckMacro then
val defs = new tpd.TreeAccumulator[List[Tree]] {
def apply(defs: List[Tree], tree: Tree)(using Context): List[Tree] =
tree match
case tree: tpd.DefTree => tree :: defs
case _ => foldOver(defs, tree)
}.apply(Nil, tree)
val defOwners = defs.groupBy(_.symbol.owner)
assert(defOwners.size <= 1,
s"""Block contains definition with different owners.
|Found definitions ${defOwners.size} distinct owners: ${defOwners.keys.mkString(", ")}
|
|Block: ${Printer.TreeCode.show(tree)}
|
|${defOwners.map((owner, trees) => s"Definitions owned by $owner: \n${trees.map(Printer.TreeCode.show).mkString("\n")}").mkString("\n\n")}
|""".stripMargin)
tree

private def xCheckMacroValidExprs(terms: List[Term]): terms.type =
if xCheckMacro then terms.foreach(xCheckMacroValidExpr)
terms
Expand Down
265 changes: 265 additions & 0 deletions tests/neg-macros/i13809/Macros_1.scala
@@ -0,0 +1,265 @@
package x

import scala.annotation._
import scala.quoted._

trait CB[+T]

object CBM:
def pure[T](t:T):CB[T] = ???
def map[A,B](fa:CB[A])(f: A=>B):CB[B] = ???
def flatMap[A,B](fa:CB[A])(f: A=>CB[B]):CB[B] = ???
def spawn[A](op: =>CB[A]): CB[A] = ???


@compileTimeOnly("await should be inside async block")
def await[T](f: CB[T]): T = ???


trait CpsExpr[T:Type](prev: Seq[Expr[?]]):

def fLast(using Quotes): Expr[CB[T]]
def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T]
def append[A:Type](chunk: CpsExpr[A])(using Quotes): CpsExpr[A]
def syncOrigin(using Quotes): Option[Expr[T]]
def map[A:Type](f: Expr[T => A])(using Quotes): CpsExpr[A] =
MappedCpsExpr[T,A](Seq(),this,f)
def flatMap[A:Type](f: Expr[T => CB[A]])(using Quotes): CpsExpr[A] =
FlatMappedCpsExpr[T,A](Seq(),this,f)

def transformed(using Quotes): Expr[CB[T]] =
import quotes.reflect._
Block(prev.toList.map(_.asTerm), fLast.asTerm).asExprOf[CB[T]]


case class GenericSyncCpsExpr[T:Type](prev: Seq[Expr[?]],last: Expr[T]) extends CpsExpr[T](prev):

override def fLast(using Quotes): Expr[CB[T]] =
'{ CBM.pure(${last}:T) }

override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] =
copy(prev = exprs ++: prev)

override def syncOrigin(using Quotes): Option[Expr[T]] =
import quotes.reflect._
Some(Block(prev.toList.map(_.asTerm), last.asTerm).asExprOf[T])

override def append[A:Type](e: CpsExpr[A])(using Quotes) =
e.prependExprs(Seq(last)).prependExprs(prev)

override def map[A:Type](f: Expr[T => A])(using Quotes): CpsExpr[A] =
copy(last = '{ $f($last) })

override def flatMap[A:Type](f: Expr[T => CB[A]])(using Quotes): CpsExpr[A] =
GenericAsyncCpsExpr[A](prev, '{ CBM.flatMap(CBM.pure($last))($f) } )


abstract class AsyncCpsExpr[T:Type](
prev: Seq[Expr[?]]
) extends CpsExpr[T](prev):

override def append[A:Type](e: CpsExpr[A])(using Quotes): CpsExpr[A] =
flatMap( '{ (x:T) => ${e.transformed} })

override def syncOrigin(using Quotes): Option[Expr[T]] = None



case class GenericAsyncCpsExpr[T:Type](
prev: Seq[Expr[?]],
fLastExpr: Expr[CB[T]]
) extends AsyncCpsExpr[T](prev):

override def fLast(using Quotes): Expr[CB[T]] = fLastExpr

override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] =
copy(prev = exprs ++: prev)

override def map[A:Type](f: Expr[T => A])(using Quotes): CpsExpr[A] =
MappedCpsExpr(Seq(),this,f)

override def flatMap[A:Type](f: Expr[T => CB[A]])(using Quotes): CpsExpr[A] =
FlatMappedCpsExpr(Seq(),this,f)



case class MappedCpsExpr[S:Type, T:Type](
prev: Seq[Expr[?]],
point: CpsExpr[S],
mapping: Expr[S=>T]
) extends AsyncCpsExpr[T](prev):

override def fLast(using Quotes): Expr[CB[T]] =
'{ CBM.map(${point.transformed})($mapping) }

override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] =
copy(prev = exprs ++: prev)



case class FlatMappedCpsExpr[S:Type, T:Type](
prev: Seq[Expr[?]],
point: CpsExpr[S],
mapping: Expr[S => CB[T]]
) extends AsyncCpsExpr[T](prev):

override def fLast(using Quotes): Expr[CB[T]] =
'{ CBM.flatMap(${point.transformed})($mapping) }

override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] =
copy(prev = exprs ++: prev)


class ValRhsFlatMappedCpsExpr[T:Type, V:Type](using thisQuotes: Quotes)
(
prev: Seq[Expr[?]],
oldValDef: quotes.reflect.ValDef,
cpsRhs: CpsExpr[V],
next: CpsExpr[T]
)
extends AsyncCpsExpr[T](prev) {

override def fLast(using Quotes):Expr[CB[T]] =
import quotes.reflect._
next.syncOrigin match
case Some(nextOrigin) =>
// owner of this block is incorrect
'{
CBM.map(${cpsRhs.transformed})((vx:V) =>
${buildAppendBlockExpr('vx, nextOrigin)})
}
case None =>
'{
CBM.flatMap(${cpsRhs.transformed})((v:V)=>
${buildAppendBlockExpr('v, next.transformed)})
}


override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] =
ValRhsFlatMappedCpsExpr(using thisQuotes)(exprs ++: prev,oldValDef,cpsRhs,next)

override def append[A:quoted.Type](e: CpsExpr[A])(using Quotes) =
ValRhsFlatMappedCpsExpr(using thisQuotes)(prev,oldValDef,cpsRhs,next.append(e))


private def buildAppendBlock(using Quotes)(rhs:quotes.reflect.Term,
exprTerm:quotes.reflect.Term): quotes.reflect.Term =
import quotes.reflect._
import scala.quoted.Expr

val castedOldValDef = oldValDef.asInstanceOf[quotes.reflect.ValDef]
val valDef = ValDef(castedOldValDef.symbol, Some(rhs.changeOwner(castedOldValDef.symbol)))
exprTerm match
case Block(stats,last) =>
Block(valDef::stats, last)
case other =>
Block(valDef::Nil,other)

private def buildAppendBlockExpr[A:Type](using Quotes)(rhs: Expr[V], expr:Expr[A]):Expr[A] =
import quotes.reflect._
buildAppendBlock(rhs.asTerm,expr.asTerm).asExprOf[A]

}


object CpsExpr:

def sync[T:Type](f: Expr[T]): CpsExpr[T] =
GenericSyncCpsExpr[T](Seq(), f)

def async[T:Type](f: Expr[CB[T]]): CpsExpr[T] =
GenericAsyncCpsExpr[T](Seq(), f)


object Async:

transparent inline def transform[T](inline expr: T) = ${
Async.transformImpl[T]('expr)
}

def transformImpl[T:Type](f: Expr[T])(using Quotes): Expr[CB[T]] =
import quotes.reflect._
// println(s"before transformed: ${f.show}")
val cpsExpr = rootTransform[T](f)
val r = '{ CBM.spawn(${cpsExpr.transformed}) }
// println(s"transformed value: ${r.show}")
r

def rootTransform[T:Type](f: Expr[T])(using Quotes): CpsExpr[T] = {
import quotes.reflect._
f match
case '{ while ($cond) { $repeat } } =>
val cpsRepeat = rootTransform(repeat.asExprOf[Unit])
CpsExpr.async('{
def _whilefun():CB[Unit] =
if ($cond) {
${cpsRepeat.flatMap('{(x:Unit) => _whilefun()}).transformed}
} else {
CBM.pure(())
}
_whilefun()
}.asExprOf[CB[T]])
case _ =>
val fTree = f.asTerm
fTree match {
case fun@Apply(fun1@TypeApply(obj2,targs2), args1) =>
if (obj2.symbol.name == "await") {
val awaitArg = args1.head
CpsExpr.async(awaitArg.asExprOf[CB[T]])
} else {
???
}
case Assign(left,right) =>
left match
case id@Ident(x) =>
right.tpe.widen.asType match
case '[r] =>
val cpsRight = rootTransform(right.asExprOf[r])
CpsExpr.async(
cpsRight.map[T](
'{ (x:r) => ${Assign(left,'x.asTerm).asExprOf[T] }
}).transformed )
case _ => ???
case Block(prevs,last) =>
val rPrevs = prevs.map{ p =>
p match
case v@ValDef(vName,vtt,optRhs) =>
optRhs.get.tpe.widen.asType match
case '[l] =>
val cpsRight = rootTransform(optRhs.get.asExprOf[l])
ValRhsFlatMappedCpsExpr(using quotes)(Seq(), v, cpsRight, CpsExpr.sync('{}))
case t: Term =>
// TODO: rootTransform
t.asExpr match
case '{ $p: tp } =>
rootTransform(p)
case other =>
printf(other.show)
throw RuntimeException(s"can't handle term in block: $other")
case other =>
printf(other.show)
throw RuntimeException(s"unknown tree type in block: $other")
}
val rLast = rootTransform(last.asExprOf[T])
val blockResult = rPrevs.foldRight(rLast)((e,s) => e.append(s))
val retval = CpsExpr.async(blockResult.transformed)
retval
//BlockTransform(cpsCtx).run(prevs,last)
case id@Ident(name) =>
CpsExpr.sync(id.asExprOf[T])
case tid@Typed(Ident(name), tp) =>
CpsExpr.sync(tid.asExprOf[T])
case matchTerm@Match(scrutinee, caseDefs) =>
val nCases = caseDefs.map{ old =>
CaseDef.copy(old)(old.pattern, old.guard, rootTransform(old.rhs.asExprOf[T]).transformed.asTerm)
}
CpsExpr.async(Match(scrutinee, nCases).asExprOf[CB[T]])
case inlinedTerm@ Inlined(call,List(),body) =>
rootTransform(body.asExprOf[T])
case constTerm@Literal(_)=>
CpsExpr.sync(constTerm.asExprOf[T])
case _ =>
throw RuntimeException(s"language construction is not supported: ${fTree}")
}
}

19 changes: 19 additions & 0 deletions tests/neg-macros/i13809/Test_2.scala
@@ -0,0 +1,19 @@
package x

object VP1:

///*
def allocateServiceOperator(optInUsername: Option[String]): CB[Unit] = Async.transform { // error
val username = optInUsername match
case None =>
while(false) {
val nextResult = await(op1())
val countResult = await(op1())
}
case Some(inUsername) =>
val x = await(op1())
inUsername
}
//*/

def op1(): CB[String] = ???
4 changes: 3 additions & 1 deletion tests/pos-macros/i10151/Macro_1.scala
Expand Up @@ -55,7 +55,9 @@ object X:
)
)
)
case Block(stats, last) => Block(stats, transform(last))
case Block(stats, last) =>
val recoverdOwner = stats.headOption.map(_.symbol.owner).getOrElse(Symbol.spliceOwner) // hacky workaround to missing owner tracking in transform
Block(stats, transform(last).changeOwner(recoverdOwner))
case Inlined(x,List(),body) => transform(body)
case l@Literal(x) =>
l.asExpr match
Expand Down