Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plug many variance holes (in higher-kinded types, refined types and private inner classes) #8545

Merged
merged 4 commits into from Mar 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 4 additions & 8 deletions src/compiler/scala/tools/nsc/typechecker/RefChecks.scala
Expand Up @@ -893,9 +893,9 @@ abstract class RefChecks extends Transform {
case ClassInfoType(parents, _, clazz) => "supertype "+intersectionType(parents, clazz.owner)
case _ => "type "+tp
}
override def issueVarianceError(base: Symbol, sym: Symbol, required: Variance): Unit = {
override def issueVarianceError(base: Symbol, sym: Symbol, required: Variance, tpe: Type): Unit = {
reporter.error(base.pos,
s"${sym.variance} $sym occurs in $required position in ${tpString(base.info)} of $base")
s"${sym.variance} $sym occurs in $required position in ${tpString(tpe)} of $base")
}
}

Expand Down Expand Up @@ -1564,8 +1564,6 @@ abstract class RefChecks extends Transform {

if (!sym.exists)
devWarning("Select node has NoSymbol! " + tree + " / " + tree.tpe)
else if (sym.isLocalToThis)
varianceValidator.checkForEscape(sym, currentClass)

def checkSuper(mix: Name) =
// term should have been eliminated by super accessors
Expand Down Expand Up @@ -1772,13 +1770,11 @@ abstract class RefChecks extends Transform {
result.transform(this)
}
result1 match {
case ClassDef(_, _, _, _)
| TypeDef(_, _, _, _)
| ModuleDef(_, _, _) =>
case ClassDef(_, _, _, _) | TypeDef(_, _, _, _) | ModuleDef(_, _, _) =>
if (result1.symbol.isLocalToBlock || result1.symbol.isTopLevel)
varianceValidator.traverse(result1)
case tt @ TypeTree() if tt.original != null =>
varianceValidator.traverse(tt.original) // See scala/bug#7872
joroKr21 marked this conversation as resolved.
Show resolved Hide resolved
varianceValidator.validateVarianceOfPolyTypesIn(tt.tpe)
case _ =>
}

Expand Down
160 changes: 112 additions & 48 deletions src/reflect/scala/reflect/internal/Variances.scala
Expand Up @@ -28,7 +28,6 @@ trait Variances {
* TODO - eliminate duplication with varianceInType
*/
class VarianceValidator extends InternalTraverser {
private[this] val escapedLocals = mutable.HashSet[Symbol]()
// A flag for when we're in a refinement, meaning method parameter types
// need to be checked.
private[this] var inRefinement = false
Expand All @@ -38,34 +37,26 @@ trait Variances {
try body finally inRefinement = saved
}

/** Is every symbol in the owner chain between `site` and the owner of `sym`
* either a term symbol or private[this]? If not, add `sym` to the set of
* escaped locals.
* @pre sym.isLocalToThis
*/
@tailrec final def checkForEscape(sym: Symbol, site: Symbol): Unit = {
if (site == sym.owner || site == sym.owner.moduleClass || site.hasPackageFlag) () // done
else if (site.isTerm || site.isPrivateLocal) checkForEscape(sym, site.owner) // ok - recurse to owner
else escapedLocals += sym
}

protected def issueVarianceError(base: Symbol, sym: Symbol, required: Variance): Unit = ()
protected def issueVarianceError(base: Symbol, sym: Symbol, required: Variance, tpe: Type): Unit = ()

// Flip occurrences of type parameters and parameters, unless
// - it's a constructor, or case class factory or extractor
// - it's a type parameter / parameter of a local definition
// - it's a type parameter of tvar's owner.
def shouldFlip(sym: Symbol, tvar: Symbol) = (
sym.isParameter
&& !sym.owner.isLocalToThis
&& !(tvar.isTypeParameterOrSkolem && sym.isTypeParameterOrSkolem && tvar.owner == sym.owner)
)

// Is `sym` is local to a term or is private[this] or protected[this]?
def isExemptFromVariance(sym: Symbol): Boolean = !sym.owner.isClass || (
(sym.isLocalToThis || sym.isSuperAccessor) // super accessors are implicitly local #4345
&& !escapedLocals(sym)
)
def isExemptFromVariance(sym: Symbol): Boolean =
// super accessors are implicitly local #4345
!sym.owner.isClass || sym.isLocalToThis || sym.isSuperAccessor

private object ValidateVarianceMap extends VariancedTypeMap {
private[this] var base: Symbol = _
private[this] var inLowerBoundOf: Symbol = _

/** The variance of a symbol occurrence of `tvar` seen at the level of the definition of `base`.
* The search proceeds from `base` to the owner of `tvar`.
Expand All @@ -85,12 +76,16 @@ trait Variances {
else v

@tailrec
def loop(sym: Symbol, v: Variance): Variance = (
if (sym == tvar.owner || v.isBivariant) v
def loop(sym: Symbol, v: Variance): Variance =
if (v.isBivariant) v
else if (sym == tvar.owner)
// We can't move this to `shouldFlip`, because it's needed only once at the end.
if (inLowerBoundOf == sym) v.flip else v
else loop(sym.owner, nextVariance(sym, v))
)

loop(base, Covariant)
}

def isUncheckedVariance(tp: Type) = tp match {
case AnnotatedType(annots, _) => annots exists (_ matches definitions.uncheckedVarianceClass)
case _ => false
Expand All @@ -104,17 +99,19 @@ trait Variances {
def base_s = s"$base in ${base.owner}" + (if (base.owner.isClass) "" else " in " + base.owner.enclClass)
log(s"verifying $sym_s is $required at $base_s")
if (sym.variance != required)
issueVarianceError(base, sym, required)
issueVarianceError(base, sym, required, base.info)
}
}

override def mapOver(decls: Scope): Scope = {
decls foreach (sym => withVariance(if (sym.isAliasType) Invariant else variance)(this(sym.info)))
decls
}

private def resultTypeOnly(tp: Type) = tp match {
case mt: MethodType => !inRefinement
case pt: PolyType => true
case _ => false
case _: MethodType => !inRefinement
case _: PolyType => true
case _ => false
}

/** For PolyTypes, type parameters are skipped because they are defined
Expand All @@ -125,75 +122,141 @@ trait Variances {
def apply(tp: Type): Type = {
tp match {
case _ if isUncheckedVariance(tp) =>
case _ if resultTypeOnly(tp) => this(tp.resultType)
case TypeRef(_, sym, _) if shouldDealias(sym) => this(tp.normalize)
case TypeRef(_, sym, _) if !sym.variance.isInvariant => checkVarianceOfSymbol(sym) ; tp.mapOver(this)
case _ if resultTypeOnly(tp) => apply(tp.resultType)
case TypeRef(_, sym, _) if shouldDealias(sym) => apply(tp.normalize)
case TypeRef(_, sym, _) if !sym.variance.isInvariant => checkVarianceOfSymbol(sym); tp.mapOver(this)
case RefinedType(_, _) => withinRefinement(tp.mapOver(this))
case ClassInfoType(parents, _, _) => parents foreach this
case mt @ MethodType(_, result) => flipped(mt.paramTypes foreach this) ; this(result)
case ClassInfoType(parents, _, _) => parents.foreach(apply)
case mt @ MethodType(_, result) => flipped(mt.paramTypes.foreach(apply)); apply(result)
case _ => tp.mapOver(this)
}
// We're using TypeMap here for type traversal only. To avoid wasteful symbol
// cloning during the recursion, it is important to return the input `tp`, rather
// than the result of the pattern match above, which normalizes types.
tp
}

private def shouldDealias(sym: Symbol): Boolean = {
// The RHS of (private|protected)[this] type aliases are excluded from variance checks. This is
// implemented in relativeVariance.
// As such, we need to expand references to them to retain soundness. Example: neg/t8079a.scala
sym.isAliasType && isExemptFromVariance(sym)
}

/** Validate the variance of types in the definition of `base`.
*
* Traverse the type signature of `base` and for each type parameter:
* - Calculate the relative variance between `base` and the type parameter's owner by
* walking the owner chain of `base`.
* - Calculate the required variance of the type parameter which is the product of the
* relative variance and the current variance in the type signature of `base`.
* - Ensure that the declared variance of the type parameter is compatible with the
* required variance, otherwise issue an error.
*
* Lower bounds need special handling. By default the variance is flipped when entering a
* lower bound. In most cases this is the correct behaviour except for the type parameters
* of higher-kinded types. E.g. in `Foo` below `x` occurs in covariant position:
* `class Foo[F[+_]] { type G[+x] >: F[x] }`
*
* To handle this special case, track when entering the lower bound of a HKT in a variable
* and flip the relative variance for its type parameters. (flipping the variance a second
* time negates the first flip).
*/
def validateDefinition(base: Symbol): Unit = {
val saved = this.base
this.base = base
try apply(base.info)
finally this.base = saved
base.info match {
case PolyType(_, TypeBounds(lo, hi)) =>
inLowerBoundOf = base
try flipped(apply(lo))
finally inLowerBoundOf = null
apply(hi)
case other =>
apply(other)
}
}
}

/** Validate variance of info of symbol `base` */
private def validateVariance(base: Symbol): Unit = {
ValidateVarianceMap validateDefinition base
private object PolyTypeVarianceMap extends TypeMap {

private def ownerOf(pt: PolyType): Symbol =
pt.typeParams.head.owner

private def checkPolyTypeParam(pt: PolyType, tparam: Symbol, tpe: Type): Unit =
if (!tparam.isInvariant) {
val required = varianceInType(tpe)(tparam)
if (!required.isBivariant && tparam.variance != required)
issueVarianceError(ownerOf(pt), tparam, required, pt)
}

def apply(tp: Type): Type = {
tp match {
case pt @ PolyType(typeParams, TypeBounds(lo, hi)) =>
typeParams.foreach { tparam =>
checkPolyTypeParam(pt, tparam, lo)
checkPolyTypeParam(pt, tparam, hi)
}

pt.mapOver(this)

case pt @ PolyType(typeParams, resultType) =>
typeParams.foreach(checkPolyTypeParam(pt, _, resultType))
pt.mapOver(this)

case _ =>
tp.mapOver(this)
}

tp
}
}

/** Validate the variance of (the type parameters of) PolyTypes in `tpe`.
*
* `validateDefinition` cannot handle PolyTypes in arbitrary position, because in general
* the relative variance of such types cannot be computed by walking the owner chain.
*
* Instead this method applies a naive algorithm which is correct but less efficient:
* use `varianceInType` to check each type parameter of a PolyType separately.
*/
def validateVarianceOfPolyTypesIn(tpe: Type): Unit =
PolyTypeVarianceMap(tpe)

override def traverse(tree: Tree): Unit = {
def sym = tree.symbol
// No variance check for object-private/protected methods/values.
// Or constructors, or case class factory or extractor.
def skip = (
sym == NoSymbol
|| sym.isLocalToThis
|| sym.owner.isConstructor
|| sym.owner.isCaseApplyOrUnapply
|| sym.owner.isConstructor // FIXME: this is unsound - scala/bug#8737
|| sym.owner.isCaseApplyOrUnapply // same treatment as constructors
|| sym.isParamAccessor && sym.isLocalToThis // local class parameters are construction only
)

tree match {
case defn: MemberDef if skip =>
case _: MemberDef if skip =>
debuglog(s"Skipping variance check of ${sym.defString}")
case ClassDef(_, _, _, _) | TypeDef(_, _, _, _) =>
validateVariance(sym)
ValidateVarianceMap.validateDefinition(sym)
tree.traverse(this)
case ModuleDef(_, _, _) =>
validateVariance(sym.moduleClass)
ValidateVarianceMap.validateDefinition(sym.moduleClass)
tree.traverse(this)
case ValDef(_, _, _, _) =>
validateVariance(sym)
ValidateVarianceMap.validateDefinition(sym)
case DefDef(_, _, tparams, vparamss, _, _) =>
validateVariance(sym)
ValidateVarianceMap.validateDefinition(sym)
traverseTrees(tparams)
traverseTreess(vparamss)
case Template(_, _, _) =>
tree.traverse(this)
case CompoundTypeTree(templ) =>
case CompoundTypeTree(_) =>
tree.traverse(this)

// scala/bug#7872 These two cases make sure we don't miss variance exploits
// in originals, e.g. in `foo[({type l[+a] = List[a]})#l]`
case tt @ TypeTree() if tt.original != null =>
tt.original.traverse(this)
case tt : TypTree =>
case tt: TypTree =>
tt.traverse(this)

case _ =>
}
}
Expand Down Expand Up @@ -230,7 +293,7 @@ trait Variances {
case ThisType(_) | ConstantType(_) => Bivariant
case TypeRef(_, tparam, _) if tparam eq this.tparam => Covariant
case NullaryMethodType(restpe) => inType(restpe)
case SingleType(pre, sym) => inType(pre)
case SingleType(pre, _) => inType(pre)
case TypeRef(pre, _, _) if tp.isHigherKinded => inType(pre) // a type constructor cannot occur in tp's args
case TypeRef(pre, sym, args) => inType(pre) & inArgs(sym, args)
case TypeBounds(lo, hi) => inType(lo).flip & inType(hi)
Expand All @@ -239,6 +302,7 @@ trait Variances {
case PolyType(tparams, restpe) => inSyms(tparams).flip & inType(restpe)
case ExistentialType(tparams, restpe) => inSyms(tparams) & inType(restpe)
case AnnotatedType(annots, tp) => inAnnots(annots) & inType(tp)
case SuperType(thistpe, supertpe) => inType(thistpe) & inType(supertpe)
}

def apply(tp: Type, tparam: Symbol): Variance = {
Expand Down
2 changes: 1 addition & 1 deletion test/files/neg/t7872.check
@@ -1,7 +1,7 @@
t7872.scala:6: error: contravariant type a occurs in covariant position in type [-a]Cov[a] of type l
type x = {type l[-a] = Cov[a]}
^
t7872.scala:8: error: covariant type a occurs in contravariant position in type [+a]Inv[a] of type l
t7872.scala:8: error: covariant type a occurs in contravariant position in type [+a]Inv[a] of value <local l>
foo[({type l[+a] = Inv[a]})#l]
^
t7872.scala:5: error: contravariant type a occurs in covariant position in type [-a]Cov[a] of type l
Expand Down
4 changes: 2 additions & 2 deletions test/files/neg/t7872b.check
@@ -1,7 +1,7 @@
t7872b.scala:8: error: contravariant type a occurs in covariant position in type [-a]List[a] of type l
t7872b.scala:8: error: contravariant type a occurs in covariant position in type [-a]List[a] of value <local l>
def oops1 = down[({type l[-a] = List[a]})#l](List('whatever: Object)).head + "oops"
^
t7872b.scala:19: error: covariant type a occurs in contravariant position in type [+a]coinv.Stringer[a] of type l
t7872b.scala:19: error: covariant type a occurs in contravariant position in type [+a]a => String of value <local l>
def oops2 = up[({type l[+a] = Stringer[a]})#l]("printed: " + _)
^
2 errors
4 changes: 4 additions & 0 deletions test/files/neg/t9911.check
@@ -0,0 +1,4 @@
t9911.scala:23: error: super may not be used on value source
super.source.getSomething
^
1 error
28 changes: 28 additions & 0 deletions test/files/neg/t9911.scala
@@ -0,0 +1,28 @@
// This should say:
// Error: super may not be used on value source
class ScalacBug {

class SomeClass {

type U

// Changing T or U stops the problem
def getSomething[T]: U = ???
}

trait Base {

// Changing this to a def like it should be stops the problem
val source: SomeClass = ???
}

class Bug extends Base {

override val source = {
// Not calling the function stops the problem
super.source.getSomething
???
}
}

}
22 changes: 22 additions & 0 deletions test/files/neg/variance-holes.check
@@ -0,0 +1,22 @@
variance-holes.scala:9: error: covariant type x occurs in contravariant position in type [+x, +y] >: F[x,y] of type F2
def asWiden[F2[+x, +y] >: F[x, y]]: F2[Int, Int] = v
^
variance-holes.scala:2: error: contravariant type A occurs in covariant position in type [-A] >: List[A] of type Lower1
type Lower1[-A] >: List[A]
^
variance-holes.scala:5: error: covariant type x occurs in contravariant position in type [+x] >: F[x] of type G
type G[+x] >: F[x]
^
variance-holes.scala:13: error: covariant type A occurs in contravariant position in type AnyRef{type T >: A} of method foo
def foo: { type T >: A }
^
variance-holes.scala:17: error: covariant type A occurs in contravariant position in type AnyRef{type T <: A} of value x
def foo(x: { type T <: A }): Unit
^
variance-holes.scala:20: error: covariant type A occurs in contravariant position in type <: AnyRef{type T >: A} of type x
class RefinedLower[+A, x <: { type T >: A }]
^
variance-holes.scala:21: error: covariant type A occurs in contravariant position in type A of value x_=
private[this] class PrivateThis[+A](var x: A)
^
7 errors