Skip to content

Commit

Permalink
Merge pull request #13348 from dotty-staging/dep-annots
Browse files Browse the repository at this point in the history
Remove anomalies and gaps in handling annotations
  • Loading branch information
odersky committed Sep 22, 2021
2 parents 321a92c + 00c9adb commit 1a84caa
Show file tree
Hide file tree
Showing 19 changed files with 231 additions and 42 deletions.
10 changes: 5 additions & 5 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
case _ => 0
}

/** The (last) list of arguments of an application */
def arguments(tree: Tree): List[Tree] = unsplice(tree) match {
case Apply(_, args) => args
case TypeApply(fn, _) => arguments(fn)
case Block(_, expr) => arguments(expr)
/** All term arguments of an application in a single flattened list */
def allArguments(tree: Tree): List[Tree] = unsplice(tree) match {
case Apply(fn, args) => allArguments(fn) ::: args
case TypeApply(fn, _) => allArguments(fn)
case Block(_, expr) => allArguments(expr)
case _ => Nil
}

Expand Down
48 changes: 44 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ import StdNames._
import dotty.tools.dotc.ast.tpd
import scala.util.Try
import util.Spans.Span
import printing.{Showable, Printer}
import printing.Texts.Text
import annotation.internal.sharable

object Annotations {

def annotClass(tree: Tree)(using Context) =
if (tree.symbol.isConstructor) tree.symbol.owner
else tree.tpe.typeSymbol

abstract class Annotation {
abstract class Annotation extends Showable {
def tree(using Context): Tree

def symbol(using Context): Symbol = annotClass(tree)
Expand All @@ -26,7 +29,8 @@ object Annotations {
def derivedAnnotation(tree: Tree)(using Context): Annotation =
if (tree eq this.tree) this else Annotation(tree)

def arguments(using Context): List[Tree] = ast.tpd.arguments(tree)
/** All arguments to this annotation in a single flat list */
def arguments(using Context): List[Tree] = ast.tpd.allArguments(tree)

def argument(i: Int)(using Context): Option[Tree] = {
val args = arguments
Expand All @@ -44,15 +48,48 @@ object Annotations {
/** The tree evaluation has finished. */
def isEvaluated: Boolean = true

/** Normally, type map over all tree nodes of this annotation, but can
* be overridden. Returns EmptyAnnotation if type type map produces a range
* type, since ranges cannot be types of trees.
*/
def mapWith(tm: TypeMap)(using Context) =
val args = arguments
if args.isEmpty then this
else
val findDiff = new TreeAccumulator[Type]:
def apply(x: Type, tree: Tree)(using Context): Type =
if tm.isRange(x) then x
else
val tp1 = tm(tree.tpe)
foldOver(if tp1 =:= tree.tpe then x else tp1, tree)
val diff = findDiff(NoType, args)
if tm.isRange(diff) then EmptyAnnotation
else if diff.exists then derivedAnnotation(tm.mapOver(tree))
else this

/** Does this annotation refer to a parameter of `tl`? */
def refersToParamOf(tl: TermLambda)(using Context): Boolean =
val args = arguments
if args.isEmpty then false
else tree.existsSubTree {
case id: Ident => id.tpe match
case TermParamRef(tl1, _) => tl eq tl1
case _ => false
case _ => false
}

/** A string representation of the annotation. Overridden in BodyAnnotation.
*/
def toText(printer: Printer): Text = printer.annotText(this)

def ensureCompleted(using Context): Unit = tree

def sameAnnotation(that: Annotation)(using Context): Boolean =
symbol == that.symbol && tree.sameTree(that.tree)
}

case class ConcreteAnnotation(t: Tree) extends Annotation {
case class ConcreteAnnotation(t: Tree) extends Annotation:
def tree(using Context): Tree = t
}

abstract class LazyAnnotation extends Annotation {
protected var mySym: Symbol | (Context ?=> Symbol)
Expand Down Expand Up @@ -98,6 +135,7 @@ object Annotations {
if (tree eq this.tree) this else ConcreteBodyAnnotation(tree)
override def arguments(using Context): List[Tree] = Nil
override def ensureCompleted(using Context): Unit = ()
override def toText(printer: Printer): Text = "@Body"
}

class ConcreteBodyAnnotation(body: Tree) extends BodyAnnotation {
Expand Down Expand Up @@ -194,6 +232,8 @@ object Annotations {
apply(defn.SourceFileAnnot, Literal(Constant(path)))
}

@sharable val EmptyAnnotation = Annotation(EmptyTree)

def ThrowsAnnotation(cls: ClassSymbol)(using Context): Annotation = {
val tref = cls.typeRef
Annotation(defn.ThrowsAnnot.typeRef.appliedTo(tref), Ident(tref))
Expand Down
10 changes: 7 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,13 @@ object TypeOps:
// with Nulls (which have no base classes). Under -Yexplicit-nulls, we take
// corrective steps, so no widening is wanted.
simplify(l, theMap) | simplify(r, theMap)
case AnnotatedType(parent, annot)
if annot.symbol == defn.UncheckedVarianceAnnot && !ctx.mode.is(Mode.Type) && !theMap.isInstanceOf[SimplifyKeepUnchecked] =>
simplify(parent, theMap)
case tp @ AnnotatedType(parent, annot) =>
val parent1 = simplify(parent, theMap)
if annot.symbol == defn.UncheckedVarianceAnnot
&& !ctx.mode.is(Mode.Type)
&& !theMap.isInstanceOf[SimplifyKeepUnchecked]
then parent1
else tp.derivedAnnotatedType(parent1, annot)
case _: MatchType =>
val normed = tp.tryNormalize
if (normed.exists) normed else mapOver
Expand Down
19 changes: 11 additions & 8 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3601,6 +3601,9 @@ object Types {
case tp: AppliedType => tp.fold(status, compute(_, _, theAcc))
case tp: TypeVar if !tp.isInstantiated => combine(status, Provisional)
case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps
case AnnotatedType(parent, ann) =>
if ann.refersToParamOf(thisLambdaType) then TrueDeps
else compute(status, parent, theAcc)
case _: ThisType | _: BoundType | NoPrefix => status
case _ =>
(if theAcc != null then theAcc else DepAcc()).foldOver(status, tp)
Expand Down Expand Up @@ -3653,8 +3656,10 @@ object Types {
if (isResultDependent) {
val dropDependencies = new ApproximatingTypeMap {
def apply(tp: Type) = tp match {
case tp @ TermParamRef(thisLambdaType, _) =>
case tp @ TermParamRef(`thisLambdaType`, _) =>
range(defn.NothingType, atVariance(1)(apply(tp.underlying)))
case AnnotatedType(parent, ann) if ann.refersToParamOf(thisLambdaType) =>
mapOver(parent)
case _ => mapOver(tp)
}
}
Expand Down Expand Up @@ -5379,6 +5384,8 @@ object Types {
variance = saved
derivedLambdaType(tp)(ptypes1, this(restpe))

def isRange(tp: Type): Boolean = tp.isInstanceOf[Range]

/** Map this function over given type */
def mapOver(tp: Type): Type = {
record(s"TypeMap mapOver ${getClass}")
Expand Down Expand Up @@ -5422,8 +5429,9 @@ object Types {

case tp @ AnnotatedType(underlying, annot) =>
val underlying1 = this(underlying)
if (underlying1 eq underlying) tp
else derivedAnnotatedType(tp, underlying1, mapOver(annot))
val annot1 = annot.mapWith(this)
if annot1 eq EmptyAnnotation then underlying1
else derivedAnnotatedType(tp, underlying1, annot1)

case _: ThisType
| _: BoundType
Expand Down Expand Up @@ -5495,9 +5503,6 @@ object Types {
else newScopeWith(elems1: _*)
}

def mapOver(annot: Annotation): Annotation =
annot.derivedAnnotation(mapOver(annot.tree))

def mapOver(tree: Tree): Tree = treeTypeMap(tree)

/** Can be overridden. By default, only the prefix is mapped. */
Expand Down Expand Up @@ -5544,8 +5549,6 @@ object Types {

protected def emptyRange = range(defn.NothingType, defn.AnyType)

protected def isRange(tp: Type): Boolean = tp.isInstanceOf[Range]

protected def lower(tp: Type): Type = tp match {
case tp: Range => tp.lo
case _ => tp
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,10 @@ class PlainPrinter(_ctx: Context) extends Printer {
case _ => literalText(String.valueOf(const.value))
}

def toText(annot: Annotation): Text = s"@${annot.symbol.name}" // for now
/** Usual target for `Annotation#toText`, overridden in RefinedPrinter */
def annotText(annot: Annotation): Text = s"@${annot.symbol.name}"

def toText(annot: Annotation): Text = annot.toText(this)

def toText(param: LambdaParam): Text =
varianceSign(param.paramVariance)
Expand Down Expand Up @@ -574,7 +577,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
Text()

nodeName ~ "(" ~ elems ~ tpSuffix ~ ")" ~ (Str(tree.sourcePos.toString) provided printDebug)
}.close // todo: override in refined printer
}.close

def toText(pos: SourcePosition): Text =
if (!pos.exists) "<no position>"
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/printing/Printer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ abstract class Printer {
/** A description of sym's location */
def extendedLocationText(sym: Symbol): Text

/** Textual description of regular annotation in terms of its tree */
def annotText(annot: Annotation): Text

/** Textual representation of denotation */
def toText(denot: Denotation): Text

Expand Down
36 changes: 24 additions & 12 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import typer.ProtoTypes._
import Trees._
import TypeApplications._
import Decorators._
import NameKinds.WildcardParamName
import NameKinds.{WildcardParamName, DefaultGetterName}
import util.Chars.isOperatorPart
import transform.TypeUtils._
import transform.SymUtils._
Expand Down Expand Up @@ -607,7 +607,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
case tree: Template =>
toTextTemplate(tree)
case Annotated(arg, annot) =>
toTextLocal(arg) ~~ annotText(annot)
toTextLocal(arg) ~~ annotText(annot.symbol.enclosingClass, annot)
case EmptyTree =>
"<empty>"
case TypedSplice(t) =>
Expand Down Expand Up @@ -964,14 +964,22 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
keywordStr("package ") ~ toTextPackageId(tree.pid) ~ bodyText
}

/** Textual representation of an instance creation expression without the leading `new` */
protected def constrText(tree: untpd.Tree): Text = toTextLocal(tree).stripPrefix(keywordStr("new ")) // DD

protected def annotText(tree: untpd.Tree): Text = "@" ~ constrText(tree) // DD

override def annotsText(sym: Symbol): Text =
Text(sym.annotations.map(ann =>
if ann.symbol == defn.BodyAnnot then Str(simpleNameString(ann.symbol))
else annotText(ann.tree)))
protected def annotText(sym: Symbol, tree: untpd.Tree): Text =
def recur(t: untpd.Tree): Text = t match
case Apply(fn, Nil) => recur(fn)
case Apply(fn, args) =>
val explicitArgs = args.filterNot(_.symbol.name.is(DefaultGetterName))
recur(fn) ~ "(" ~ toTextGlobal(explicitArgs, ", ") ~ ")"
case TypeApply(fn, args) => recur(fn) ~ "[" ~ toTextGlobal(args, ", ") ~ "]"
case Select(qual, nme.CONSTRUCTOR) => recur(qual)
case New(tpt) => recur(tpt)
case _ =>
val annotSym = sym.orElse(tree.symbol.enclosingClass)
s"@${if annotSym.exists then annotSym.name.toString else t.show}"
recur(tree)

protected def modText(mods: untpd.Modifiers, sym: Symbol, kw: String, isType: Boolean): Text = { // DD
val suppressKw = if (enclDefIsClass) mods.isAllOf(LocalParam) else mods.is(Param)
Expand All @@ -984,12 +992,16 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
if (rawFlags.is(Param)) flagMask = flagMask &~ Given &~ Erased
val flags = rawFlags & flagMask
var flagsText = toTextFlags(sym, flags)
val annotations =
if (sym.exists) sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree)
else mods.annotations.filterNot(tree => dropAnnotForModText(tree.symbol))
Text(annotations.map(annotText), " ") ~~ flagsText ~~ (Str(kw) provided !suppressKw)
val annotTexts =
if sym.exists then
sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(toText)
else
mods.annotations.filterNot(tree => dropAnnotForModText(tree.symbol)).map(annotText(NoSymbol, _))
Text(annotTexts, " ") ~~ flagsText ~~ (Str(kw) provided !suppressKw)
}

override def annotText(annot: Annotation): Text = annotText(annot.symbol, annot.tree)

def optText(name: Name)(encl: Text => Text): Text =
if (name.isEmpty) "" else encl(toText(name))

Expand Down
23 changes: 15 additions & 8 deletions compiler/test/dotty/tools/dotc/printing/PrintingTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@ import scala.io.Source
import org.junit.Test

class PrintingTest {
val testsDir = "tests/printing"
val options = List("-Xprint:typer", "-color:never", "-classpath", TestConfiguration.basicClasspath)

private def compileFile(path: JPath): Boolean = {
def options(phase: String) =
List(s"-Xprint:$phase", "-color:never", "-classpath", TestConfiguration.basicClasspath)

private def compileFile(path: JPath, phase: String): Boolean = {
val baseFilePath = path.toString.stripSuffix(".scala")
val checkFilePath = baseFilePath + ".check"
val byteStream = new ByteArrayOutputStream()
val reporter = TestReporter.reporter(new PrintStream(byteStream), INFO)

try {
Main.process((path.toString::options).toArray, reporter, null)
Main.process((path.toString::options(phase)).toArray, reporter, null)
} catch {
case e: Throwable =>
println(s"Compile $path exception:")
Expand All @@ -40,11 +41,10 @@ class PrintingTest {
FileDiff.checkAndDump(path.toString, actualLines.toIndexedSeq, checkFilePath)
}

@Test
def printing: Unit = {
def testIn(testsDir: String, phase: String) =
val res = Directory(testsDir).list.toList
.filter(f => f.extension == "scala")
.map { f => compileFile(f.jpath) }
.map { f => compileFile(f.jpath, phase) }

val failed = res.filter(!_)

Expand All @@ -53,5 +53,12 @@ class PrintingTest {
assert(failed.length == 0, msg)

println(msg)
}

end testIn

@Test
def printing: Unit = testIn("tests/printing", "typer")

@Test
def untypedPrinting: Unit = testIn("tests/printing/untyped", "parser")
}
7 changes: 7 additions & 0 deletions tests/neg/annot-printing.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- [E007] Type Mismatch Error: tests/neg/annot-printing.scala:5:46 -----------------------------------------------------
5 |def x: Int @nowarn @main @Foo @Bar("hello") = "abc" // error
| ^^^^^
| Found: ("abc" : String)
| Required: Int @nowarn() @main @Foo @Bar("hello")

longer explanation available when compiling with `-explain`
6 changes: 6 additions & 0 deletions tests/neg/annot-printing.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import scala.annotation.*
class Foo() extends Annotation
class Bar(s: String) extends Annotation

def x: Int @nowarn @main @Foo @Bar("hello") = "abc" // error

7 changes: 7 additions & 0 deletions tests/pos/dependent-annot.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class C
class ann(x: Any*) extends annotation.Annotation

def f(y: C, z: C) =
def g(): C @ann(y, z) = ???
val ac: ((x: C) => Array[String @ann(x)]) = ???
val dc = ac(g())
22 changes: 22 additions & 0 deletions tests/printing/annot-printing.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[[syntax trees at end of typer]] // tests/printing/annot-printing.scala
package <empty> {
import scala.annotation.*
class Foo() extends annotation.Annotation() {}
class Bar(s: String) extends annotation.Annotation() {
private[this] val s: String
}
class Xyz(i: Int) extends annotation.Annotation() {
private[this] val i: Int
}
final lazy module val Xyz: Xyz = new Xyz()
final module class Xyz() extends AnyRef() { this: Xyz.type =>
def $lessinit$greater$default$1: Int @uncheckedVariance = 23
}
final lazy module val annot-printing$package: annot-printing$package =
new annot-printing$package()
final module class annot-printing$package() extends Object() {
this: annot-printing$package.type =>
def x: Int @nowarn() @main @Xyz() @Foo @Bar("hello") = ???
}
}

7 changes: 7 additions & 0 deletions tests/printing/annot-printing.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import scala.annotation.*

class Foo() extends Annotation
class Bar(s: String) extends Annotation
class Xyz(i: Int = 23) extends Annotation

def x: Int @nowarn @main @Xyz() @Foo @Bar("hello") = ???

0 comments on commit 1a84caa

Please sign in to comment.